
88
|
第
3
章
图 3-1:MNIST 数据集中的数字
同样,我们先将训练集数据混洗,这样能保证交叉验证时所有的折叠都差不多(你肯定
不希望某个折叠丢失一些数字)。此外,有些机器学习算法对训练实例的顺序敏感,如
果连续输入许多相似的实例,可能导致执行性能不佳。给数据集混洗正是为了确保这种
情况不会发生
注 2
。
1
3.2 训练二元分类器
现在先简化问题,只尝试识别一个数字,比如数字 5。那么这个“数字 5 检测器”就
是一个二元分类器的示例,它只能区分两个类别:5 和非 5。先为此分类任务创建目标
向量:
y_train_5 = (y_train == 5)
# True for all 5s, False for all other digits
y_test_5 = (y_test == 5)
接着挑选一个分类器并开始训练。一个好的初始选择是随机梯度下降(SGD)分类器,
使用 Scikit-Learn 的 SGDClassifier 类即可。这个分类器的优势是能够有效处理非常
大型的数据集。这部分是因为 SGD 独立处理训练实例,一次一个(这也使得 SGD 非常
适合在线学习),稍后我们将会看到。此时先创建一个 SGDClassifier 并在整个训练
集上进行训练:
注 2 :在某些情况下,例如,如果你正在处理时间序列数据(例如股市价格或天气状况),则混洗可能不
是一个好主意。我们将在下一章中对此进行探讨。