操练代码之Kmeans聚类分析

2024-08-21  本文已影响0人  万州客

这个好像没有用pytorch建网络,但使用了kmeans_pytorch包,不知道如何实现的。

一,代码

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from kmeans_pytorch import kmeans
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import time

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

plant = pd.read_csv('./聚类分析/plant.csv')
plant_d = plant[['Sepal_Length', 'Sepal_Width','Petal_Length', 'Petal_Width']]
plant['target'] = plant['Species']
x = torch.from_numpy(np.array(plant_d))
y = torch.from_numpy(np.array(plant.target))

num_clusters = 3
cluster_ids_x, cluster_centers = kmeans(
    X=x,
    num_clusters=num_clusters,
    distance='euclidean',
    device=device
)

print(cluster_ids_x)
print(cluster_centers)

plt.figure(figsize=(4, 3), dpi=160)
plt.scatter(x[:, 0], x[:, 1], c=cluster_ids_x, cmap='cool', marker='D')
plt.scatter(
    cluster_centers[:, 0], cluster_centers[:, 1],
    c='white',
    alpha=0.6,
    edgecolors='black',
    linewidths=2
)

plt.tight_layout()
plt.show()


二,截图


2024-08-21 11_05_33-ch2 – 333.py.png
上一篇下一篇

猜你喜欢

热点阅读