pytorch如何保存与加载训练模型
2020-05-07 本文已影响0人
一位学有余力的同学
- 保存网络结构及参数
torch.save(model,'model.pth') # 保存
model = torch.load("model.pth") # 加载
- 只保存模型参数
使用这种方法时需要创建一个和原来一模一样的模型
torch.save(model.state_dict(),"model.pth") # 保存参数
model = model() # 代码中创建网络结构
params = torch.load("model.pth") # 加载参数
model.load_state_dict(params) # 应用到网络结构中
pytorch加载官方提供预训练模型的方法请参考博客