用pytorch实现LeNet

2020-03-16  本文已影响0人  不分享的知识毫无意义

这篇文章是简单了解pytorch的工作流,发现pytorch和tensorflow不一样的地方,具体的小知识我会列出来,大家可以有针对性的看一下。

1.LeNet简介

这是一个简单的基于卷积神经网络的“深层”网络,网络结构如下:


LeNet网络结构

看见这个图是不是很眼熟啊,任何一本教科书上应该都有,网络其实由五层组成:

2.pytorch知识点

3.LeNet的简单实现

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V

class LeNet(nn.Module):
    #必须继承自nn.moddule类,是nn最基本的类,可以是一个tensor或者是一个tensor的集合
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1,6,(5, 5))#输入通道数,输出通道数
        self.conv2 = nn.Conv2d(6,16,(5, 5))
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)),(2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)),(2, 2))
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x
        
def output_name_and_params(net):
    for name, parameters in net.named_parameters():#是个字典存名字,返回参数名和参数取值
        print('name:{},param:{}'.format(name,parameters))
        
        
if __name__ == '__main__':
    net = LeNet()
    print('net:{}'.format(net))
    params = net.parameters()
    print('params:{}'.format(params))
    output_name_and_params(net)
    input_image = t.FloatTensor(10,1,28,28)
    input_image = V(input_image)
    output = net(input_image)
    print('output:{}'.format(output))
    print('output.size:{}'.format(output.size()))
上一篇下一篇

猜你喜欢

热点阅读