pytorch 定义一个网络

2018-12-12  本文已影响0人  Zeke_Wang

声明一个关于网络的类

import torch.nn as nn
class NetName(nn.Module):
    def __init__(self):
        super(NetName, self).__init__()

        nn.module1 = ...
        nn.module2 = ...
        nn.module3 = ...
    
    def forward(self,x):
        x = self.module1(x)
        x = self.module1(x)
        x = self.module2(x)
        x = self.module3(x)
        return x

其中在构造函数__init__中构造这个NN中需要使用的各种模块(module),比如:参数完全相同的maxpooling声明为一个模块,或者例如在CV任务中,把feature_extraction的网络和classification的网络分别声明。
forward函数用于声明各个模块间的关系。即,连接整个网络。

net = NetName().to(device) # 创建网络,并放入指定的device

网络创建后,可以通过以下方式遍历模块信息:

for name, module in net._modules.items():
    print(name) # name就是__init__中的各个模块名
    print(module) # module就是各个模块内具体的层

示例:AlexNet

注释中的tensor大小变化是基于cifar10的图片----(channel=3, height=32, width=32)

import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), # (3,32,32) -> (64,8,8)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                 # (64,8,8)  -> (64,4,4)
            nn.Conv2d(64, 192, kernel_size=5, padding=2),          # (64,4,4)  -> (192,4,4)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                 # (192,4,4) -> (192,2,2)
            nn.Conv2d(192, 384, kernel_size=3, padding=1),         # (192,2,2) -> (384,2,2)
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),         # (384,2,2) -> (256,2,2)
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),         # (256,2,2) -> (256,2,2)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),                 # (256,2,2) -> (256,1,1)
        )
        
        self.classifier = nn.Linear(256, 10)                       # (batch_size,256) -> (batch_size,10)
        

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1) # flatten to (batch_size, 256*1*1)
        x = self.classifier(x)
        return x
上一篇下一篇

猜你喜欢

热点阅读