torchvision.models.segmentation.
2021-03-11 本文已影响0人
blair_liu
随便一个位置
from torchvision.models.segmentation.segmentation import fcn_resnet50
跳转到fcn_resnet50
def fcn_resnet50(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs):
"""
:param pretrained: 是否下载预训练权重
:param progress: 是否显示下载进度条
:param num_classes: 类别数
:param aux_loss:是否有辅助损失
:param kwargs:额外参数
:return:fcn_resnet50模型
"""
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
_load_model加载模型
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
"""
:param arch_type: 模型名称 'fcn'
:param backbone: 模型主干 resnet50
:param pretrained: 是否下载预训练权重
:param progress: 是否显示下载进度条
:param num_classes: 类别数
:param aux_loss: 是否有辅助损失
:param kwargs: 额外参数
:return: fcn模型
"""
if pretrained: # 如果下载预训练权重,就有辅助损失
aux_loss = True
model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs) # 获取分割模型,具体见下一段_segm_resnet
if pretrained:
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls[arch]
if model_url is None: # 如果没找到预训练权重,就报错
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress) # 下载预训练权重
model.load_state_dict(state_dict) # 模型加载预训练权重
return model
_segm_resnet
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
"""
:param name: 模型名称 'fcn'
:param backbone_name: 模型主干 resnet50
:param num_classes: 类别数
:param aux: 是否有辅助损失
:param pretrained_backbone: 是否有模型主干预训练权重
:return: fcn模型
"""
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True])
"""
resnet.__dict__包含了resnet所有变量,函数和类,下面代码自行验证
from torchvision.models import resnet
for key, value in resnet.__dict__.items():
print(key, value)
print('-'*50)
此处等效于:
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
"""
return_layers = {'layer4': 'out'} # B * 2048 * 7 * 7
if aux:
return_layers['layer3'] = 'aux' # B * 1014 * 14 * 14
# 将resnet50裁剪成我们需要的
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) # 中间层获取函数
# IntermediateLayerGetter 看它的介绍里面的Examples就很清楚了
# 返回的是一个字典 backbone输出{'out':layer4输出, 'aux':layer3输出}
aux_classifier = None
if aux:
inplanes = 1024
aux_classifier = FCNHead(inplanes, num_classes) # 说是头有点不合理,其实是接在resnet50主干网络后面,可以叫主体
# FCNHead主要作用是将主干网络的输出变为类别数
model_map = { # 这个model_map的主要目的就是适应不同的模型
'deeplabv3': (DeepLabHead, DeepLabV3),
'fcn': (FCNHead, FCN),
}
inplanes = 2048
classifier = model_map[name][0](inplanes, num_classes) # 等效于:FCNHead(inplanes, num_classes)
base_model = model_map[name][1] # 等效于FCN,其实FCN是_SimpleSegmentationModel
model = base_model(backbone, classifier, aux_classifier) # model = FCN(backbone, classifier, aux_classifier)
return model
FCNHead
class FCNHead(nn.Sequential): # FCN主体
def __init__(self, in_channels, channels):
inter_channels = in_channels // 4
layers = [
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(inter_channels, channels, 1)
]
super(FCNHead, self).__init__(*layers)
FCN即_SimpleSegmentationModel
class _SimpleSegmentationModel(nn.Module):
__constants__ = ['aux_classifier']
def __init__(self, backbone, classifier, aux_classifier=None):
super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone # 主干网络输出
self.classifier = classifier # 分割网络输出
self.aux_classifier = aux_classifier # 辅助分割网络输出
def forward(self, x):
input_shape = x.shape[-2:] # H W
# contract: features is a dict of tensors
features = self.backbone(x) # 输出{'out':layer4输出, 'aux':layer3输出}
result = OrderedDict()
x = features["out"]
x = self.classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result["out"] = x
if self.aux_classifier is not None:
x = features["aux"]
x = self.aux_classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result["aux"] = x
return result # 字典