机器学习快速入门(英文版)1分类
What is Classification?
In this chapter, we are going to look at one of 2 basic problems in machine learning: “classification”.
Classification means we are trying to predict a category.
Let’s make this more concrete through examples.
MNIST
This one is a machine learning classic.
In fact, for many years, this was the standard benchmark.
Whenever someone wrote a new paper, they would always report their algorithm’s performance on this dataset.
The data consists of a set of images of handwritten digits.
The possible digits are 0 through 9.
Your job, if you are building a machine learning model, is to write code that takes in as input an image, and outputs the correct digit corresponding to the image.
You may recognize that this is just a special case of handwriting recognition.
A more general version of this classifier would be able to take in an image of any character, and map that to the corresponding character (0-9 and A-Z and perhaps some symbols).
So for example, you could write down the letter “a” on a piece of paper, take a photo of it, and your code would spit out the character “a”.
An immediately useful application of this would be converting physical documents (handwritten or printed) into digital form. We call this OCR (optical character recognition).
Data Types
In machine learning, you’ll find that there are 2 very common data types.
MNIST is made up of a set of images.
Images generally fall under the field of “computer vision”.
It’s called computer vision because you’re teaching a computer to be able to intelligently process images.
The other common data format we work with in machine learning is text.
Text is everywhere, just like images.
When we’re working with text, we are doing “natural language processing”.
As a side note, these two fields have been totally transformed by deep learning.
Both computer vision and NLP algorithms have made significant strides in the recent past thanks to deep learning.
You won’t take a computer vision or NLP class today without discussing deep learning algorithms. 10 years ago, deep learning wouldn’t even be mentioned.
Spam Classification
Since we’ve already discussed an example of classification for images, let us now discuss an example for text.
This should be an easy one.
How about spam detection?
You may have noticed that these days, modern email providers like Gmail,
Yahoo, Hotmail, etc. automatically sort your spam and non-spam email for you.
Your junk mail automatically goes to a folder called “Spam”, and your regular email just goes to your inbox.
They don’t always get it right, but they are pretty good.
How is this done?
Well it’s not magic. It’s just machine learning.
A machine learning algorithm takes in as input an email, and outputs a prediction: it will categorize that email as either spam or not spam.
Are we seeing a pattern yet?
A machine learning classification algorithm takes in an input, and outputs a categorical prediction!
Business Example
Of course, there must be some examples of classification that do not fall under images and text.
Businesses these days are taking advantage of machine learning to improve their services and make more money.
Can we think of an example of that?
You bet!
One example we can all relate to, whether you love it or hate it, is online advertising.
This usually takes the form of banner ads or commercials in videos.
Online advertising is a billion dollar industry and it’s the source of a lot of controversy.
Luckily, this book isn’t about the ethics of advertising.
So how can an online advertiser make use of a machine learning classifier?
Suppose the advertiser has collected some data about you.
For example, they know your age and your gender and your location.
They know that yesterday, you did a google search for “best sunglasses for running”.
One detail that’s important to know is that the advertiser has a set of possible ads to show you (these ads are provided by their clients).
Suppose their clients consist of:
a sunglasses company
a swimwear company
an office furniture company
a makeup and beauty company
a running shoe company
Each of these 5 clients provide the advertiser with some ads, and the advertiser’s job is to show the user the best ad.
The best ad is the ad that the user is most likely to click on and buy a product from.
Thus, for every ad, we would like to be able to make a prediction about whether or not the user will click on the ad.
Once they’ve determined which ad you have the highest probability of clicking,
the ad will be shown to you (the user).
Of course, out of the companies listed above, probably the ad belonging to the sunglasses company would be most appropriate for you (but based on the other attributes, maybe not!) Both this example and the spam classification example were examples of binary classification, meaning that your prediction is always 1 of 2 categories.
(1) spam (2) not spam
(1) click (2) no click
What’s the pattern?
Let’s try to identify the common themes in the previous 3 examples.
For each example, we have 3 things:
(1) Input data (image, text, or a list of attributes identifying you) (2) A machine learning classifier (3) A prediction made by the classifier For our examples, the predictions were:
What digit does the image represent?
Is the email spam or not spam?
Will this user click on the ad or not?
You can think of this process as a “flowchart” or diagram as:
input > model > prediction
Making correct predictions
How does the machine learning classifier learn to make correct predictions?
This is where the “learning” in “machine learning” comes in.
Normally, in machine learning, we have a table of data from which to learn (we’ll see exactly how that works programmatically very shortly).
But not only do we have the input data, like the images or the text or the user attributes, but we also have the correct answer.
Usually, people refer to these as the labels or the targets.
In my courses we usually use these 2 words interchangeably.
The short answer for “how” a classifier learns to make correct predictions is that the model is given lots of examples of inputs and their true labels, and some algorithm learns to identify this pattern.
The long answer is that once you’re given a set of data (consisting of inputs and corresponding targets), there are actually many different kinds of learning algorithms you can apply.
If you move beyond this course, those are the algorithms you’d learn about (linear regression, logistic regression, decision trees, random forests, neural networks, etc.) Of course, that’s where the math comes in - and in this book I’ve promised you no math.
What does data “look like”?
At this point, we’ve looked at some examples (image classification, text classification, click classification), but what does the data actually look like when it sits on your computer?
We know that data takes the form of “inputs” and “targets”, but how are these represented?
What a great time for another example!
This one will be simple so that it can actually fit on the page.
Suppose I’d like to predict whether or not the students in my class will pass or fail my exam (another binary classification problem).
I thus take some measurements from past students to help me make this prediction: the number of hours they studied and the number of hours they spent playing video games during the study period.
I can make a table containing all these measurements. Each row represents the measurements for a different student.
Thus, you can imagine your data being stored as a CSV (comma-separated values) spreadsheet.
In memory, your input data (the first 2 columns of the table) would be represented by a 2-D array. Your targets (the last column of the table, highlighted in gray) would be represented by a 1-D array of the same length.
The idea behind machine learning is that I can learn to identify patterns from this data.
So for next year’s class, I can ask a student, “How many hours did you study and how long have you spent playing video games?” and using that data I can predict whether or not that student will pass my exam.
The black box perspective
What we need in our case is a useful abstraction.
Let’s treat the machine learning algorithm like a black box.
We can assume this black box is capable of doing 2 things.
1 - It can learn a pattern based on training data. Remember, this training data consists of 2 parts: the input, and the corresponding targets.
People usually just call these X and Y, respectively.
2 - After it learns the pattern from the training data, it can make predictions on new data.
For example, Gmail’s spam classifier has probably never seen the emails I got today, before today.
So how can Gmail’s spam classifier take an email it has never seen and predict whether it’s spam or not spam?
The reason is due to #1 - it has learned to correctly identify the patterns between spam and non-spam.
How does it work in code?
Now that we have a basic idea of how classification works, we can start looking at code.
Luckily, the scikit-learn API has exactly the 2 functions we need to perform the 2 tasks we just talked about.
Let’s first consider our data. Remember, we usually call these X and Y.
We discussed earlier that X will be a 2-D array and Y will be a 1-D array of the same length.
Our data X, generally speaking, can be thought of as an NxD matrix.
N stands for number of samples, and D stands for number of dimensions.
Notice that, if X has N rows, then Y also must have N rows. This is because for each input sample we have, we must have a corresponding target.
Assume for now we have some data loading function that returns the 2 arrays we need: X, Y = load_data()
Now that we have our data, we need a classifier. Let’s choose scikit-learn’s RandomForestClassifier.
Since the RandomForestClassifier is an object, we’ll need to instantiate it: model = RandomForestClassifier()
I like this choice because it has “classifier” in its name, which makes it obvious it is a classifier.
It also happens to be quite a powerful classifier, which is nice.
Side note: We’ll talk about different classifiers and some of their pros and cons later in this book.
Remember, the 2 things a classifier has to do: #1: learn, and #2: make predictions.
Learning is very simple. We just call the instance method “fit”, and we pass in X and Y.
model.fit(X, Y)
Making predictions is similarly very simple. We call the instance method “predict”, and we pass in some data X.
predictions = model.predict(X)
Notice how predict does not have to take in any Y argument.
Why might that be?
When we’re making predictions, we simply don’t need to be told the right answer. If we knew the right answer already, there would be no need to make any predictions!
Model Evaluation
One thing we’re also going to want to do is evaluate our model.
In other words, after we’ve trained our model we want to be able to ask: “How good is this model?”
With classification, we’re interested in the classification rate, also known as the accuracy.
These 2 terms will also be used interchangeably.
It’s just the number of predictions you get correct, divided by the total number of predictions you made.
accuracy = # correct / # total In scikit-learn, this is just another function called “score”. Of course, you can also try implementing the score function yourself, as an exercise. You should be able to accomplish this just using the model’s predict function.
accuracy = model.score(X, Y)