Chapter 1. Introduction
Machine learning is about extracting knowledge from data. It is a research field at the intersection of statistics, artificial intelligence, and computer science and is also known as predictive analytics or statistical learning. The application of machine learning methods has in recent years become ubiquitous in everyday life. From automatic recommendations of which movies to watch, to what food to order or which products to buy, to personalized online radio and recognizing your friends in your photos, many modern websites and devices have machine learning algorithms at their core. When you look at a complex website like Facebook, Amazon, or Netflix, it is very likely that every part of the site contains multiple machine learning models.
Outside of commercial applications, machine learning has had a tremendous influence on the way datadriven research is done today. The tools introduced in this book have been applied to diverse scientific problems such as understanding stars, finding distant planets, discovering new particles, analyzing DNA sequences, and providing personalized cancer treatments.
Your application doesn’t need to be as largescale or worldchanging as these examples in order to benefit from machine learning, though. In this chapter, we will explain why machine learning has become so popular, and discuss what kind of problem can be solved using machine learning. Then, we will show you how to build your first machine learning model, introducing important concepts along the way.
1.1 Why Machine Learning?
In the early days of “intelligent” applications, many systems used handcoded rules of “if” and “else” decisions to process data or adjust to user input. Think of a spam filter whose job is to move the appropriate incoming email messages to a spam folder. You could make up a blacklist of words that would result in an email being marked as spam. This would be an example of using an expertdesigned rule system to design an “intelligent” application. Manually crafting decision rules is feasible for some applications, particularly those in which humans have a good understanding of the process to model. However, using handcoded rules to make decisions has two major disadvantages:

The logic required to make a decision is specific to a single domain and task. Changing the task even slightly might require a rewrite of the whole system.

Designing rules requires a deep understanding of how a decision should be made by a human expert.
One example of where this handcoded approach will fail is in detecting faces in images. Today, every smartphone can detect a face in an image. However, face detection was an unsolved problem until as recently as 2001. The main problem is that the way in which pixels (which make up an image in a computer) are “perceived” by the computer is very different from how humans perceive a face. This difference in representation makes it basically impossible for a human to come up with a good set of rules to describe what constitutes a face in a digital image.
Using machine learning, however, simply presenting a program with a large collection of images of faces is enough for an algorithm to determine what characteristics are needed to identify a face.
1.1.1 Problems Machine Learning Can Solve
The most successful kinds of machine learning algorithms are those that automate decisionmaking processes by generalizing from known examples. In this setting, which is known as supervised learning, the user provides the algorithm with pairs of inputs and desired outputs, and the algorithm finds a way to produce the desired output given an input. In particular, the algorithm is able to create an output for an input it has never seen before without any help from a human. Going back to our example of spam classification, using machine learning, the user provides the algorithm with a large number of emails (which are the input), together with information about whether any of these emails are spam (which is the desired output). Given a new email, the algorithm will then produce a prediction as to whether the new email is spam.
Machine learning algorithms that learn from input/output pairs are called supervised learning algorithms because a “teacher” provides supervision to the algorithms in the form of the desired outputs for each example that they learn from. While creating a dataset of inputs and outputs is often a laborious manual process, supervised learning algorithms are well understood and their performance is easy to measure. If your application can be formulated as a supervised learning problem, and you are able to create a dataset that includes the desired outcome, machine learning will likely be able to solve your problem.
Examples of supervised machine learning tasks include:
 Identifying the zip code from handwritten digits on an envelope

Here the input is a scan of the handwriting, and the desired output is the actual digits in the zip code. To create a dataset for building a machine learning model, you need to collect many envelopes. Then you can read the zip codes yourself and store the digits as your desired outcomes.
 Determining whether a tumor is benign based on a medical image

Here the input is the image, and the output is whether the tumor is benign. To create a dataset for building a model, you need a database of medical images. You also need an expert opinion, so a doctor needs to look at all of the images and decide which tumors are benign and which are not. It might even be necessary to do additional diagnosis beyond the content of the image to determine whether the tumor in the image is cancerous or not.
 Detecting fraudulent activity in credit card transactions

Here the input is a record of the credit card transaction, and the output is whether it is likely to be fraudulent or not. Assuming that you are the entity distributing the credit cards, collecting a dataset means storing all transactions and recording if a user reports any transaction as fraudulent.
An interesting thing to note about these examples is that although the inputs and outputs look fairly straightforward, the data collection process for these three tasks is vastly different. While reading envelopes is laborious, it is easy and cheap. Obtaining medical imaging and diagnoses, on the other hand, requires not only expensive machinery but also rare and expensive expert knowledge, not to mention the ethical concerns and privacy issues. In the example of detecting credit card fraud, data collection is much simpler. Your customers will provide you with the desired output, as they will report fraud. All you have to do to obtain the input/output pairs of fraudulent and nonfraudulent activity is wait.
Unsupervised algorithms are the other type of algorithm that we will cover in this book. In unsupervised learning, only the input data is known, and no known output data is given to the algorithm. While there are many successful applications of these methods, they are usually harder to understand and evaluate. Examples of unsupervised learning include:
 Identifying topics in a set of blog posts

If you have a large collection of text data, you might want to summarize it and find prevalent themes in it. You might not know beforehand what these topics are, or how many topics there might be. Therefore, there are no known outputs.
 Segmenting customers into groups with similar preferences

Given a set of customer records, you might want to identify which customers are similar, and whether there are groups of customers with similar preferences. For a shopping site, these might be “parents,” “bookworms,” or “gamers.” Because you don’t know in advance what these groups might be, or even how many there are, you have no known outputs.
 Detecting abnormal access patterns to a website

To identify abuse or bugs, it is often helpful to find access patterns that are different from the norm. Each abnormal pattern might be very different, and you might not have any recorded instances of abnormal behavior. Because in this example you only observe traffic, and you don’t know what constitutes normal and abnormal behavior, this is an unsupervised problem.
For both supervised and unsupervised learning tasks, it is important to have a representation of your input data that a computer can understand. Often it is helpful to think of your data as a table. Each data point that you want to reason about (each email, each customer, each transaction) is a row, and each property that describes that data point (say, the age of a customer or the amount or location of a transaction) is a column. You might describe users by their age, their gender, when they created an account, and how often they have bought from your online shop. You might describe the image of a tumor by the grayscale values of each pixel, or maybe by using the size, shape, and color of the tumor.
Each entity or row here is known as a sample (or data point) in machine learning, while the columns—the properties that describe these entities—are called features.
Later in this book we will go into more detail on the topic of building a good representation of your data, which is called feature extraction or feature engineering. You should keep in mind, however, that no machine learning algorithm will be able to make a prediction on data for which it has no information. For example, if the only feature that you have for a patient is their last name, no algorithm will be able to predict their gender. This information is simply not contained in your data. If you add another feature that contains the patient’s first name, you will have much better luck, as it is often possible to tell the gender by a person’s first name.
1.1.2 Knowing Your Task and Knowing Your Data
Quite possibly the most important part in the machine learning process is understanding the data you are working with and how it relates to the task you want to solve. It will not be effective to randomly choose an algorithm and throw your data at it. It is necessary to understand what is going on in your dataset before you begin building a model. Each algorithm is different in terms of what kind of data and what problem setting it works best for. While you are building a machine learning solution, you should answer, or at least keep in mind, the following questions:

What question(s) am I trying to answer? Do I think the data collected can answer that question?

What is the best way to phrase my question(s) as a machine learning problem?

Have I collected enough data to represent the problem I want to solve?

What features of the data did I extract, and will these enable the right predictions?

How will I measure success in my application?

How will the machine learning solution interact with other parts of my research or business product?
In a larger context, the algorithms and methods in machine learning are only one part of a greater process to solve a particular problem, and it is good to keep the big picture in mind at all times. Many people spend a lot of time building complex machine learning solutions, only to find out they don’t solve the right problem.
When going deep into the technical aspects of machine learning (as we will in this book), it is easy to lose sight of the ultimate goals. While we will not discuss the questions listed here in detail, we still encourage you to keep in mind all the assumptions that you might be making, explicitly or implicitly, when you start building machine learning models.
1.2 Why Python?
Python has become the lingua franca for many data science applications. It combines the power of generalpurpose programming languages with the ease of use of domainspecific scripting languages like MATLAB or R. Python has libraries for data loading, visualization, statistics, natural language processing, image processing, and more. This vast toolbox provides data scientists with a large array of general and specialpurpose functionality. One of the main advantages of using Python is the ability to interact directly with the code, using a terminal or other tools like the Jupyter Notebook, which we’ll look at shortly. Machine learning and data analysis are fundamentally iterative processes, in which the data drives the analysis. It is essential for these processes to have tools that allow quick iteration and easy interaction.
As a generalpurpose programming language, Python also allows for the creation of complex graphical user interfaces (GUIs) and web services, and for integration into existing systems.
1.3 scikitlearn
scikitlearn
is an open source project, meaning that it is free to use
and distribute, and anyone can easily obtain the source code to see what
is going on behind the scenes. The scikitlearn
project is constantly
being developed and improved, and it has a very active user community.
It contains a number of stateoftheart machine learning algorithms, as
well as comprehensive documentation
about each algorithm. scikitlearn
is a very
popular tool, and the most prominent Python library for machine
learning. It is widely used in industry and academia, and a wealth of
tutorials and code snippets are available online. scikitlearn
works
well with a number of other scientific Python tools, which we will
discuss later in this chapter.
While reading this, we recommend that you also browse the scikitlearn
user guide and API documentation for additional details on and many more
options for each algorithm. The online documentation is very thorough,
and this book will provide you with all the prerequisites in machine
learning to understand it in detail.
1.3.1 Installing scikitlearn
scikitlearn
depends on two other Python packages, NumPy and
SciPy. For plotting and interactive development, you should also
install matplotlib
, IPython, and the Jupyter Notebook. We recommend
using one of the following prepackaged Python distributions, which will
provide the necessary packages:
 Anaconda

A Python distribution made for largescale data processing, predictive analytics, and scientific computing. Anaconda comes with NumPy, SciPy,
matplotlib
,pandas
, IPython, Jupyter Notebook, andscikitlearn
. Available on Mac OS, Windows, and Linux, it is a very convenient solution and is the one we suggest for people without an existing installation of the scientific Python packages.  Enthought Canopy

Another Python distribution for scientific computing. This comes with NumPy, SciPy,
matplotlib
,pandas
, and IPython, but the free version does not come withscikitlearn
. If you are part of an academic, degreegranting institution, you can request an academic license and get free access to the paid subscription version of Enthought Canopy. Enthought Canopy is available for Python 2.7.x, and works on Mac OS, Windows, and Linux.  Python(x,y)

A free Python distribution for scientific computing, specifically for Windows. Python(x,y) comes with NumPy, SciPy,
matplotlib
,pandas
, IPython, andscikitlearn
.
If you already have a Python installation set up, you can use pip
to
install all of these packages:
$ pip install numpy scipy matplotlib ipython scikitlearn pandas pillow
For the tree visualizations in Chapter 2, you also need the
graphviz
packages; see the accompanying code for instructions. For
Chapter 7, you will also need the nltk
and spacy
libraries; see the instructions in that chapter.
1.4 Essential Libraries and Tools
Understanding what scikitlearn
is and how to use it is important, but
there are a few other libraries that will enhance your experience.
scikitlearn
is built on top of the NumPy and SciPy scientific Python
libraries. In addition to NumPy and SciPy, we will be using pandas
and
matplotlib
. We will also introduce the Jupyter Notebook, which is a
browserbased interactive programming environment. Briefly, here is what
you should know about these tools in order to get the most out of
scikitlearn
.^{1}
1.4.1 Jupyter Notebook
The Jupyter Notebook is an interactive environment for running code in the browser. It is a great tool for exploratory data analysis and is widely used by data scientists. While the Jupyter Notebook supports many programming languages, we only need the Python support. The Jupyter Notebook makes it easy to incorporate code, text, and images, and all of this book was in fact written as a Jupyter Notebook. All of the code examples we include can be downloaded from https://github.com/amueller/introduction_to_ml_with_python.
1.4.2 NumPy
NumPy is one of the fundamental packages for scientific computing in Python. It contains functionality for multidimensional arrays, highlevel mathematical functions such as linear algebra operations and the Fourier transform, and pseudorandom number generators.
In scikitlearn
, the NumPy array is the fundamental data structure.
scikitlearn
takes in data in the form of NumPy arrays. Any data
you’re using will have to be converted to a NumPy array. The core
functionality of NumPy is the ndarray
class, a multidimensional
(ndimensional) array. All elements of the array must be of the same
type. A NumPy array looks like this:
In[1]:
import
numpy
as
np
x
=
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])
(
"x:
\n
{}"
.
format
(
x
))
Out[1]:
x: [[1 2 3] [4 5 6]]
We will be using NumPy a lot in this book, and we will refer to
objects of the NumPy ndarray
class as “NumPy arrays” or just “arrays.”
1.4.3 SciPy
SciPy is a collection of functions for scientific computing in Python.
It provides, among other functionality, advanced linear algebra
routines, mathematical function optimization, signal processing, special
mathematical functions, and statistical distributions. scikitlearn
draws from SciPy’s collection of functions for implementing its
algorithms. The most important part of SciPy for us is scipy.sparse
:
this provides sparse matrices, which are another representation that
is used for data in scikitlearn
. Sparse matrices are used whenever we
want to store a 2D array that contains mostly zeros:
In[2]:
from
scipy
import
sparse
# Create a 2D NumPy array with a diagonal of ones, and zeros everywhere else
eye
=
np
.
eye
(
4
)
(
"NumPy array:
\n
"
,
eye
)
Out[2]:
NumPy array: [[1. 0. 0. 0.] [0. 1. 0. 0.] [0. 0. 1. 0.] [0. 0. 0. 1.]]
In[3]:
# Convert the NumPy array to a SciPy sparse matrix in CSR format
# Only the nonzero entries are stored
sparse_matrix
=
sparse
.
csr_matrix
(
eye
)
(
"
\n
SciPy sparse CSR matrix:
\n
"
,
sparse_matrix
)
Out[3]:
SciPy sparse CSR matrix: (0, 0) 1.0 (1, 1) 1.0 (2, 2) 1.0 (3, 3) 1.0
Usually it is not possible to create dense representations of sparse data (as they would not fit into memory), so we need to create sparse representations directly. Here is a way to create the same sparse matrix as before, using the COO format:
In[4]:
data
=
np
.
ones
(
4
)
row_indices
=
np
.
arange
(
4
)
col_indices
=
np
.
arange
(
4
)
eye_coo
=
sparse
.
coo_matrix
((
data
,
(
row_indices
,
col_indices
)))
(
"COO representation:
\n
"
,
eye_coo
)
Out[4]:
COO representation: (0, 0) 1.0 (1, 1) 1.0 (2, 2) 1.0 (3, 3) 1.0
More details on SciPy sparse matrices can be found in the SciPy Lecture Notes.
1.4.4 matplotlib
matplotlib
is the primary scientific plotting library in Python. It
provides functions for making publicationquality visualizations such as
line charts, histograms, scatter plots, and so on. Visualizing your data
and different aspects of your analysis can give you important insights,
and we will be using matplotlib
for all our visualizations. When
working inside the Jupyter Notebook, you can show figures directly in
the browser by using the %matplotlib notebook
and %matplotlib inline
commands. We recommend using %matplotlib notebook
, which provides an
interactive environment (though we are using %matplotlib inline
to
produce this book). For example, this code produces the plot in Figure 11:
In[5]:
%
matplotlib
inline
import
matplotlib.pyplot
as
plt
# Generate a sequence of numbers from 10 to 10 with 100 steps in between
x
=
np
.
linspace
(

10
,
10
,
100
)
# Create a second array using sine
y
=
np
.
sin
(
x
)
# The plot function makes a line chart of one array against another
plt
.
plot
(
x
,
y
,
marker
=
"x"
)
1.4.5 pandas
pandas
is a Python library for data wrangling and analysis. It is
built around a data structure called the DataFrame
that is modeled
after the R DataFrame
. Simply put, a pandas
DataFrame
is a table,
similar to an Excel spreadsheet. pandas
provides a great range of
methods to modify and operate on this table; in particular, it allows
SQLlike queries and joins of tables. In contrast to NumPy, which
requires that all entries in an array be of the same type, pandas
allows each column to have a separate type (for example, integers, dates,
floatingpoint numbers, and strings). Another valuable tool provided by
pandas
is its ability to ingest from a great variety of file formats
and databases, like SQL, Excel files, and commaseparated values (CSV)
files. Going into detail about the functionality of pandas
is out of
the scope of this book.
However, Python for Data Analysis by Wes McKinney (O’Reilly, 2012) provides a great
guide. Here is a small example of creating a DataFrame
using a
dictionary:
In[6]:
import
pandas
as
pd
# create a simple dataset of people
data
=
{
'Name'
:
[
"John"
,
"Anna"
,
"Peter"
,
"Linda"
],
'Location'
:
[
"New York"
,
"Paris"
,
"Berlin"
,
"London"
],
'Age'
:
[
24
,
13
,
53
,
33
]
}
data_pandas
=
pd
.
DataFrame
(
data
)
# IPython.display allows "pretty printing" of dataframes
# in the Jupyter notebook
display
(
data_pandas
)
This produces the following output:
Age  Location  Name  

0 
24 
New York 
John 
1 
13 
Paris 
Anna 
2 
53 
Berlin 
Peter 
3 
33 
London 
Linda 
There are several possible ways to query this table. For example:
In[7]:
# Select all rows that have an age column greater than 30
display
(
data_pandas
[
data_pandas
.
Age
>
30
])
This produces the following result:
Age  Location  Name  

2 
53 
Berlin 
Peter 
3 
33 
London 
Linda 
1.4.6 mglearn
This book comes with accompanying code, which you can find on
https://github.com/amueller/introduction_to_ml_with_python. The
accompanying code includes not only all the examples shown in this book,
but also the mglearn library. This is a library of utility functions we
wrote for this book, so that we don’t clutter up our code listings with
details of plotting and data loading. If you’re interested, you can look
up all the functions in the repository, but the details of the mglearn
module are not really important to the material in this book. If you see
a call to mglearn
in the code, it is usually a way to make a pretty
picture quickly, or to get our hands on some interesting data. If you
run the notebooks published on GitHub, the mglearn
package is already
in the right place and you don’t have to worry about it. If you want to
call mglearn
functions from any other place, the easiest way to
install it is by calling pip install mglearn
.
Note
Throughout the book we make ample use of NumPy, matplotlib
and
pandas
. All the code will assume the following imports:
import numpy as np import matplotlib.pyplot as plt import pandas as pd import mglearn from IPython.display import display
We also assume that you will run the code in a Jupyter Notebook with the
%matplotlib notebook
or %matplotlib inline
magic enabled to show
plots. If you are not using the notebook or these magic commands, you
will have to call plt.show
to actually show any of the figures.
1.5 Python 2 Versus Python 3
There are two major versions of Python that are widely used at the
moment: Python 2 (more precisely, 2.7) and Python 3 (with the latest
release being 3.7 at the time of writing). This sometimes leads to some
confusion. Python 2 is no longer actively developed, but because Python 3
contains major changes, Python 2 code usually does not run on Python 3.
If you are new to Python, or are starting a new project from scratch, we
highly recommend using the latest version of Python 3. If you have a
large codebase that you rely on that is written for Python 2, you are
excused from upgrading for now. However, you should try to migrate to
Python 3 as soon as possible. When writing any new code, it is for the
most part quite easy to write code that runs under Python 2 and Python 3.
^{2} If you don’t have to interface with legacy
software, you should definitely use Python 3. All the code in this book
is written in a way that works for both versions. However, the exact
output might differ slightly under Python 2. You should also note that
many packages such as matplotlib
, numpy
, and scikitlearn
are no longer
releasing new features under Python 2.7; you need to upgrade to
Python 3.7 to get the benefit of the improvements that come with newer
versions.
1.6 Versions Used in this Book
We are using the following versions of the previously mentioned libraries in this book:
In[8]:
import
sys
(
"Python version:"
,
sys
.
version
)
import
pandas
as
pd
(
"pandas version:"
,
pd
.
__version__
)
import
matplotlib
(
"matplotlib version:"
,
matplotlib
.
__version__
)
import
numpy
as
np
(
"NumPy version:"
,
np
.
__version__
)
import
scipy
as
sp
(
"SciPy version:"
,
sp
.
__version__
)
import
IPython
(
"IPython version:"
,
IPython
.
__version__
)
import
sklearn
(
"scikitlearn version:"
,
sklearn
.
__version__
)
Out[8]:
Python version: 3.7.0 (default, Jun 28 2018, 13:15:42) [GCC 7.2.0] pandas version: 0.23.4 matplotlib version: 3.0.0 NumPy version: 1.15.2 SciPy version: 1.1.0 IPython version: 6.4.0 scikitlearn version: 0.20.0
While it is not important to match these versions exactly, you should
have a version of scikitlearn
that is as least as recent as the one
we used.
Note
When using the code in this book, you might sometimes see
DeprecationWarnings or FutureWarnings from scikitlearn
. These
inform you about behavior in scikitlearn
that will change in the
future or will be removed. While going through this book, you can safely
ignore these. If you are running a machine learning algorithm in
production, you should carefully consider each warning, as they might
inform you about functionality being removed in the future or outcomes
of predictions changing.
Now that we have everything set up, let’s dive into our first application of machine learning.
1.7 A First Application: Classifying Iris Species
In this section, we will go through a simple machine learning application and create our first model. In the process, we will introduce some core concepts and terms.
Let’s assume that a hobby botanist is interested in distinguishing the species of some iris flowers that she has found. She has collected some measurements associated with each iris: the length and width of the petals and the length and width of the sepals, all measured in centimeters (see Figure 12).
She also has the measurements of some irises that have been previously identified by an expert botanist as belonging to the species setosa, versicolor, or virginica. For these measurements, she can be certain of which species each iris belongs to. Let’s assume that these are the only species our hobby botanist will encounter in the wild.
Our goal is to build a machine learning model that can learn from the measurements of these irises whose species is known, so that we can predict the species for a new iris.
Because we have measurements for which we know the correct species of iris, this is a supervised learning problem. In this problem, we want to predict one of several options (the species of iris). This is an example of a classification problem. The possible outputs (different species of irises) are called classes. Every iris in the dataset belongs to one of three classes, so this problem is a threeclass classification problem.
The desired output for a single data point (an iris) is the species of this flower. For a particular data point, the species it belongs to is called its label.
1.7.1 Meet the Data
The data we will use for this example is the Iris dataset, a classical
dataset in machine learning and statistics. It is included in
scikitlearn
in the dataset
module. We can load it by calling the
load_iris
function:
In[9]:
from
sklearn.datasets
import
load_iris
iris_dataset
=
load_iris
()
The iris object that is returned by load_iris
is a Bunch
object,
which is very similar to a dictionary. It contains keys and values:
In[10]:
(
"Keys of iris_dataset:
\n
"
,
iris_dataset
.
keys
())
Out[10]:
Keys of iris_dataset: dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
The value of the key DESCR
is a short description of the dataset. We
show the beginning of the description here (feel free to look up the
rest yourself):
In[11]:
(
iris_dataset
[
'DESCR'
][:
193
]
+
"
\n
..."
)
Out[11]:
Iris Plants Database ==================== Notes  Data Set Characteristics: :Number of Instances: 150 (50 in each of three classes) :Number of Attributes: 4 numeric, predictive att ... 
The value of the key target_names
is an array of strings, containing
the species of flower that we want to predict:
In[12]:
(
"Target names:"
,
iris_dataset
[
'target_names'
])
Out[12]:
Target names: ['setosa' 'versicolor' 'virginica']
The value of feature_names
is a list of strings, giving the
description of each feature:
In[13]:
(
"Feature names:
\n
"
,
iris_dataset
[
'feature_names'
])
Out[13]:
Feature names: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
The data itself is contained in the target
and data
fields. data
contains the numeric measurements of sepal length, sepal width, petal
length, and petal width in a NumPy array:
In[14]:
(
"Type of data:"
,
type
(
iris_dataset
[
'data'
]))
Out[14]:
Type of data: <class 'numpy.ndarray'>
The rows in the data
array correspond to flowers, while the columns
represent the four measurements that were taken for each flower:
In[15]:
(
"Shape of data:"
,
iris_dataset
[
'data'
]
.
shape
)
Out[15]:
Shape of data: (150, 4)
We see that the array contains measurements for 150 different flowers.
Remember that the individual items are called samples in machine
learning, and their properties are called features. The shape of the
data
array is the number of samples times the number of features. This
is a convention in scikitlearn
, and your data will always be assumed
to be in this shape. Here are the feature values for the first five
samples:
In[16]:
(
"First five rows of data:
\n
"
,
iris_dataset
[
'data'
][:
5
])
Out[16]:
First five rows of data: [[5.1 3.5 1.4 0.2] [4.9 3. 1.4 0.2] [4.7 3.2 1.3 0.2] [4.6 3.1 1.5 0.2] [5. 3.6 1.4 0.2]]
From this data, we can see that all of the first five flowers have a petal width of 0.2 cm and that the first flower has the longest sepal, at 5.1 cm.
The target
array contains the species of each of the flowers that were
measured, also as a NumPy array:
In[17]:
(
"Type of target:"
,
type
(
iris_dataset
[
'target'
]))
Out[17]:
Type of target: <class 'numpy.ndarray'>
target
is a onedimensional array, with one entry per flower:
In[18]:
(
"Shape of target:"
,
iris_dataset
[
'target'
]
.
shape
)
Out[18]:
Shape of target: (150,)
The species are encoded as integers from 0 to 2:
In[19]:
(
"Target:
\n
"
,
iris_dataset
[
'target'
])
Out[19]:
Target: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
The meanings of the numbers are given by the iris['target_names']
array: 0 means setosa, 1 means versicolor, and 2 means virginica.
1.7.2 Measuring Success: Training and Testing Data
We want to build a machine learning model from this data that can predict the species of iris for a new set of measurements. But before we can apply our model to new measurements, we need to know whether it actually works—that is, whether we should trust its predictions.
Unfortunately, we cannot use the data we used to build the model to evaluate it. This is because our model can always simply remember the whole training set, and will therefore always predict the correct label for any point in the training set. This “remembering” does not indicate to us whether our model will generalize well (in other words, whether it will also perform well on new data).
To assess the model’s performance, we show it new data (data that it hasn’t seen before) for which we have labels. This is usually done by splitting the labeled data we have collected (here, our 150 flower measurements) into two parts. One part of the data is used to build our machine learning model, and is called the training data or training set. The rest of the data will be used to assess how well the model works; this is called the test data, test set, or holdout set.
scikitlearn
contains a function that shuffles the dataset and splits
it for you: the train_test_split
function. This function extracts 75%
of the rows in the data as the training set, together with the
corresponding labels for this data. The remaining 25% of the data,
together with the remaining labels, is declared as the test set.
Deciding how much data you want to put into the training and the test
set respectively is somewhat arbitrary, but using a test set containing
25% of the data is a good rule of thumb.
In scikitlearn
, data is usually denoted with a capital X
, while
labels are denoted by a lowercase y
. This is inspired by the standard
formulation f(x)=y in mathematics, where x
is the input to a function and y is the output. Following
more conventions from mathematics, we use a capital X
because the data
is a twodimensional array (a matrix) and a lowercase y
because the
target is a onedimensional array (a vector).
Let’s call train_test_split
on our data and assign the outputs using
this nomenclature:
In[20]:
from
sklearn.model_selection
import
train_test_split
X_train
,
X_test
,
y_train
,
y_test
=
train_test_split
(
iris_dataset
[
'data'
],
iris_dataset
[
'target'
],
random_state
=
0
)
Before making the split, the train_test_split
function shuffles the
dataset using a pseudorandom number generator. If we just took the last
25% of the data as a test set, all the data points would have the label
2
, as the data points are sorted by the label (see the output for
iris['target']
shown earlier). Using a test set containing only one
of the three classes would not tell us much about how well our model
generalizes, so we shuffle our data to make sure the test data contains
data from all classes.
To make sure that we will get the same output if we run the same
function several times, we provide the pseudorandom number generator
with a fixed seed using the random_state
parameter. This will make the
outcome deterministic, so this line will always have the same outcome.
We will always fix the random_state
in this way when using randomized
procedures in this book.
The output of the train_test_split
function is X_train
, X_test
,
y_train
, and y_test
, which are all NumPy arrays. X_train
contains
75% of the rows of the dataset, and X_test
contains the remaining 25%:
In[21]:
(
"X_train shape:"
,
X_train
.
shape
)
(
"y_train shape:"
,
y_train
.
shape
)
Out[21]:
X_train shape: (112, 4) y_train shape: (112,)
In[22]:
(
"X_test shape:"
,
X_test
.
shape
)
(
"y_test shape:"
,
y_test
.
shape
)
Out[22]:
X_test shape: (38, 4) y_test shape: (38,)
1.7.3 First Things First: Look at Your Data
Before building a machine learning model it is often a good idea to inspect the data, to see if the task is easily solvable without machine learning, or if the desired information might not be contained in the data.
Additionally, inspecting your data is a good way to find abnormalities and peculiarities. Maybe some of your irises were measured using inches and not centimeters, for example. In the real world, inconsistencies in the data and unexpected measurements are very common.
One of the best ways to inspect data is to visualize it. One way to do this is by using a scatter plot. A scatter plot of the data puts one feature along the xaxis and another along the yaxis, and draws a dot for each data point. Unfortunately, computer screens have only two dimensions, which allows us to plot only two (or maybe three) features at a time. It is difficult to plot datasets with more than three features this way. One way around this problem is to do a pair plot, which looks at all possible pairs of features. If you have a small number of features, such as the four we have here, this is quite reasonable. You should keep in mind, however, that a pair plot does not show the interaction of all of features at once, so some interesting aspects of the data may not be revealed when visualizing it this way.
Figure 13 is a pair plot of the features in the training set. The data points
are colored according to the species the iris belongs to. To create the
plot, we first convert the NumPy array into a pandas
DataFrame
.
pandas
has a function to create pair plots called scatter_matrix
. The
diagonal of this matrix is filled with histograms of each feature:
In[23]:
# create dataframe from data in X_train
# label the columns using the strings in iris_dataset.feature_names
iris_dataframe
=
pd
.
DataFrame
(
X_train
,
columns
=
iris_dataset
.
feature_names
)
# create a scatter matrix from the dataframe, color by y_train
pd
.
plotting
.
scatter_matrix
(
iris_dataframe
,
c
=
y_train
,
figsize
=
(
15
,
15
),
marker
=
'o'
,
hist_kwds
=
{
'bins'
:
20
},
s
=
60
,
alpha
=.
8
,
cmap
=
mglearn
.
cm3
)
From the plots, we can see that the three classes seem to be relatively well separated using the sepal and petal measurements. This means that a machine learning model will likely be able to learn to separate them.
1.7.4 Building Your First Model: kNearest Neighbors
Now we can start building the actual machine learning model. There are
many classification algorithms in scikitlearn
that we could use. Here
we will use a knearest neighbors classifier, which is easy to
understand. Building this model only consists of storing the training
set. To make a prediction for a new data point, the algorithm finds the
point in the training set that is closest to the new point. Then it
assigns the label of this training point to the new data point.
The k in knearest neighbors signifies that instead of using only
the closest neighbor to the new data point, we can consider any fixed
number k
of neighbors in the training (for example, the closest three
or five neighbors). Then, we can make a prediction using the majority
class among these neighbors. We will go into more detail about this in
Chapter 2; for now, we’ll use only a single neighbor.
All machine learning models in scikitlearn
are implemented in their
own classes, which are called Estimator
classes. The knearest
neighbors classification algorithm is implemented in the
KNeighborsClassifier
class in the neighbors
module. Before we can
use the model, we need to instantiate the class into an object. This is
when we will set any parameters of the model. The most important
parameter of KNeighborsClassifier
is the number of neighbors, which we
will set to 1:
In[24]:
from
sklearn.neighbors
import
KNeighborsClassifier
knn
=
KNeighborsClassifier
(
n_neighbors
=
1
)
The knn
object encapsulates the algorithm that will be used to build
the model from the training data, as well the algorithm to make
predictions on new data points. It will also hold the information that
the algorithm has extracted from the training data. In the case of
KNeighborsClassifier
, it will just store the training set.
To build the model on the training set, we call the fit
method of the
knn
object, which takes as arguments the NumPy array X_train
containing the training data and the NumPy array y_train
of the
corresponding training labels:
In[25]:
knn
.
fit
(
X_train
,
y_train
)
Out[25]:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=None, n_neighbors=1, p=2, weights='uniform')
The fit
method returns the knn
object itself (and modifies it in
place), so we get a string representation of our classifier. The
representation shows us which parameters were used in creating the
model. Nearly all of them are the default values, but you can also find
n_neighbors=1
, which is the parameter that we passed. Most models in
scikitlearn
have many parameters, but the majority of them are either
speed optimizations or for very special use cases. You don’t have to
worry about the other parameters shown in this representation. Printing
a scikitlearn
model can yield very long strings, but don’t be
intimidated by these. We will cover all the important parameters in
Chapter 2. In the remainder of this book, we will not usually show the output of fit
because it doesn’t contain any new information.
1.7.5 Making Predictions
We can now make predictions using this model on new data for which we might not know the correct labels. Imagine we found an iris in the wild with a sepal length of 5 cm, a sepal width of 2.9 cm, a petal length of 1 cm, and a petal width of 0.2 cm. What species of iris would this be? We can put this data into a NumPy array, again by calculating the shape—that is, the number of samples (1) multiplied by the number of features (4):
In[26]:
X_new
=
np
.
array
([[
5
,
2.9
,
1
,
0.2
]])
(
"X_new.shape:"
,
X_new
.
shape
)
Out[26]:
X_new.shape: (1, 4)
Note that we made the measurements of this single flower into a row in a
twodimensional NumPy array, as scikitlearn
always expects
twodimensional arrays for the data.
To make a prediction, we call the predict
method of the knn
object:
In[27]:
prediction
=
knn
.
predict
(
X_new
)
(
"Prediction:"
,
prediction
)
(
"Predicted target name:"
,
iris_dataset
[
'target_names'
][
prediction
])
Out[27]:
Prediction: [0] Predicted target name: ['setosa']
Our model predicts that this new iris belongs to the class 0, meaning its species is setosa. But how do we know whether we can trust our model? We don’t know the correct species of this sample, which is the whole point of building the model!
1.7.6 Evaluating the Model
This is where the test set that we created earlier comes in. This data was not used to build the model, but we do know what the correct species is for each iris in the test set.
Therefore, we can make a prediction for each iris in the test data and compare it against its label (the known species). We can measure how well the model works by computing the accuracy, which is the fraction of flowers for which the right species was predicted:
In[28]:
y_pred
=
knn
.
predict
(
X_test
)
(
"Test set predictions:
\n
"
,
y_pred
)
Out[28]:
Test set predictions: [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0 2]
In[29]:
(
"Test set score: {:.2f}"
.
format
(
np
.
mean
(
y_pred
==
y_test
)))
Out[29]:
Test set score: 0.97
We can also use the score
method of the knn
object, which will
compute the test set accuracy for us:
In[30]:
(
"Test set score: {:.2f}"
.
format
(
knn
.
score
(
X_test
,
y_test
)))
Out[30]:
Test set score: 0.97
For this model, the test set accuracy is about 0.97, which means we made the right prediction for 97% of the irises in the test set. Under some mathematical assumptions, this means that we can expect our model to be correct 97% of the time for new irises. For our hobby botanist application, this high level of accuracy means that our model may be trustworthy enough to use. In later chapters we will discuss how we can improve performance, and what caveats there are in tuning a model.
1.8 Summary and Outlook
Let’s summarize what we learned in this chapter. We started with a brief introduction to machine learning and its applications, then discussed the distinction between supervised and unsupervised learning and gave an overview of the tools we’ll be using in this book. Then, we formulated the task of predicting which species of iris a particular flower belongs to by using physical measurements of the flower. We used a dataset of measurements that was annotated by an expert with the correct species to build our model, making this a supervised learning task. There were three possible species, setosa, versicolor, or virginica, which made the task a threeclass classification problem. The possible species are called classes in the classification problem, and the species of a single iris is called its label.
The Iris dataset consists of two NumPy arrays: one containing the data,
which is referred to as X
in scikitlearn
, and one containing the
correct or desired outputs, which is called y
. The array X
is a
twodimensional array of features, with one row per data point and one
column per feature. The array y
is a onedimensional array, which here
contains one class label, an integer ranging from 0 to 2, for each of
the samples.
We split our dataset into a training set, to build our model, and a test set, to evaluate how well our model will generalize to new, previously unseen data.
We chose the knearest neighbors classification algorithm, which makes
predictions for a new data point by considering its closest neighbor(s)
in the training set. This is implemented in the KNeighborsClassifier
class, which contains the algorithm that builds the model as well as
the algorithm that makes a prediction using the model. We instantiated
the class, setting parameters. Then we built the model by calling the
fit
method, passing the training data (X_train
) and training outputs
(y_train
) as parameters. We evaluated the model using the score
method, which computes the accuracy of the model. We applied the score
method to the test set data and the test set labels and found that our
model is about 97% accurate, meaning it is correct 97% of the time on
the test set.
This gave us the confidence to apply the model to new data (in our example, new flower measurements) and trust that the model will be correct about 97% of the time.
Here is a summary of the code needed for the whole training and evaluation procedure:
In[31]:
X_train
,
X_test
,
y_train
,
y_test
=
train_test_split
(
iris_dataset
[
'data'
],
iris_dataset
[
'target'
],
random_state
=
0
)
knn
=
KNeighborsClassifier
(
n_neighbors
=
1
)
knn
.
fit
(
X_train
,
y_train
)
(
"Test set score: {:.2f}"
.
format
(
knn
.
score
(
X_test
,
y_test
)))
Out[31]:
Test set score: 0.97
This snippet contains the core code for applying any machine learning
algorithm using scikitlearn
. The fit
, predict
, and score
methods are the common interface to supervised models in scikitlearn
,
and with the concepts introduced in this chapter, you can apply these
models to many machine learning tasks. In the next chapter, we will go
into more depth about the different kinds of supervised models in
scikitlearn
and how to apply them successfully.
^{1} If you are unfamiliar with NumPy or matplotlib
, we recommend reading the first chapter of the SciPy Lecture Notes.
^{2} The six
package can be very handy for that.
Get Introduction to Machine Learning with Python now with the O’Reilly learning platform.
O’Reilly members experience live online training, plus books, videos, and digital content from nearly 200 publishers.