迁移学习pytorch

2020-11-28  本文已影响0人  HelloSam

以VGG网络为例:
1、只调整一层:以后禁止使用这种写法

model = torchvision.models.vgg16(pretrained=False)
vgg16pth = 'G:/other/data/pytorch-dataset/model_weight/vgg16-397923af.pth'
model.load_state_dict(torch.load(vgg16pth))
# 冻结卷积层的参数
for params in model.features.parameters():
    params.requires_grad = False
# 微调model.classifier部分
model.classifier[-1].out_features = 5
model = model.to(device)

2、调整整个classifier层:要调整把整个分类层都要调整一下

# 微调model.classifier部分
fc_inputs_num = model.classifier[0].in_features
fc_inputs_num

model.classifier = nn.Sequential(
    nn.Linear(fc_inputs_num, 4096),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5),
    nn.Linear(4096, 1024),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5),
    nn.Linear(1024, 5),
)
model = model.to(device)
上一篇 下一篇

猜你喜欢

热点阅读