pytorch如何保存与加载训练模型

2020-05-07  本文已影响0人  一位学有余力的同学
torch.save(model,'model.pth') # 保存
model = torch.load("model.pth") # 加载
torch.save(model.state_dict(),"model.pth") # 保存参数
model = model() # 代码中创建网络结构
params = torch.load("model.pth") # 加载参数
model.load_state_dict(params) # 应用到网络结构中

pytorch加载官方提供预训练模型的方法请参考博客

上一篇 下一篇

猜你喜欢

热点阅读