基本图形绘制数据可视化

python3画hist图和曲线图(tSNE原理)

2020-08-19  本文已影响0人  RedStones

# coding:utf-8

import numpy as np

from numpy.linalg import norm

from matplotlib import pyplot as plt

plt.style.use('ggplot')

def sne_crowding():

    npoints = 1000 # 抽取1000个m维球内均匀分布的点

    plt.figure(figsize=(20, 5))

    for i, m in enumerate((2, 3, 5, 8)):

        # 这里模拟m维球中的均匀分布用到了拒绝采样,

        # 即先生成m维立方中的均匀分布,再剔除m维球外部的点

        accepts = []

        while len(accepts) < 1000:

            points = np.random.rand(500, m)

            accepts.extend([d for d in norm(points, axis=1)

                            if d <= 1.0]) # 拒绝采样

        accepts = accepts[:npoints]

        ax = plt.subplot(1, 4, i+1)

        if i == 0:

            ax.set_ylabel('count')

        if i == 2:

            ax.set_xlabel('distance')

        ax.hist(accepts, bins=np.linspace(0., 1., 50))

        ax.set_title('m=%s' %m)

    plt.savefig("sne_crowding.png")

def sne_norm_t_dist_cost():

    x = np.linspace(0, 4, 100)

    ta = 1 / (1 + np.square(x))

    tb = np.sum(ta) - 1

    qa = np.exp(-np.square(x))

    qb = np.sum(qa) - 1

    #

    plt.figure(figsize=(8, 5))

    plt.plot(qa/qb, c="b", label="normal-dist")

    plt.plot(ta/tb, c="g", label="t-dist")

    plt.plot((0, 20), (0.025, 0.025), 'r--')

    plt.text(10, 0.022, r'$q_{ij}$')

    plt.text(20, 0.026, r'$p_{ij}$')

    plt.plot((0, 55), (0.005, 0.005), 'r--')

    plt.text(36, 0.003, r'$q_{ij}$')

    plt.text(55, 0.007, r'$p_{ij}$')

    plt.title("probability of distance")

    plt.xlabel("distance")

    plt.ylabel("probability")

    plt.legend()

    plt.savefig("sne_norm_t_dist_cost.png")

if __name__ == '__main__':

    sne_crowding()

    sne_norm_t_dist_cost()

    print('==end==')

ref: https://www.jianshu.com/p/631d6529e0df

上一篇下一篇

猜你喜欢

热点阅读