LeNet网络参数注释
2020-11-09 本文已影响0人
周周周__
LeNet网络参数注释
from torch import nn
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__() # 例:输入数据为(1,32,32)
self.conv = nn.Sequential(
nn.Conv2d(1, 6, 5), # 输入通道数1,输出通道数6,卷积核大小5 ==> (6, 28, 28) 步长默认为1
nn.Sigmoid(),
nn.MaxPool2d(2, 2), # 经过池化输出为 (6,14,14)
nn.Conv2d(6, 16, 5), # 输入通道数6,输出通道数16, 卷积核大小5 ==>(16, 10, 10)
nn.Sigmoid(),
nn.MaxPool2d(2, 2) # 经过池化输出为(16, 5, 5)
)
self.fc = nn.Sequential(
nn.Linear(16 * 5 * 5, 120), # 输入为是上边 16*5*5
nn.Sigmoid(),
nn.Linear(120, 84),
nn.Sigmoid(),
nn.Linear(84, 10)
)
def forward(self, img):
feature = self.conv(img)
output = self.fc(feature.view(img.shape[0], -1)) # 全连接需要进行平铺
return output
图片.png