数字图像处理与计算机视觉(python)深度学习-推荐系统-CV-NLP机器学习与计算机视觉

resnet18实现dual pooling

2019-11-28  本文已影响0人  诗人藏夜里

max pooling更注重重要的局部特征
average pooling更关注全局特征
两者concat可以丰富特征层

from torchvision.models import resnet18
import torch.nn as nn
import torch

class res18(nn.Module):
    def __init__(self, num_classes):
        super(res18, self).__init__()
        self.base = resnet18(pretrained=False)
        self.feature = nn.Sequential(
            self.base.conv1,
            self.base.bn1,
            self.base.relu,
            self.base.maxpool,
            self.base.layer1,
            self.base.layer2,
            self.base.layer3,
            self.base.layer4          #输出512通道
        )
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))  #自适应平均池化(batch, 512, 1, 1)
        self.max_pool = nn.AdaptiveMaxPool2d((1,1))  #自适应最大池化(batch, 512, 1, 1)
        self.reduce_layer = nn.Conv2d(1024, 512, 1) #1*1卷积进行降维
        self.fc  = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)    #fc层
            )
    def forward(self, x):
        bs = x.shape[0]   #batch size
        x = self.feature(x)    # 输出512通道
        avgpool_x = self.avg_pool(x)   #输出(batch, 512, 1, 1)
        maxpool_x = self.max_pool(x)   #输出(batch,512, 1, 1)
        x = torch.cat([avgpool_x, maxpool_x], dim=1)  #输出(batch, 1024, 1, 1)
        x = self.reduce_layer(x).view(bs, -1)    #输出[batch, 512])
        logits = self.fc(x)    #输出(batch,num_classes)
        return logits

参考:https://zhuanlan.zhihu.com/p/93806755

上一篇 下一篇

猜你喜欢

热点阅读