PyTorch中如何加载子模块的权重

2021-01-02  本文已影响0人  WritingHere

假设我们在深度学习模型中有一个这样的需求:主要模型A中包含子模块B,而模型B可以通过一定的方式得到一个预训练的权重,模型A需要利用B模型的权重,在此基础上继续训练。
首先我们到官网上去寻找,PyTorch官网上给出了一些保存和加载模型的示例,可以说非常全面总结了模型保存和加载的方法和主义事项,https://pytorch.org/tutorials/beginner/saving_loading_models.html。但是这里的方案都是针对一个完整模型的保存和加载的,不能满足我们这个需求。
因此需要基于此做一些改进,具体如代码所示:

import torch
import torch.nn as nn

class ModelA(nn.Module):
    def __init__(self):
        super(ModelA, self).__init__()
        self.A = nn.Linear(2, 3)
    
    def forward(self, A):
        pass


class ModelB(nn.Module):
    def __init__(self):
        super(ModelB, self).__init__()
        self.model_a = ModelA()
        self.A = nn.Linear(2, 3)
    
    def forward(self, x):
        pass

print("Model")
modelA = ModelA()
modelA_dict = modelA.state_dict()
print('-' * 80)
for key in sorted(modelA_dict.keys()):
    parameter = modelA_dict[key]
    print(key)
    print(parameter.size())
    print(parameter)
modelB = ModelB()
modelB_dict = modelB.state_dict()
print('-'*80)
for key in sorted(modelB_dict.keys()):
    print('-'*20)
    parameter = modelB_dict[key]
    print(type(key), key)
    print(parameter.size())
    print(parameter)
    print('-'*20)
print('-'*80)

pretrained_dict = modelA_dict
model_dict = modelB_dict

pretrained_dict = {'model_a.' + k: v for k, v in pretrained_dict.items() if 'model_a.' + k in model_dict}
model_dict.update(pretrained_dict)

modelB.load_state_dict(model_dict)
modelB_dict = modelB.state_dict()
for key in sorted(modelB_dict.keys()):
    parameter = modelB_dict[key]
    print(key)
    print(parameter.size())
    print(parameter)
上一篇下一篇

猜你喜欢

热点阅读