4. pytorch-简单分类

2018-07-01  本文已影响0人  FantDing
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.h1 = torch.nn.Linear(2, 10)
        # self.a1=torch.nn.ReLU()
        self.h2 = torch.nn.Linear(10, 2)

    def forward(self, x):
        # x=self.a1(self.h1(x)) # 使用这种写法,在print(net)时,会显示relu层
        x = F.relu(self.h1(x))  # 这种写法不会显示relu层,对于无状态的函数,推荐使用F形式(简单)
        x = self.h2(x)
        return x


if __name__ == "__main__":
    # test()
    # 1. 数据准备
    x0 = torch.normal(torch.ones(100, 2) * 1, 1)
    y0 = torch.zeros(100, 1)
    x1 = torch.normal(torch.ones(100, 2) * -2, 1)
    y1 = torch.ones(100, 1)

    x = torch.cat((x0, x1), dim=0)
    y = torch.cat((y0, y1), dim=0).long().squeeze()
    # 可视化数据
    # plt.scatter(x0.numpy()[:, 0], x0.numpy()[:, 1], c="red", label="negtive")
    # plt.scatter(x1.numpy()[:, 0], x1.numpy()[:, 1], c="green", label="positive")
    # plt.legend()
    # plt.show()

    # 2. 定义网络
    net = Net()
    # 3. 训练
    optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
    loss_F = torch.nn.CrossEntropyLoss()
    for iter in range(100):
        pred = net(x)
        # arg1: 二维的原始输出(没有加softmax)
        # arg2: 一维的batch_size的真实cls label(不是二维one hot编码)
        loss = loss_F(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        plt.ion()
        # 计算精度
        if iter % 2 == 0:
            print("*" * 8, "iter:", iter, "*" * 8)
            loss_str = "loss: {:.4f}".format(loss.data.numpy())
            print(loss_str)
            probability = F.softmax(pred, dim=1)
            pred_cls = torch.argmax(probability, dim=1)
            equal = (pred_cls == y)  # 返回的torch元素为0或1,不是bool类型
            accuracy = torch.sum(equal).data.numpy() / 200
            print("accuracy:", accuracy)

            # 画图
            plt.cla()
            plt.scatter(x.numpy()[:, 0], x.numpy()[:, 1], c=pred_cls)
            plt.text(1, -3.5, "accuracy:{}".format(accuracy))
            plt.pause(0.2)

    plt.ioff()
    plt.show()
image.png
上一篇下一篇

猜你喜欢

热点阅读