PyTorch 实现简单的二分类器

2022-04-21  本文已影响0人  刘小白DOER

    今天测试一个使用PyTorch 来完成sklearn的moon数据分类的案例。

    本文基于MyML/pytorch moons.py at master · prudvinit/MyML (github.com)

1、生成moon数据,一共 200 个样本,并使用scatter画出散点图

2、将样本数据从 numpy 转成 tensor

3、构建全连接的神经网络,网络包含一个输入层,一个中间层,一个输出层。中间层包含 3 个神经元,使用的激活函数是 tanh,softmax函数计算概率得分,根据大小判断为0或者1。

整个网络连接情况如下:

4、损失函数用 CrossEntropyLoss,梯度优化器使用 Adam

5、开始训练及计算training error,accuracy_score得分0.97

6、根据loss画出曲线

    plt.plot(losses,linewidth=1)

7、更直观地展示分类结果,将结果可视化

上一篇下一篇

猜你喜欢

热点阅读