Pytorch: 手动修改模型参数

2019-09-19  本文已影响0人  Gavin先生

在 Pytorch 框架下,如何手动修改训练的模型的参数?
我们以两个模型参数加权平均为例,步骤如下:
Step 1: Load Model

model0 = torch.load('*.pth')
model1 = torch.load('*.pth')

Step 2: Parse Parameters

param0 = model0.state_dict()
param1 = model1.state_dict()

Step 3: Modified Parameters

param_new = {}
for key in param0.keys():
    paranew[key] = (param0[key] + param1[key])/2

Step 4: Generate New Model

model2 = model1
model2.load_state_dict(param_new)
param2 =  model2.net.state_dict()
上一篇下一篇

猜你喜欢

热点阅读