python

利用kmeans对图像色彩进行聚类

2020-08-07  本文已影响0人  ThompsonHen
原图.png 结果.png
import cv2
import numpy as np
import matplotlib.pylab as plt
import random
def get_euclidian_distance(vec1, vec2):

    return np.linalg.norm(vec1 - vec2)


def get_centroids(centroids, vector):
    minimum_distance = float('inf')
    centroid_index = -1
    for index, centroid in enumerate(centroids):

        distance = get_euclidian_distance(centroid, vector)
        if distance < minimum_distance:
            minimum_distance = distance
            centroid_index = index

    return centroid_index + 1

def k_means(vector_set, centroid_num, iteration_times=20):

    group_tag = np.zeros(len(vector_set))
    centroids = []
    #初始化聚类中心
    for i in range(centroid_num):

        #r_channel = random.randint(0, 255)
        #g_channel = random.randint(0, 255)
        #b_channel = random.randint(0, 255)

        a = random.randint(0, len(vector_set))

        centroid = np.array(vector_set[a])

        centroids.append(centroid)

    for j in range(iteration_times):
        print("开始第{}轮聚类".format(str(j)))

        #开始一轮新的分类
        for index, vector in enumerate(vector_set):

            centroid_index = get_centroids(centroids, vector)

            group_tag[index] = centroid_index

        #重新寻找聚类中心
        centroid_sum = np.zeros((centroid_num, 3))  # 记录属于每个中心的向量的和
        vector_nbr = np.zeros(centroid_num)  # 记录属于每个向量中心的向量的个数
        for index, vector in enumerate(vector_set):

            group = int(group_tag[index] - 1)

            temp = centroid_sum[int(group), :]
            temp = temp + np.array(vector)

            centroid_sum[group] = temp
            vector_nbr[group] += 1

        for k in range(centroid_num):

            centroid_sum[k] /= vector_nbr[k]

        centroids = centroid_sum
    return group_tag, centroids




def main(image, nbr_class=3, iterations=5):

    pixel_set = []
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):

            pixel_vector = image[i, j, :]
            pixel_set.append(pixel_vector)

    group_tag, centroids = k_means(pixel_set, nbr_class, iterations)

    group_tag = group_tag.reshape((image.shape[0], image.shape[1]))



    for k in range(nbr_class):
        image_matrix = np.zeros((image.shape[0], image.shape[1], 3))

        for i in range(image.shape[0]):
            for j in range(image.shape[1]):
                c = group_tag[i][j] - 1
                if k == group_tag[i][j] - 1:
                    image_matrix[i, j, :] = image[i, j, :]
        np.save("{}.npy".format(str(int(k))), image_matrix)


    total_row = nbr_class / 2
    plt.subplot(total_row + 1, 2, 1)
    plt.imshow(image)
    plt.axis("off")
    plt.title("original image")
    for i in range(nbr_class):

        row = i / 3
        col = i % 3
        image_matrix = np.load("{}.npy".format(str(int(i))))
        plt.subplot(total_row + 1 , 2, i + 2)
        plt.axis("off")
        plt.title("part{}".format(str(int(i+1))))
        plt.imshow(image_matrix / 255)
    plt.show()




if __name__ == '__main__':

    image = cv2.imread("11.png")
    channel_1 = image[:, :, 0]
    channel_2 = image[:, :, 1]
    channel_3 = image[:, :, 2]
    image = cv2.merge([channel_3, channel_2, channel_1])
    plt.imshow(image)
    plt.show()

    main(image, nbr_class=3, iterations=20)


上一篇下一篇

猜你喜欢

热点阅读