01 手工就近原则实现一个简单的鸢尾花分类器

2018-08-16  本文已影响45人  夏威夷的芒果
基础知识 问题描述
任务描述 数据集 思路

人工智能数据源下载地址,下载压缩包后解压即可.
小脚本下载地址

原理

数据集分开,一部分用来训练集合,一部分作为测试集合,测试集合里面每一条用来与训练集合中的元素比对,就近标类。

代码

import pandas as pd
import ai_utils
from sklearn.model_selection import train_test_split
from scipy.spatial.distance import euclidean
import numpy as np

#读取文件
data_file = '/Users/miraco/PycharmProjects/ai/data_ai/Iris.csv'

#种类
species = ['Iris-setosa',
           'Iris-versicolor',
           'Iris-virginica'
           ]
#特征
feat_cols = ['SepalLengthCm','SepalWidthCm','PetalLengthCm','PetalWidthCm']

def get_pred_label(test_sample_feat, train_data):
    #近朱者赤,,找最近距离的样本,取其标签作为预测样本的标签
    dis_list = []
    for idx, row in train_data.iterrows():
        #训练样本特征
        train_sample_feat = row[feat_cols].values
        #计算当前条目和样本集合之间的距离
        dis = euclidean(test_sample_feat, train_sample_feat)
        dis_list.append(dis)

    #最小距离对应的位置
    pos = np.argmin(dis_list)
    #离谁最近就算成谁
    pred_label = train_data.iloc[pos]['Species']
    return pred_label


#读取数据

iris_data = pd.read_csv(data_file, index_col = 'Id')

#eda

ai_utils.do_eda_plot_for_iris(iris_data)

# 划分数据集
#三分之一作为训练集
train_data, test_data = train_test_split(iris_data, test_size= 1/3 , random_state= 10)

# 预测对的个数
acc_count = 0

# 分类器

for idx,row in test_data.iterrows():
    # 测试样本特征
    test_sample_feat = row[feat_cols].values

    # 预测值
    pred_label = get_pred_label(test_sample_feat, train_data)

    # 真实值
    true_label = row['Species']
    print(f'样本{idx}的真实标签是{true_label},预测标签是{pred_label}')
    if true_label == pred_label:
        acc_count += 1


# 准确率
accuracy  = acc_count / test_data.shape[0]
print('预测准确率{:2f}%'.format(accuracy*100))

运行结果

样本88的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本112的真实标签是Iris-virginica,预测标签是Iris-virginica
样本11的真实标签是Iris-setosa,预测标签是Iris-setosa
样本92的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本50的真实标签是Iris-setosa,预测标签是Iris-setosa
样本61的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本73的真实标签是Iris-versicolor,预测标签是Iris-virginica
样本68的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本40的真实标签是Iris-setosa,预测标签是Iris-setosa
样本56的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本67的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本143的真实标签是Iris-virginica,预测标签是Iris-virginica
样本54的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本2的真实标签是Iris-setosa,预测标签是Iris-setosa
样本20的真实标签是Iris-setosa,预测标签是Iris-setosa
样本113的真实标签是Iris-virginica,预测标签是Iris-virginica
样本86的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本39的真实标签是Iris-setosa,预测标签是Iris-setosa
样本22的真实标签是Iris-setosa,预测标签是Iris-setosa
样本36的真实标签是Iris-setosa,预测标签是Iris-setosa
样本103的真实标签是Iris-virginica,预测标签是Iris-virginica
样本133的真实标签是Iris-virginica,预测标签是Iris-virginica
样本127的真实标签是Iris-virginica,预测标签是Iris-virginica
样本25的真实标签是Iris-setosa,预测标签是Iris-setosa
样本62的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本3的真实标签是Iris-setosa,预测标签是Iris-setosa
样本96的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本91的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本77的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本118的真实标签是Iris-virginica,预测标签是Iris-virginica
样本59的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本98的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本130的真实标签是Iris-virginica,预测标签是Iris-virginica
样本115的真实标签是Iris-virginica,预测标签是Iris-virginica
样本147的真实标签是Iris-virginica,预测标签是Iris-virginica
样本48的真实标签是Iris-setosa,预测标签是Iris-setosa
样本125的真实标签是Iris-virginica,预测标签是Iris-virginica
样本121的真实标签是Iris-virginica,预测标签是Iris-virginica
样本119的真实标签是Iris-virginica,预测标签是Iris-virginica
样本142的真实标签是Iris-virginica,预测标签是Iris-virginica
样本27的真实标签是Iris-setosa,预测标签是Iris-setosa
样本44的真实标签是Iris-setosa,预测标签是Iris-setosa
样本60的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本42的真实标签是Iris-setosa,预测标签是Iris-setosa
样本57的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本33的真实标签是Iris-setosa,预测标签是Iris-setosa
样本53的真实标签是Iris-versicolor,预测标签是Iris-versicolor
样本71的真实标签是Iris-versicolor,预测标签是Iris-virginica
样本122的真实标签是Iris-virginica,预测标签是Iris-virginica
样本145的真实标签是Iris-virginica,预测标签是Iris-virginica
预测准确率96.000000%
运行的图

复习需要注意的地方:

from sklearn.model_selection import train_test_split
train_test_split(train_data,train_target,test_size=0.3, random_state=0)

参数解释:
train_data:被划分的样本特征集
train_target:被划分的样本标签
test_size:如果是浮点数,在0-1之间,表示样本占比;如果是整数,就是样本的数量
random_state:是随机数的种子。随机数种子其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:

  1. 种子不同,产生不同的随机数;
  2. 种子相同,即使实例不同也产生相同的随机数。
for idx,row in test_data.iterrows():
    # 测试样本特征
    test_sample_feat = row[feat_cols].values
from scipy.spatial.distance import euclidean
dis = euclidean(test_sample_feat, train_sample_feat) 

练习:手工实现一个简单的水果识别器

参考答案

import pandas as pd
from sklearn.model_selection import train_test_split
from scipy.spatial.distance import euclidean
import numpy as np


#特征文字

feat_cols =['mass','width','height','color_score']

#读取数据

data = pd.read_csv('/Users/miraco/PycharmProjects/ai/data_ai/fruit_data.csv')

#划分数据

train_set, test_set = train_test_split(data, random_state = 10, test_size= 0.4)

#计算结果

acc_count = 0  # 预测对的个数

for idx, row in test_set.iterrows():
    #提取每一行的各特征的值
    test_sample_feat = row[feat_cols].values  #多维的一定写value

    #预测值

    pos = np.argmin([euclidean(test_sample_feat,train_row[feat_cols].values) for idx2, train_row in train_set.iterrows()])
    pred_label = train_set.iloc[pos]['fruit_name']

    #实际值
    real_label = row['fruit_name']

    print(f'样本{idx}的真实标签是{real_label},预测标签是{pred_label}')

    if real_label == pred_label:
        acc_count += 1

# 准确率
accuracy  = acc_count / test_set.shape[0]
print('预测准确率{:2f}%'.format(accuracy*100))

运行结果

/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
  return f(*args, **kwds)
样本31的真实标签是orange,预测标签是lemon
样本3的真实标签是mandarin,预测标签是mandarin
样本38的真实标签是orange,预测标签是orange
样本27的真实标签是orange,预测标签是lemon
样本21的真实标签是apple,预测标签是apple
样本17的真实标签是apple,预测标签是apple
样本46的真实标签是lemon,预测标签是lemon
样本2的真实标签是apple,预测标签是apple
样本23的真实标签是apple,预测标签是apple
样本26的真实标签是orange,预测标签是orange
样本35的真实标签是orange,预测标签是apple
样本39的真实标签是orange,预测标签是orange
样本20的真实标签是apple,预测标签是orange
样本37的真实标签是orange,预测标签是orange
样本7的真实标签是mandarin,预测标签是mandarin
样本6的真实标签是mandarin,预测标签是mandarin
样本45的真实标签是lemon,预测标签是orange
样本56的真实标签是lemon,预测标签是lemon
样本47的真实标签是lemon,预测标签是lemon
样本10的真实标签是apple,预测标签是orange
样本44的真实标签是lemon,预测标签是lemon
样本54的真实标签是lemon,预测标签是lemon
样本18的真实标签是apple,预测标签是apple
样本4的真实标签是mandarin,预测标签是mandarin
预测准确率75.000000%

Process finished with exit code 0

这个警告的原因是是各种库之间的版本不匹配,只需要把numpy的版本降到1.14.5就可以了。

sudo pip uninstall numpy
sudo pip install numpy==1.14.5

我懒得理他,就这样吧。

上一篇下一篇

猜你喜欢

热点阅读