K均值用于聚类实践代码

2022-06-17  本文已影响0人  万州客

不讲理论,直接干。

一,代码

import numpy as np
import matplotlib.pyplot as plt


def l2(vec_x_i, vec_x_j):
    """
    计算欧氏距离
    para vecXi: 点坐标,向量
    para vecXj: 点坐标,向量
    return: 两点之间的欧氏距离
    """
    return np.sqrt(np.sum(np.power(vec_x_i - vec_x_j, 2)))


def k_means(s, k, dist_means=l2):
    """
    k均值聚类
    para s: 样酊集,多维数组
    para k: 簇个数
    para dist_means: 距离量度函数,默认为欧氏距离计算函数
    return sample_tag: 一维数组,存储样本对应的簇标记
    return cluster_cents: 一维数组, 各簇中心
    return sse: 误差平方和
    """
    # 样本总数
    m = np.shape(s)[0]
    sample_tag = np.zeros(m)
    # 随机产生k个初始簇中心
    # 样本向量的特征数
    n = np.shape(s)[1]
    cluster_cents = np.mat(np.zeros((k, n)))
    for j in range(n):
        min_j = min(s[:, j])
        range_j = float(max(s[:, j]) - min_j)
        cluster_cents[:, j] = np.mat(min_j + range_j * np.random.rand(k, 1))
    sample_tag_changed = True
    sse = 0
    # 如果没有点发生分配结果改变, 则结束
    while sample_tag_changed:
        sample_tag_changed = False
        sse = 0.0
        # 计算每个样本点到各簇中心的距离
        for i in range(m):
            min_d = np.inf
            min_index = -1
            for j in range(k):
                d = dist_means(cluster_cents[j, :], s[i, :])
                if d < min_d:
                    min_d = d
                    min_index = j
            if sample_tag[i] != min_index:
                sample_tag_changed = True
            sample_tag[i] = min_index
            sse += min_d ** 2
        # print(cluster_cents)
        # plt.scatter(s[:, 0], s[:, 1], c=sample_tag, linewidths=np.power(sample_tag + 0.5, 2))
        # plt.show()
        print(sse)

        # 重新计算簇中心
        for i in range(k):
            cluster_i = s[np.nonzero(sample_tag[:] == i)[0]]
            cluster_cents[i, :] = np.mean(cluster_i, axis=0)
    return cluster_cents, sample_tag, sse


if __name__ == '__main__':
    samples = np.loadtxt('kmeansSamples.txt')
    cluster_cents, sample_tag, sse = k_means(samples, 3)
    plt.scatter(samples[:, 0], samples[:, 1], c=sample_tag, linewidths=np.power(sample_tag+ 0.5, 2))
    plt.show()
    print(cluster_cents)
    print(sse)

kmeansSamples.txt

8.764743691132109049e+00 1.497536962729086341e+01
4.545778445909218313e+00 7.394332431706460262e+00
5.661841772908352333e+00 1.045327224311696668e+01
6.020055532521467967e+00 1.860759073162559929e+01
1.256729723000295529e+01 5.506569916803323750e+00
4.186942275051188211e+00 1.402615035721461290e+01
5.726706075832996845e+00 8.375613974148174989e+00
4.099899279500291094e+00 1.444273323355928795e+01
2.257178930021525254e+00 1.977895587652345855e+00
4.669135451288612515e+00 7.717803834787531070e-01
8.121947597697801058e+00 7.976212807755792555e-01
7.972277764807800260e-02 -1.938666197338206221e+00
8.370047062442882435e+00 1.077781799178707622e+01
6.680973199869320922e+00 1.553118858170866545e+01
5.991946943553537963e+00 1.657732863976965021e+01
5.641990155271871643e+00 1.554671013661827672e+01
-2.925147643580102041e+00 1.108844569740028163e+01
4.996949605297930752e+00 1.986732057663068707e+00
3.866584099986317025e+00 -1.752825909916766900e+00
2.626427441224858939e+00 2.208897582166075324e+01
5.656225833870900388e+00 1.477736974879376675e+01
-3.388227926726261607e-01 5.569311423852095544e+00
1.093574481611491223e+01 1.124487205516641275e+01
4.650235760178413003e+00 1.278869502885029341e+01
8.498485127403823114e+00 9.787697108749913610e+00
7.530467091751554598e+00 8.502325665434069535e+00
6.171183705302398792e+00 2.174394049079376856e+01
-9.333949569013078040e-01 1.594142490265068712e+00
-6.377004909329702542e+00 3.463894089865578341e+00
7.135980906743346175e+00 1.417794597480970609e+01

二,效果

2022-06-17 21_56_00-MessageCenterUI.png
上一篇 下一篇

猜你喜欢

热点阅读