4. pytorch-简单分类
2018-07-01 本文已影响0人
FantDing
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.h1 = torch.nn.Linear(2, 10)
# self.a1=torch.nn.ReLU()
self.h2 = torch.nn.Linear(10, 2)
def forward(self, x):
# x=self.a1(self.h1(x)) # 使用这种写法,在print(net)时,会显示relu层
x = F.relu(self.h1(x)) # 这种写法不会显示relu层,对于无状态的函数,推荐使用F形式(简单)
x = self.h2(x)
return x
if __name__ == "__main__":
# test()
# 1. 数据准备
x0 = torch.normal(torch.ones(100, 2) * 1, 1)
y0 = torch.zeros(100, 1)
x1 = torch.normal(torch.ones(100, 2) * -2, 1)
y1 = torch.ones(100, 1)
x = torch.cat((x0, x1), dim=0)
y = torch.cat((y0, y1), dim=0).long().squeeze()
# 可视化数据
# plt.scatter(x0.numpy()[:, 0], x0.numpy()[:, 1], c="red", label="negtive")
# plt.scatter(x1.numpy()[:, 0], x1.numpy()[:, 1], c="green", label="positive")
# plt.legend()
# plt.show()
# 2. 定义网络
net = Net()
# 3. 训练
optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_F = torch.nn.CrossEntropyLoss()
for iter in range(100):
pred = net(x)
# arg1: 二维的原始输出(没有加softmax)
# arg2: 一维的batch_size的真实cls label(不是二维one hot编码)
loss = loss_F(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
plt.ion()
# 计算精度
if iter % 2 == 0:
print("*" * 8, "iter:", iter, "*" * 8)
loss_str = "loss: {:.4f}".format(loss.data.numpy())
print(loss_str)
probability = F.softmax(pred, dim=1)
pred_cls = torch.argmax(probability, dim=1)
equal = (pred_cls == y) # 返回的torch元素为0或1,不是bool类型
accuracy = torch.sum(equal).data.numpy() / 200
print("accuracy:", accuracy)
# 画图
plt.cla()
plt.scatter(x.numpy()[:, 0], x.numpy()[:, 1], c=pred_cls)
plt.text(1, -3.5, "accuracy:{}".format(accuracy))
plt.pause(0.2)
plt.ioff()
plt.show()
image.png