人工智能/模式识别/机器学习精华专题机器学习与数据挖掘Python语言与信息数据获取和机器学习

统计学习方法之K近邻法

2017-12-08  本文已影响110人  J_101

1.k近邻法(k-nearest neighbor,k-NN)

2.k近邻模型

2.1距离度量

2.2 k值的选择

近似误差、估计误差知乎解释

2.3分类决策规则

3.k近邻算法的实现

3.1简单实现

hei,wei,tag
1.5,40,thin
1.5,50,fat
1.5,60,fat
1.6,40,thin
1.6,50,thin
1.6,60,fat
1.6,70,fat
1.7,50,thin
1.7,60,thin
1.7,70,fat
1.7,80,fat
1.8,60,thin
1.8,70,thin
1.8,80,fat
1.8,90,fat
1.9,80,thin
1.9,90,fat
# -*- coding: utf-8 -*-
"""
Created on Fri Dec  8 17:21:14 2017

@author: jasonhaven
"""

import os
import numpy as np
import pandas as pd
import operator
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

def read_from_csv(file):
    '''
    file:文件绝对地址
    功能:读入csv文件并解析出数据集和标签集
    '''
    pwd=os.getcwd()
    os.chdir(os.path.dirname(file))
    df=pd.read_csv('data.csv')
    os.chdir(pwd)
    datas=df.iloc[:,:2].astype(np.float).values
    labels = df.iloc[0:,-1:].astype(np.str).values#加载类别标签部分
    return datas,labels


def classify(instance,datas,labels,k):
    '''
    instance:新的实例特征向量
    datas:训练数据集
    labels:标签集
    k:选择相邻的k个实例
    '''
    num=datas.shape[0]
    #tile(A, reps)返回一个shape=reps的矩阵,矩阵的每个元素是A
    diffMat = np.tile(instance, (num, 1)) - datas
    #diffMat就是输入样本与每个训练样本的差值
    square_diffMat = diffMat**2
    #然后对其每个x和y的差值进行平方运算。
    square_distances=square_diffMat.sum(axis=1)
    #开平方根求出距离
    distances=square_distances**0.5
    #argsort函数返回的是关键字(数组值)从小到大的索引值
    sorted_distances = distances.argsort()
    
    class_count = {}
    # 投票过程,就是统计前k个最近的样本所属类别包含的样本个数
    for i in range(k):
        # sortedDistIndicies[i]是第i个最相近的样本下标
        voteIlabel = str(labels[sorted_distances[i]])
        # 然后将票数增1
        class_count[voteIlabel] = class_count.get(voteIlabel, 0) + 1
    # 把分类结果进行排序,然后返回得票数最多的分类结果
    sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_count[0][0]
    

def draw(datas,labels):
    plt.figure('KNN')
    plt.title('KNN')
    plt.xlabel('height')
    plt.ylabel('weight')
    
    green_patch=mpatches.Patch(color='green', label='thin')
    red_patch=mpatches.Patch(color='red', label='fat')
    handles=[red_patch,green_patch]
    plt.legend(handles=handles)
    for i,x in enumerate(datas):
        if labels[i]=='thin':
            plt.scatter(x[0],x[1],s=100,marker='.',c='g')
        else:
            plt.scatter(x[0],x[1],s=100,marker='x',c='r')
    plt.show()


if __name__=='__main__':
    #获取数据集
    file='./data.csv'#data.csv : 身高,体重,标签
    datas,labels=read_from_csv(file)
    labels=list(labels)
    #新实例
    instance=[1.7,60]
    k=2
    #分类
    label=classify(instance,datas,labels,k)
    draw(datas,labels)
    print("knn classify : %s's label is %s"%(str(instance),label))
    

3.2运行结果


作者:Jasonhaven.D
链接:http://www.jianshu.com/u/ed031e432b82
來源:简书
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

上一篇下一篇

猜你喜欢

热点阅读