K-means报告(模式识别3)
2018-01-17 本文已影响59人
小火伴
K-Means-1.png
K-Means-2.png
K-Means-3.png
K-Means-4.png
程序
# coding: utf-8
# # 第三次模式识别作业
# In[1]:
get_ipython().magic('matplotlib inline')
# In[2]:
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import numpy as np
# In[ ]:
K=3
iris = load_iris()
X = iris.data
Y = iris.target
# # 随机洗牌数据
# In[13]:
shuffle_para=np.arange(Y.shape[0])
np.random.shuffle(shuffle_para)
X,Y=X[shuffle_para],Y[shuffle_para]
# # 每次随机一样
# In[ ]:
np.random.seed(980406)
# # 分类
# In[ ]:
cla=[]
for i in range(K):
cla.append(np.where(Y==i))
# # 初始点
# In[14]:
initial_point=X[np.random.randint(0,X.shape[0],(3,))]
initial_point
# In[15]:
mean_point=initial_point
# In[16]:
print(X.shape)
# # 开始迭代
# In[17]:
accu=[]
n=0
while True:
# 计算到k个中心的欧氏距离
distances=[]
for p in mean_point:
distances.append(np.linalg.norm((X-p),axis=1))
pass
distances=np.array(distances)
y=np.argmin(distances,0)
y=np.array(y,dtype=int)
# 保存上次点
last_point=mean_point
# 生成新点
mean_point=[]
for i in range(K):
mean_point.append(np.mean(X[(y==i),:],axis=0))
mean_point=np.array(mean_point)
J=np.linalg.norm(last_point-mean_point,axis=1)
# 每一个都是<0.01
if False not in list(J<0.001):
break
pass
if(n==20):
print('到达最大迭代次数')
break
# 看把原始数据的每一类还保留多少个为一类
corr=0
for c in cla:
corr+=(max(np.bincount(y[c])))
accu.append(corr/Y.shape[0])
print(accu[-1])
n+=1
pass
# # 画图
# In[18]:
plt.ylim([0.6,1])
plt.xticks(list(range(n)), rotation=20)
plt.xlabel('Interations')
plt.ylabel('Accuracy')
plt.plot(np.arange(n),accu)
# In[19]:
mean_point.shape
# In[20]:
label=(('Sepal length','Sepal width'),('Petal length','Petal width'))
def scat(i):
plt.scatter(X[:, i*2], X[:,2*(i+1)-1], c=y,marker='+')
plt.scatter(mean_point[:,i*2],mean_point[:,(i+1)*2-1],c=np.arange(K),marker='o')
plt.xlabel(label[i][0])
plt.ylabel(label[i][1])
i=0
scat(i)
# In[21]:
scat(1)