k-means算法
2020-03-05 本文已影响0人
就是果味熊
#%%
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
#%%
def assignment(df,centroids,colmap):
for i in centroids.keys():
df['distance_from_{}'.format(i)] = (
np.sqrt(
(df['x'] - centroids[i][0])**2 + (df['y'] - centroids[i][1])**2
)
)
distance_from_centroid_id = ['distance_from_{}'.format(i) for i in centroids.keys()]
# df.loc切片操作,idxmin返回最小值的索引(取决于比较的axis)
df['closest'] = df.loc[:, distance_from_centroid_id].idxmin(axis=1)
# lstrip()截掉字符串左边的空格或指定字符
df['closest'] = df['closest'].map(lambda x : int(x.lstrip('distance_from_')))
df['color'] = df['closest'].map(lambda x : colmap[x])
return df
def update(df,centroids):
# recalculate the centroids
for i in centroids.keys():
centroids[i][0] = np.mean(df[df['closest'] == i]['x'])
centroids[i][1] = np.mean(df[df['closest'] == i]['y'])
return centroids
#%%
def main():
df = pd.DataFrame({
'x' : [12, 20, 28, 18, 10, 29, 33, 24, 45, 45, 52, 51, 52, 55, 53, 55, 61, 64, 69, 72, 23],
'y' : [39, 36, 30, 52, 54, 20, 46, 55, 59, 63, 70, 66, 63, 58, 23, 14, 8, 19, 7, 24, 77]
}
)
k = 3
# randomly choose centroids
centroids = {
i : [np.random.randint(0,80), np.random.randint(0,80)] for i in range(k)
}
colmap = {0:'r', 1:'g', 2:'b'}
# print(df)
# print(centroids)
df = assignment(df,centroids,colmap)
plt.scatter(df['x'], df['y'], color=df['color'], alpha=0.5, edgecolors='k')
for i in centroids.keys():
plt.scatter(*centroids[i],color=colmap[i],linewidth=6)
plt.xlim(0, 80)
plt.ylim(0, 80)
plt.show()
for i in range(10):
plt.close()
closest_centroids = df['closest'].copy(deep=True)
centroids = update(df, centroids)
plt.scatter(df['x'],df['y'],color=df['color'],alpha=0.5,edgecolors='k')
for i in centroids.keys():
plt.scatter(*centroids[i],color=colmap[i],linewidths=6)
plt.xlim(0,80)
plt.ylim(0,80)
plt.show()
df = assignment(df,centroids,colmap)
if closest_centroids.equals(df['closest']):
break
if __name__ == '__main__':
main()