PyTorch如何确定全连接的参数

2020-09-04  本文已影响0人  geekboys

如何确定全连接的参数

虽然目前使用全连接层的网络模型越来越少,但是仍有部分网络需要全连接层,但是如果通过CNN计算图片的输出尺寸可以说有点复杂。现在就使用PyTorch自带的功能来实现这个计算,可以说非常简单。首先,我们先定义如下的网络:

class LinearDemo(nn.Module):
    def __init__(self):
        super(LinearDemo,self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(3,96,kernel_size=11,stride=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),

            nn.Conv2d(96,256,kernel_size=5,padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),

            nn.Conv2d(256,384,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384,384,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384,256,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2)

        )

上面代码中的基本组件这里就不多赘述了,下面正常书写全连接层如下:

self.fc=nn.Sequential(
#         nn.Linear(???,4096)
#     )

其中???就是我们需要计算的参数值,如果通过层的关系进行计算则很容易出错。这里推荐使用PyTorch自带的forward方法进行推算。我们写forward方法如下:

def forward(self,x):
        x=self.conv(x)
        print(x.size())

这里我们可以在main方法中进行调用后,就可以输出该参数。main方法如下:

net=LinearDemo()
data_input=torch.randn(1,3,80,280)
print(data_input.size())
net(data_input)

这样就将上面的参数输出了。非常的简单

上一篇下一篇

猜你喜欢

热点阅读