MAC + LINUX + VIM + GIT + Latex + Markdown

机器学习-k均值聚类算法代码笔记

2017-02-08  本文已影响248人  LEONYao

最近在学机器学习,看了好多天书发现完全看不懂。然后就跟着敲代码,居然可以理解一点了。在代码上做笔记,好像理解得更多了。我把代码笔记po上来吧
test.txt,数据集

    1.658985    4.285136  
    -3.453687   3.424321  
    4.838138    -1.151539  
    -5.379713   -3.362104  
    0.972564    2.924086  
    -3.567919   1.531611  
    0.450614    -3.302219  
    -3.487105   -1.724432  
    2.668759    1.594842  
    -3.156485   3.191137  
    3.165506    -3.999838  
    -2.786837   -3.099354  
    4.208187    2.984927  
    -2.123337   2.943366  
    0.704199    -0.479481  
    -0.392370   -3.963704  
    2.831667    1.574018  
    -0.790153   3.343144  
    2.943496    -3.357075  
    -3.195883   -2.283926  
    2.336445    2.875106  
    -1.786345   2.554248  
    2.190101    -1.906020  
    -3.403367   -2.778288  
    1.778124    3.880832  
    -1.688346   2.230267  
    2.592976    -2.054368  
    -4.007257   -3.207066  
    2.257734    3.387564  
    -2.679011   0.785119  
    0.939512    -4.023563  
    -3.674424   -2.261084  
    2.046259    2.735279  
    -3.189470   1.780269  
    4.372646    -0.822248  
    -2.579316   -3.497576  
    1.889034    5.190400  
    -0.798747   2.185588  
    2.836520    -2.658556  
    -3.837877   -3.253815  
    2.096701    3.886007  
    -2.709034   2.923887  
    3.367037    -3.184789  
    -2.121479   -4.232586  
    2.329546    3.179764  
    -3.284816   3.273099  
    3.091414    -3.815232  
    -3.762093   -2.432191  
    3.542056    2.778832  
    -1.736822   4.241041  
    2.127073    -2.983680  
    -4.323818   -3.938116  
    3.792121    5.135768  
    -4.786473   3.358547  
    2.624081    -3.260715  
    -4.009299   -2.978115  
    2.493525    1.963710  
    -2.513661   2.642162  
    1.864375    -3.176309  
    -3.171184   -3.572452  
    2.894220    2.489128  
    -2.562539   2.884438  
    3.491078    -3.947487  
    -2.565729   -2.012114  
    3.332948    3.983102  
    -1.616805   3.573188  
    2.280615    -2.559444  
    -2.651229   -3.103198  
    2.321395    3.154987  
    -1.685703   2.939697  
    3.031012    -3.620252  
    -4.599622   -2.185829  
    4.196223    1.126677  
    -2.133863   3.093686  
    4.668892    -2.562705  
    -2.793241   -2.149706  
    2.884105    3.043438  
    -2.967647   2.848696  
    4.479332    -1.764772  
    -4.905566   -2.911070  

k-means.py

from numpy import *
import math
import matplotlib.pyplot as plt

def loadDataSet(filename):
    fr = open(filename)
    lines = fr.readlines()
    dataMat = []
    for line in lines:
        result = line.strip().split('   ')
        fltline = map(float,result)
        dataMat.append(fltline)
    return dataMat

def distEclud(vecA,vecB):
    return sqrt(sum(power(vecA-vecB,2)))#欧式距离

def randCent(dataSet,k):
    n = shape(dataSet)[1]#n=列数
    
    centroids = mat(zeros((k,n)))
    for j in range(n):
        minJ = min(dataSet[:,j])#一列中最小的值
        rangeJ = float(max(dataSet[:,j])-minJ)#一列中最大减去最小,取差值
        centroids[:,j] = minJ + rangeJ*random.rand(k,1)#最小值加上(差值*(0,1))把聚类点控制在范围之内
    return centroids

def kMeans(dataSet,k,distMeas = distEclud,creatCent = randCent):
    m = shape(dataSet)[0] 
    clusterAssment = mat(zeros((m,2)))# 建立80个簇分配结果矩阵,第一列存索引,第二列存距离值
    centroids = creatCent(dataSet,k)#建立聚类点,k个聚类点
    clusterChanged = True#如果簇有改变,则为真
    while clusterChanged:#循环
        clusterChanged = False
        for i in range(m):
            minDist = inf
            minIndex = -1
            for j in range(k):
                distJI = distMeas(centroids[j,:],dataSet[i,:])#计算聚类点与数据点距离,k个centroids与80个数据样本进行比较
                #k为2时候,dataSet有80行数据。每一行dataSet都与centroids[0],centroids[1]进行比较取距离最小值,标记为一个簇
                if distJI < minDist:
                    minDist = distJI#取最小值
                    minIndex = j
            if clusterAssment[i,0] != minIndex:#对比索引,如果簇是最小值,则停止更新了
                clusterChanged = True
            clusterAssment[i,:] = minIndex,minDist**2#存最新簇
        for cent in range(k):
                ptsInClust = []
 
                ptsInClust = dataSet[nonzero(clusterAssment[:,0].A== cent)[0]] # nonzeros(a==k)返回数组a中值不为k的元素的下标
                centroids[cent,:] = mean(ptsInClust,axis=0)
                #for j in range(m):
                    #if clusterAssment[j,0]==cent: clusterAssment的第一列为index,与cent进行比较,相同则为同一簇。80行clusterAssment与cent比较
                      #ptsInClust.append(dataSet[j].tolist()[0])
                #ptsInClust = mat(ptsInClust)
    return centroids,clusterAssment



然后直接进行测试吧

由于我使用的是ipython notebook 所以下面的代码不需要import任何东西了

dataMat =mat(loadDataSet('test.txt'))
myCentroids,clustAssment = kMeans(dataMat,2)


plt.figure(2) #创建图表2

ax3 = plt.subplot() # 图表2中创建子图1
plt.title("biK-means Scatter")
plt.xlabel('x')
plt.ylabel('y')

ax3.scatter(dataMat[:,0],dataMat[:,1],color='b',marker='o',s=100)
#plt.scatter(x, y, s=area, c=colors, alpha=0.5)
ax3.scatter(myCentroids[:,0],myCentroids[:,1],color='r',marker='o',s=200,label='Cluster & K=2')


#显示label位置的函数

ax3.legend(loc='upper right')
plt.show()
index.png
思路整理一下
数据集有80行坐标,建立k个点,这里我默认两个聚类点
两个聚类点,每个都要与80行坐标进行距离比较,取最小值minDist以及索引minIndex
然后接下来的
if clusterAssment[i,0] != minIndex:
    clusterChanged
如果不是同一簇,簇变动为真。这一行有点费解,分析一下:
初始化的簇分配矩阵是80行2列,值全为0的矩阵。80行都要重新分配结果,区分簇。如果某一行索引不等于最小值索引,那就是还没有进行分配。
这个while循环系统,是在进行六次以后停止的。我前面的分析被推翻了。
为什么簇会改变呢?
终于想通了!下面的聚类点centroids是通过计算均值而变化得出的。而distJI距离是通过聚类点与80行坐标进行比较得出的。
初始的聚类点centroids是随机产生,然后它每次都会更新,产生新的minDist,minIndex。
新的minIndex会与上次循环形成的簇clusterAssment的Index比较,
不是同一个簇就会clusterChanged为真,让循环继续,直到新的minIndex与上一轮的一样,则停止循环。

80行坐标,每一行都存入两个聚类点比较之后的最小值以及索引,通过索引区分不同的簇。


上一篇下一篇

猜你喜欢

热点阅读