torchvision.models.segmentation.

2021-03-11  本文已影响0人  blair_liu

随便一个位置

from torchvision.models.segmentation.segmentation import fcn_resnet50

跳转到fcn_resnet50

def fcn_resnet50(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs):
    """
    :param pretrained: 是否下载预训练权重
    :param progress: 是否显示下载进度条
    :param num_classes: 类别数
    :param aux_loss:是否有辅助损失
    :param kwargs:额外参数
    :return:fcn_resnet50模型
    """
    """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)

_load_model加载模型

def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
    """
    :param arch_type: 模型名称 'fcn'
    :param backbone: 模型主干 resnet50
    :param pretrained: 是否下载预训练权重
    :param progress: 是否显示下载进度条
    :param num_classes: 类别数
    :param aux_loss: 是否有辅助损失
    :param kwargs: 额外参数
    :return: fcn模型
    """
    if pretrained:  # 如果下载预训练权重,就有辅助损失
        aux_loss = True
    model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)  # 获取分割模型,具体见下一段_segm_resnet
    if pretrained:
        arch = arch_type + '_' + backbone + '_coco'
        model_url = model_urls[arch]
        if model_url is None:  # 如果没找到预训练权重,就报错
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)  # 下载预训练权重
            model.load_state_dict(state_dict)  # 模型加载预训练权重
    return model

_segm_resnet

def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
    """
    :param name: 模型名称 'fcn'
    :param backbone_name: 模型主干 resnet50
    :param num_classes: 类别数
    :param aux: 是否有辅助损失
    :param pretrained_backbone: 是否有模型主干预训练权重
    :return: fcn模型
    """
    backbone = resnet.__dict__[backbone_name](
        pretrained=pretrained_backbone,
        replace_stride_with_dilation=[False, True, True])
    """
        resnet.__dict__包含了resnet所有变量,函数和类,下面代码自行验证
        from torchvision.models import resnet
        for key, value in resnet.__dict__.items():
            print(key, value)
            print('-'*50)
        此处等效于:
        backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
    """
    return_layers = {'layer4': 'out'}  # B * 2048 * 7 * 7
    if aux:
        return_layers['layer3'] = 'aux'  # B * 1014 * 14 * 14
    # 将resnet50裁剪成我们需要的
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)  # 中间层获取函数
    # IntermediateLayerGetter 看它的介绍里面的Examples就很清楚了
    # 返回的是一个字典 backbone输出{'out':layer4输出, 'aux':layer3输出}

    aux_classifier = None
    if aux:
        inplanes = 1024
        aux_classifier = FCNHead(inplanes, num_classes)  # 说是头有点不合理,其实是接在resnet50主干网络后面,可以叫主体
        # FCNHead主要作用是将主干网络的输出变为类别数

    model_map = {  # 这个model_map的主要目的就是适应不同的模型
        'deeplabv3': (DeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
    inplanes = 2048
    classifier = model_map[name][0](inplanes, num_classes)  # 等效于:FCNHead(inplanes, num_classes)
    base_model = model_map[name][1]  # 等效于FCN,其实FCN是_SimpleSegmentationModel

    model = base_model(backbone, classifier, aux_classifier)  # model = FCN(backbone, classifier, aux_classifier)
    return model

FCNHead

class FCNHead(nn.Sequential):  # FCN主体
    def __init__(self, in_channels, channels):
        inter_channels = in_channels // 4
        layers = [
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, channels, 1)
        ]

        super(FCNHead, self).__init__(*layers)

FCN即_SimpleSegmentationModel

class _SimpleSegmentationModel(nn.Module):
    __constants__ = ['aux_classifier']

    def __init__(self, backbone, classifier, aux_classifier=None):
        super(_SimpleSegmentationModel, self).__init__()
        self.backbone = backbone  # 主干网络输出
        self.classifier = classifier  # 分割网络输出
        self.aux_classifier = aux_classifier  # 辅助分割网络输出

    def forward(self, x):
        input_shape = x.shape[-2:]  # H W
        # contract: features is a dict of tensors
        features = self.backbone(x)  # 输出{'out':layer4输出, 'aux':layer3输出}

        result = OrderedDict()
        x = features["out"]
        x = self.classifier(x)
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        result["out"] = x

        if self.aux_classifier is not None:
            x = features["aux"]
            x = self.aux_classifier(x)
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
            result["aux"] = x

        return result  # 字典
上一篇下一篇

猜你喜欢

热点阅读