tf 实现 K-means

2019-08-19  本文已影响0人  cookyo
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets.samples_generator import make_blobs
from sklearn.datasets.samples_generator import make_circles

K = 4 # 类别数目
MAX_ITERS = 1000 # 最大迭代次数
N = 200 # 样本点数目

centers = [[-2, -2], [-2, 1.5], [1.5, -2], [2, 1.5]] # 簇中心

# 生成人工数据集
#data, features = make_circles(n_samples=200, shuffle=True, noise=0.1, factor=0.4)
data, features = make_blobs(n_samples=N, centers=centers, n_features = 2, cluster_std=0.8, shuffle=False, random_state=42)
print(data)
print(features)

# 计算类内平均值函数
def clusterMean(data, id, num):
    total = tf.unsorted_segment_sum(data, id, num) # 第一个参数是tensor,第二个参数是簇标签,第三个是簇数目
    count = tf.unsorted_segment_sum(tf.ones_like(data), id, num)
    return total/count

# 构建graph
points = tf.Variable(data)
cluster = tf.Variable(tf.zeros([N], dtype=tf.int64))
centers = tf.Variable(tf.slice(points.initialized_value(), [0, 0], [K, 2]))# 将原始数据前k个点当做初始中心
repCenters = tf.reshape(tf.tile(centers, [N, 1]), [N, K, 2]) # 复制操作,便于矩阵批量计算距离
repPoints = tf.reshape(tf.tile(points, [1, K]), [N, K, 2])
sumSqure = tf.reduce_sum(tf.square(repCenters-repPoints), reduction_indices=2) # 计算距离
bestCenter = tf.argmin(sumSqure, axis=1)  # 寻找最近的簇中心
change = tf.reduce_any(tf.not_equal(bestCenter, cluster)) # 检测簇中心是否还在变化
means = clusterMean(points, bestCenter, K)  # 计算簇内均值
# 将粗内均值变成新的簇中心,同时分类结果也要更新
with tf.control_dependencies([change]):
    update = tf.group(centers.assign(means),cluster.assign(bestCenter)) # 复制函数 

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    changed = True
    iterNum = 0
    while changed and iterNum < MAX_ITERS:
        iterNum += 1
        # 运行graph
        [changed, _] = sess.run([change, update])
        [centersArr, clusterArr] = sess.run([centers, cluster])
        print(clusterArr)
        print(centersArr)

        # 显示图像
        fig, ax = plt.subplots()
        ax.scatter(data.transpose()[0], data.transpose()[1], marker='o', s=100, c=clusterArr)
        plt.plot()
        plt.show()

这里需要注意的地方有:

1、unsorted_segment_sum函数是用来分割求和的,第二个参数就是分割的index,index相同的作为一个整体求和。
2、计算距离的时候使用了矩阵的批量运算,因此看起来不太直观,稍微推导一下就明白了。
3、tf.control_dependencies用来控制op运行顺序,只有检测类中心还在变化,再完成之后的更新操作。
4、tf.group是封装多个操作的函数。
5、画图函数内置在了训练过程中,因此每一轮迭代的结果都有显示,这是个很小的demo,因此迭代几轮后就可以收敛了。

上一篇下一篇

猜你喜欢

热点阅读