第 3 章 分类 分类
本作品已使用人工智能进行翻译。欢迎您提供反馈和意见:translation-feedback@oreilly.com
在第 1 章中,我提到最常见的监督学习任务是回归(预测值)和分类(预测类)。在第 2 章中,我们探讨了回归任务,即使用线性回归、决策树和随机森林等各种算法预测房屋价值(这些算法将在后面的章节中详细介绍)。现在,我们将把注意力转向分类系统。
MNIST
在本章中,我们将使用 MNIST 数据集,这是一组由高中生和美国人口普查局员工手写的 70,000 张小数字图像。每张图像都标有其代表的数字。人们对这个数据集进行了大量的研究,因此它经常被称为机器学习的 "hello world":每当人们提出一种新的分类算法时,他们都会好奇地想知道这种算法在 MNIST 上的表现如何,而且任何学习机器学习的人迟早都会用到这个数据集。
Scikit-Learn 提供了许多下载流行数据集的辅助函数。MNIST 就是其中之一。以下代码从 OpenML.org 获取 MNIST 数据集。1
fromsklearn.datasetsimportfetch_openmlmnist=fetch_openml('mnist_784',as_frame=False)
sklearn.datasets 软件包主要包含三类函数:fetch_* 函数(如fetch_openml() ),用于下载现实生活中的数据集;load_* 函数,用于加载与 Scikit-Learn 绑定的小型玩具数据集(因此无需通过互联网下载);以及make_* 函数,用于生成假数据集,对测试非常有用。生成的数据集通常以(X, y) 元组的形式返回,其中包含输入数据和目标数据,两者都是 NumPy 数组。其他数据集以sklearn.utils.Bunch 对象的形式返回,这是一个字典,其条目也可以作为属性访问。它们通常包含以下条目:
"DESCR"-
数据集说明
"data"-
输入数据,通常是二维 NumPy 数组
"target"-
标签,通常为一维 NumPy 数组
fetch_openml() 函数有点不同寻常,因为它默认以 Pandas DataFrame 的形式返回输入,以 Pandas Series 的形式返回标签(除非数据集很稀疏)。但 MNIST 数据集包含图像,而 DataFrame 并不适合图像,因此最好设置as_frame=False ,以 NumPy 数组的形式获取数据。让我们来看看这些数组:
>>>X,y=mnist.data,mnist.target>>>Xarray([[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]])>>>X.shape(70000, 784)>>>yarray(['5', '0', '4', ..., '4', '5', '6'], dtype=object)>>>y.shape(70000,)
共有 70,000 幅图像,每幅图像有 784 个特征。这是因为每幅图像都是 28 × 28 像素,每个特征只代表一个像素的强度,从 0(白色)到 ...