Pytorch: 加载部分权重

2019-11-21  本文已影响0人  wzNote

直接加载所有权重

适用于直接使用别人的模型和权重

model=CNN()
model.load_state_dict(torch.load('cnn.pth'))

加载部分权重

适用于对别人的模型做了适当的修改

model=CNN()
pretrained_dict = torch.load('cnn.pth'))
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} #用于过滤掉修改结构处的权重
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
上一篇 下一篇

猜你喜欢

热点阅读