k-means算法

2020-03-05  本文已影响0人  就是果味熊
#%%
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
#%%

def assignment(df,centroids,colmap):
    for i in centroids.keys():
        df['distance_from_{}'.format(i)] = (
                np.sqrt(
                        (df['x'] - centroids[i][0])**2 + (df['y'] - centroids[i][1])**2
                        )
                )
    distance_from_centroid_id = ['distance_from_{}'.format(i) for i in centroids.keys()]
    # df.loc切片操作,idxmin返回最小值的索引(取决于比较的axis)
    df['closest'] = df.loc[:,  distance_from_centroid_id].idxmin(axis=1)
    # lstrip()截掉字符串左边的空格或指定字符
    df['closest'] = df['closest'].map(lambda x : int(x.lstrip('distance_from_')))
    df['color'] = df['closest'].map(lambda x : colmap[x])
    return df


def update(df,centroids):
    # recalculate the centroids
    for i in centroids.keys():
        centroids[i][0] = np.mean(df[df['closest'] == i]['x'])
        centroids[i][1] = np.mean(df[df['closest'] == i]['y'])
    return centroids

#%%

def main():
    df = pd.DataFrame({
            'x' : [12, 20, 28, 18, 10, 29, 33, 24, 45, 45, 52, 51, 52, 55, 53, 55, 61, 64, 69, 72, 23],
            'y' : [39, 36, 30, 52, 54, 20, 46, 55, 59, 63, 70, 66, 63, 58, 23, 14, 8, 19, 7, 24, 77]
            }
            )
    
    k = 3
    # randomly choose centroids
    centroids = {
            i : [np.random.randint(0,80), np.random.randint(0,80)] for i in range(k)
            }
    colmap = {0:'r', 1:'g', 2:'b'}
#    print(df)
#    print(centroids)
    df = assignment(df,centroids,colmap)
    
    plt.scatter(df['x'], df['y'], color=df['color'], alpha=0.5, edgecolors='k')
    for i in centroids.keys():
        plt.scatter(*centroids[i],color=colmap[i],linewidth=6)
    plt.xlim(0, 80)
    plt.ylim(0, 80)
    plt.show()
    
    for i in range(10):
        plt.close()
        
        closest_centroids = df['closest'].copy(deep=True)
        centroids = update(df, centroids)
        
        plt.scatter(df['x'],df['y'],color=df['color'],alpha=0.5,edgecolors='k')
        for i in centroids.keys():
            plt.scatter(*centroids[i],color=colmap[i],linewidths=6)
        plt.xlim(0,80)
        plt.ylim(0,80)
        plt.show()
        
        df = assignment(df,centroids,colmap)
        
        if closest_centroids.equals(df['closest']):
            break
    
if __name__ == '__main__':
    main()
上一篇 下一篇

猜你喜欢

热点阅读