forward函数实现

2019-08-03  本文已影响0人  钢笔先生

Time: 2019-08-03
视频地址:https://youtu.be/MasG7tZj-hw?list=PLZbbT5o_s2xrfNyHZsM6ufI0iZENK9xgG

简单网络复习

class Network(nn.Module):
  def __init__(self):
    super(Network, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
    self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
    self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
    self.fc2 = nn.Linear(in_features=120, out_features=60)
    self.out = nn.Linear(in_features=10 out_features=60)
  
  def forward(self, t):
    return t

现在开始实现网络的前向函数组装。

forward函数的实现

在构造函数中定义的层,相当于函数,也就是我们定义了5层,实际上是5个函数,现在我们需要在forward函数中有机地将它们组合起来。

def forward(self, t):
  # (1) input
  t = t

  # (2) hidden conv layer
  t = self.conv1(t)
  t = F.relu(t)
  t = F.max_pool2d(t, kernel_size=2, strides=2)

  # (3) hidden conv layer
  t = self.conv2(t)
  t = F.relu(t)
  t = F.max_pool2d(t, kernel_size=2, strides=2)

  # (4) hidden linear layer
  t = t.reshape(-1, 12*4*4)
  t = self.fc1(t)
  t = F.relu(t)

  # (5) hidden linear layer
  t = self.fc2(t)
  t = F.relu(t)

  # (6) output layer
  t = self.out(t)
  t = F.softmax(t, dim=1)

  return t

这里,不含weight的用F.xxx, 其中F是import torch.nn.functional as F

END.

上一篇 下一篇

猜你喜欢

热点阅读