few-shot mini-imagenet实现t-sne可视化

2020-07-25  本文已影响0人  vieo

参考

# main
data, _ = [_.cuda() for _ in batch]  # 遍历batch
data_support, data_query = data[:p_support], data[p_support:]  # [150,3,84,84]
labels_support = torch.arange(way).repeat(shot)

emb_support = model(data_support)  # [150,1600]
proto_support = emb_support.reshape(shot, way, -1).mean(0)
labels_proto = torch.arange(way)  # 构造假标签[01234,01234,...]

tsne(emb_support, proto_support, labels_support, labels_proto)
# tsne
def tsne(training_feature, proto_feature, train_label, proto_label):
    """
    :param training_feature:[shot*way,1600]
    :param test_feature:
    :param train_label:
    """
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    from matplotlib import pyplot as plt
    import random
    size_train = training_feature.size()[0]
    size_proto = proto_feature.size()[0]
    size_sum = size_train + size_proto

    training_feature = training_feature.cpu().detach().numpy()
    proto_feature = proto_feature.cpu().detach().numpy()
    train_label = train_label.numpy()
    proto_label = proto_label.numpy()
    # t-SNE
    # tsne_2D = TSNE(n_components=2, perplexity=50, n_iter=1000, learning_rate=200,
    #                n_iter_without_progress=10).fit_transform(training_feature)
    # training_feature_tsne_2D = tsne_2D
    tsne_2D = TSNE(n_components=2, perplexity=50, n_iter=1000, learning_rate=200,
                   n_iter_without_progress=10).fit_transform(np.concatenate((training_feature, proto_feature)))
    training_feature_tsne_2D = tsne_2D[0:size_train, :]
    proto_feature_tsne_2D = tsne_2D[size_train:size_sum, :]

    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(1, 1, 1)
    colors = ['r', 'g', 'b']
    markers = ['+', '+', '+']
    classes = np.sort(np.unique(train_label))
    labels = ['0', '1', '2']

    for class_ix, marker, color, label in zip(classes, markers, colors, labels):
        ax.scatter(training_feature_tsne_2D[np.where(train_label == class_ix), 0],
                   training_feature_tsne_2D[np.where(train_label == class_ix), 1],
                   marker=marker, color=color,
                   linewidth='1', alpha=0.9, label=label, )
        # ax.legend(loc='best')

    markers = ['o', 'o', 'o']
    # markers = ['o', 'P', 'v']
    for class_ix, marker, color, label in zip(classes, markers, colors, labels):
        ax.scatter(proto_feature_tsne_2D[np.where(proto_label == class_ix), 0],
                   proto_feature_tsne_2D[np.where(proto_label == class_ix), 1],
                   marker=marker, color=color,
                   linewidth='5', alpha=0.9, label=label)
    title = 'title'
    plt.title(title)
    # plt.show()

    save_path = './home/...'
    item = random.randint(0, 200)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    plt.savefig(save_path + '/' + str(item) + '.png')
    plt.close()
上一篇下一篇

猜你喜欢

热点阅读