【机器学习】入门第一步
机器学习前提介绍:
- 使用python语言,最好使用python3
- 使用Jupyter notebook
- 熟练使用Numpy/SciPy/Pandas/matplotlib
- 机器学习主要框架scikit-learn
另外,为了方便呈现数据,这里使用了mglarn
模块。该模块的使用不必费脑学习,只需要知道它可以帮助美化图表、呈现数据即可。
# 在学习之前,先导入这些常用的模块
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mglearn
构建一个简单的机器学习应用
假设这里已经收集了关于鸢尾花的测量数据:花瓣的长度和宽度;花萼的长度和宽度。这些花共有三个品种:setosa/versicolor/virginica。并且事先已经将所有鸢尾花的数据与分类做了对应关系。
如果现在又有一批新的关于鸢尾花的数据,但没有做出分类,是否可以根据其花瓣的长度和宽度、花萼的长度和宽度来预测出其类别呢?
以上问题,是一个分类问题,最终对于数据结果的输出叫做类别。
由于事先对已有的鸢尾花数据做了分类处理,再从这些数据的经验中判断新数据的分类,这种学习方式被叫做监督式学习,即从给定好的输入与输出的对应关系中,得出新的数据可能的结果。
第一步,获得数据
鸢尾花数据集已经包含在 scikit-learn 的 datasets 模块中,可以直接调用 load_iris 函数来加载数据:
# 导入load_iris
from sklearn.datasets import load_iris
# 调用数据函数
iris_dataset = load_iris()
# 展示结果
iris_dataset
{'data': array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.1, 1.5, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]]),
'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'),
'DESCR': 'Iris Plants Database\n====================\n\nNotes\n-----\nData Set Characteristics:\n :Number of Instances: 150 (50 in each of three classes)\n :Number of Attributes: 4 numeric, predictive attributes and the class\n :Attribute Information:\n - sepal length in cm\n - sepal width in cm\n - petal length in cm\n - petal width in cm\n - class:\n - Iris-Setosa\n - Iris-Versicolour\n - Iris-Virginica\n :Summary Statistics:\n\n ============== ==== ==== ======= ===== ====================\n Min Max Mean SD Class Correlation\n ============== ==== ==== ======= ===== ====================\n sepal length: 4.3 7.9 5.84 0.83 0.7826\n sepal width: 2.0 4.4 3.05 0.43 -0.4194\n petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n ============== ==== ==== ======= ===== ====================\n\n :Missing Attribute Values: None\n :Class Distribution: 33.3% for each of 3 classes.\n :Creator: R.A. Fisher\n :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n :Date: July, 1988\n\nThis is a copy of UCI ML iris datasets.\nhttp://archive.ics.uci.edu/ml/datasets/Iris\n\nThe famous Iris database, first used by Sir R.A Fisher\n\nThis is perhaps the best known database to be found in the\npattern recognition literature. Fisher\'s paper is a classic in the field and\nis referenced frequently to this day. (See Duda & Hart, for example.) The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant. One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\nReferences\n----------\n - Fisher,R.A. "The use of multiple measurements in taxonomic problems"\n Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n Mathematical Statistics" (John Wiley, NY, 1950).\n - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.\n (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n Structure and Classification Rule for Recognition in Partially Exposed\n Environments". IEEE Transactions on Pattern Analysis and Machine\n Intelligence, Vol. PAMI-2, No. 1, 67-71.\n - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions\n on Information Theory, May 1972, 431-433.\n - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II\n conceptual clustering system finds 3 classes in the data.\n - Many, many more ...\n',
'feature_names': ['sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)']}
laod_iris 返回的是一个 Buch 对象,与字典非常相似,里面包含键和值:
# 查看键
iris_dataset.keys()
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])
- data 对应是的鸢尾花测量的数据集
- target 对应的是分类
- target_names 对应是的类别的名称
- DESCR 对应的是数据集的说明
- feature_names 对应的是数据特征列表
# 查看数据集
iris_dataset.data
array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.1, 1.5, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]])
datak中的数据有四列,每列表示花萼的长度、花萼的宽度、花瓣的长度、花瓣的宽度,格式为Numpy数组。
# 查看数据的数量
iris_dataset.data.shape
(150, 4)
可以看出数据一共有150行,4列。
# 查看数据对应的分类
iris_dataset.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
每一行代表一个花朵,也对应一个类别,这里的0,1,2分别代表三个品种。
要点知道:
数据集中的个体叫做样本,其属性叫作特征或标签。
# 查看数据的特征
iris_dataset.feature_names
['sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)']
第二步:训练数据
这里不会将所有的数据都用来训练,还要留出一部分用来预测。这里的预测指的是,通过训练后,形成一个模型,使用未经训练的数据去测试模型是否能准确预测出其分类。
sklearn-learn 中的 train_split 函数可以打乱数据集并进行拆分,默认会将75%的数据用作训练集,25%的数据集用作测试集。
在书写上,数据通常用大写的X表示,标签则用小写的y表示。一般大写用来表示二维矩阵,小写表示一维的向量。
# 导入 train_split函数
from sklearn.model_selection import train_test_split
# 拆分数据
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)
train_test_split 中需要三个参数,第一个是数据集,第二个是标签集,第三个是随机种子数。
由于iris_dataset是一个Buch对象,因为既可以使用属性的方式也可以使用中括号的方式获得对应的值。
random_state是指利用伪随机数生成器将数据集打乱。
# 查看训练数据集
X_train
array([[5.9, 3. , 4.2, 1.5],
[5.8, 2.6, 4. , 1.2],
[6.8, 3. , 5.5, 2.1],
[4.7, 3.2, 1.3, 0.2],
[6.9, 3.1, 5.1, 2.3],
[5. , 3.5, 1.6, 0.6],
[5.4, 3.7, 1.5, 0.2],
[5. , 2. , 3.5, 1. ],
[6.5, 3. , 5.5, 1.8],
[6.7, 3.3, 5.7, 2.5],
[6. , 2.2, 5. , 1.5],
[6.7, 2.5, 5.8, 1.8],
[5.6, 2.5, 3.9, 1.1],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.3, 4.7, 1.6],
[5.5, 2.4, 3.8, 1.1],
[6.3, 2.7, 4.9, 1.8],
[6.3, 2.8, 5.1, 1.5],
[4.9, 2.5, 4.5, 1.7],
[6.3, 2.5, 5. , 1.9],
[7. , 3.2, 4.7, 1.4],
[6.5, 3. , 5.2, 2. ],
[6. , 3.4, 4.5, 1.6],
[4.8, 3.1, 1.6, 0.2],
[5.8, 2.7, 5.1, 1.9],
[5.6, 2.7, 4.2, 1.3],
[5.6, 2.9, 3.6, 1.3],
[5.5, 2.5, 4. , 1.3],
[6.1, 3. , 4.6, 1.4],
[7.2, 3.2, 6. , 1.8],
[5.3, 3.7, 1.5, 0.2],
[4.3, 3. , 1.1, 0.1],
[6.4, 2.7, 5.3, 1.9],
[5.7, 3. , 4.2, 1.2],
[5.4, 3.4, 1.7, 0.2],
[5.7, 4.4, 1.5, 0.4],
[6.9, 3.1, 4.9, 1.5],
[4.6, 3.1, 1.5, 0.2],
[5.9, 3. , 5.1, 1.8],
[5.1, 2.5, 3. , 1.1],
[4.6, 3.4, 1.4, 0.3],
[6.2, 2.2, 4.5, 1.5],
[7.2, 3.6, 6.1, 2.5],
[5.7, 2.9, 4.2, 1.3],
[4.8, 3. , 1.4, 0.1],
[7.1, 3. , 5.9, 2.1],
[6.9, 3.2, 5.7, 2.3],
[6.5, 3. , 5.8, 2.2],
[6.4, 2.8, 5.6, 2.1],
[5.1, 3.8, 1.6, 0.2],
[4.8, 3.4, 1.6, 0.2],
[6.5, 3.2, 5.1, 2. ],
[6.7, 3.3, 5.7, 2.1],
[4.5, 2.3, 1.3, 0.3],
[6.2, 3.4, 5.4, 2.3],
[4.9, 3. , 1.4, 0.2],
[5.7, 2.5, 5. , 2. ],
[6.9, 3.1, 5.4, 2.1],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.6, 1.4, 0.2],
[7.2, 3. , 5.8, 1.6],
[5.1, 3.5, 1.4, 0.3],
[4.4, 3. , 1.3, 0.2],
[5.4, 3.9, 1.7, 0.4],
[5.5, 2.3, 4. , 1.3],
[6.8, 3.2, 5.9, 2.3],
[7.6, 3. , 6.6, 2.1],
[5.1, 3.5, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.2, 3.4, 1.4, 0.2],
[5.7, 2.8, 4.5, 1.3],
[6.6, 3. , 4.4, 1.4],
[5. , 3.2, 1.2, 0.2],
[5.1, 3.3, 1.7, 0.5],
[6.4, 2.9, 4.3, 1.3],
[5.4, 3.4, 1.5, 0.4],
[7.7, 2.6, 6.9, 2.3],
[4.9, 2.4, 3.3, 1. ],
[7.9, 3.8, 6.4, 2. ],
[6.7, 3.1, 4.4, 1.4],
[5.2, 4.1, 1.5, 0.1],
[6. , 3. , 4.8, 1.8],
[5.8, 4. , 1.2, 0.2],
[7.7, 2.8, 6.7, 2. ],
[5.1, 3.8, 1.5, 0.3],
[4.7, 3.2, 1.6, 0.2],
[7.4, 2.8, 6.1, 1.9],
[5. , 3.3, 1.4, 0.2],
[6.3, 3.4, 5.6, 2.4],
[5.7, 2.8, 4.1, 1.3],
[5.8, 2.7, 3.9, 1.2],
[5.7, 2.6, 3.5, 1. ],
[6.4, 3.2, 5.3, 2.3],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 4.9, 1.5],
[6.7, 3. , 5. , 1.7],
[5. , 3. , 1.6, 0.2],
[5.5, 2.4, 3.7, 1. ],
[6.7, 3.1, 5.6, 2.4],
[5.8, 2.7, 5.1, 1.9],
[5.1, 3.4, 1.5, 0.2],
[6.6, 2.9, 4.6, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.9, 3.2, 4.8, 1.8],
[6.3, 2.3, 4.4, 1.3],
[5.5, 3.5, 1.3, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.9, 3.1, 1.5, 0.1],
[6.3, 2.9, 5.6, 1.8],
[5.8, 2.7, 4.1, 1. ],
[7.7, 3.8, 6.7, 2.2],
[4.6, 3.2, 1.4, 0.2]])
# 查看训练标签集
y_train
array([1, 1, 2, 0, 2, 0, 0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 1, 2,
1, 0, 2, 1, 1, 1, 1, 2, 0, 0, 2, 1, 0, 0, 1, 0, 2, 1, 0, 1, 2, 1,
0, 2, 2, 2, 2, 0, 0, 2, 2, 0, 2, 0, 2, 2, 0, 0, 2, 0, 0, 0, 1, 2,
2, 0, 0, 0, 1, 1, 0, 0, 1, 0, 2, 1, 2, 1, 0, 2, 0, 2, 0, 0, 2, 0,
2, 1, 1, 1, 2, 2, 1, 1, 0, 1, 2, 2, 0, 1, 1, 1, 1, 0, 0, 0, 2, 1,
2, 0])
# 查看测试数据集
X_test
array([[5.8, 2.8, 5.1, 2.4],
[6. , 2.2, 4. , 1. ],
[5.5, 4.2, 1.4, 0.2],
[7.3, 2.9, 6.3, 1.8],
[5. , 3.4, 1.5, 0.2],
[6.3, 3.3, 6. , 2.5],
[5. , 3.5, 1.3, 0.3],
[6.7, 3.1, 4.7, 1.5],
[6.8, 2.8, 4.8, 1.4],
[6.1, 2.8, 4. , 1.3],
[6.1, 2.6, 5.6, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.5, 2.8, 4.6, 1.5],
[6.1, 2.9, 4.7, 1.4],
[4.9, 3.1, 1.5, 0.1],
[6. , 2.9, 4.5, 1.5],
[5.5, 2.6, 4.4, 1.2],
[4.8, 3. , 1.4, 0.3],
[5.4, 3.9, 1.3, 0.4],
[5.6, 2.8, 4.9, 2. ],
[5.6, 3. , 4.5, 1.5],
[4.8, 3.4, 1.9, 0.2],
[4.4, 2.9, 1.4, 0.2],
[6.2, 2.8, 4.8, 1.8],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.8, 1.9, 0.4],
[6.2, 2.9, 4.3, 1.3],
[5. , 2.3, 3.3, 1. ],
[5. , 3.4, 1.6, 0.4],
[6.4, 3.1, 5.5, 1.8],
[5.4, 3. , 4.5, 1.5],
[5.2, 3.5, 1.5, 0.2],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.2],
[5.2, 2.7, 3.9, 1.4],
[5.7, 3.8, 1.7, 0.3],
[6. , 2.7, 5.1, 1.6]])
# 查看测试标签集
y_test
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 1])
观察数据
下图可以通过四个标签值两两对应的关系,查看其表现。(这里不做深究其原理)
# 将训练数据转换成DataFrame
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
# 通过scatter_matrix绘制出矩阵图
grr = pd.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o', hist_kwds={'bins': 20}, s=60, alpha=.8, cmap=mglearn.cm3)
C:\Users\Administrator\Anaconda3\lib\site-packages\ipykernel_launcher.py:4: FutureWarning: pandas.scatter_matrix is deprecated, use pandas.plotting.scatter_matrix instead
after removing the cwd from sys.path.
构建第一个模型:k近邻算法
想要训练数据,则需要一个算法模型。这里选择使用k近邻分类算法。
k近邻分类器中k的含义,新数据与训练集中最近的任意k个邻居,也就是说,新数据与k个某标签离得最近,则归类为该标签
scikit_lean 中所有的机器学习模型都在各自的类中实现,k近邻算法实在 neighors 模块的 KNeighborsClassifier 类中实现的,我们需要将这个列实例化为一个对象,然后才能使用这个模型
# 导入KNeighborsClassifier模块
from sklearn.neighbors import KNeighborsClassifier
# 实例化对象
knn = KNeighborsClassifier(n_neighbors=1)
n_neighbors 参数表示k的个数,1一表示按与它相邻最近的那1个进行分类。
想要基于训练集来构建模型,需要调用knn对象的fit方法,输入参数X_train和y_train。
# 训练数据,并返回模型
knn.fit(X_train,y_train)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=1, p=2,
weights='uniform')
fit方法返回的是knn对象,所以这里得到了一个表示该对象的字符串
第三步:做出预测
# 假设这里有一个新的花瓣数据
X_new = np.array([[5,2.9,1,0.2]])
需要注意的是,这里的数据一定要是二维的数据才可以
调用 knn 的 predict 方法来进行预测
# 调用 predict 函数进行预测
prediction = knn.predict(X_new)
# 查看返回的类型
prediction
array([0])
iris_dataset['target_names'][prediction]
array(['setosa'], dtype='<U10')
predict 方法会返回一个标签值,通过标签值,则可获得其对应的品种名称
第四步:评估模型
调用测试集,对测试数据中的每朵鸢尾花进行预测,并将预测结果与标签(一直的品种)进行对比。我们可以通过计算精度来衡量模型的优劣,精度就是品种预测正确的花所占的比例
y_pred = knn.predict(X_test)
y_pred
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 2])
那么,测试返回的分类集合,与原始的分类是否一致呢?这里需要将 y_pred 与 y_test 进行对比
np.mean(y_pred==y_test)
0.9736842105263158
或者直接调用knn的score方法来计算精度
knn.score(X_test,y_test)
0.9736842105263158
可以看出,测试返回的结果中,与原始分类集合具有97%的相似度。
以上便是机器学习的基本流程。O(∩_∩)