最简单的神经网络

2021-02-04  本文已影响0人  rivrui

最简单的神经网络

准备训练数据

随机生成50个数据,用作训练数据

x = np.linspace(1, 100)

最简单的神经网络拟合,y=ax+b,所有设置y为

y = 2 * x + 3

不过,为了符合实际情况,可用适当增加一些噪声。

noise = torch.randn(50)
y = 2 * x + 3+noise.numpy()

绘制x,y的图象如下


构建神经网络

torch里面可以基于nn.Module类写自己的神经网络,这里使用最简单的线性层。

class nn(nn.Module):
    def __init__(self, in_features=1, mid_features=5, out_features=1):
        super(nn, self).__init__()
        self.layer1 = nn.Linear(in_features, mid_features)
        self.layer2 = nn.Linear(mid_features, out_features)
        self.layer = nn.Linear(mid_features, mid_features)

    def forward(self, x):
        x = self.layer2(x)
        for i in range(1):
            x = self.layer(x)
        x = self.layer2(x)
        return x

之后则是设置损失函数,优化器,依旧选择最简单的。

criterion = nn.L1Loss()
optimizer = optim.RMSprop(model.parameters())

其中L1形式的损失函数就在lasso loss,loss=(y-X\theta)+C|\theta|
RMSProp算法的全称叫 Root Mean Square Prop。
考虑到训练时1-100,那么预测则选取50-150。
迭代计算,结果如下:


全部代码
x = np.linspace(1, 100)
noise = torch.randn(50)
y = 2 * x + 3+noise.numpy()
plt.plot(x, y)
plt.show()
dataset = []
for i, j in zip(x, y):
    dataset.append([i, j])
epochs = 10
model = Nn(1, 1)
criterion = nn.L1Loss()
optimizer = optim.RMSprop(model.parameters())
dataset = torch.tensor(dataset, dtype=torch.float, requires_grad=True)
for times in range(epochs):
    for i, data in enumerate(dataset, 0):
        x, label = data
        optimizer.zero_grad()
        out = model(x.unsqueeze(dim=0))
        loss = criterion(out, label.unsqueeze(dim=0))
        print("loss:", loss.data.item())
        loss.backward()
        optimizer.step()
x = np.linspace(50, 150)
y = 2 * x + 3
dataset = []
for i, j in zip(x, y):
    dataset.append([i, j])
dataset = torch.tensor(dataset, dtype=torch.float, requires_grad=True)
pred_y = []
pred_x = x
for i, data in enumerate(dataset, 0):
    x, label = data
    optimizer.zero_grad()
    out = model(x.unsqueeze(dim=0))
    pred_y.append(out)
plt.plot(pred_x, pred_y)
print(pred_x, pred_y)
plt.show()

有时间给出最简单神经网络的解析

上一篇下一篇

猜你喜欢

热点阅读