pytorch模型保存和加载

2019-02-24  本文已影响0人  sheng_pan_ai

模型保存

torch.save()实现对网络结构和模型参数的保存.有两种保存方式:一是保存整个神经网络的结构信息和模型参数信息.save的对象是网络net.二是只保留神经网络的训练模型参数,save的对象是net.state_dict()

torch.save('net1','model.pkl') #保留整个神经网络的结构和模型参数
torch.save(net1.state_dict(),'model.pkl') # 只保留神经网络的模型参数

模型加载

对于两种保存方式,重载也有两种方式.
对应第一种完整网络结构信息,重载的时候通过

torch.load('model.pkl')

直接初始化新的神经网络对象即可.
对应第二种只保存模型参数信息,需要首先导入对应的网络,通过

net.load_state_dict(torch.load('model.pkl'))

完成模型参数的重载
在网络比较大时,第一种方法会花费较多的时间.

上一篇下一篇

猜你喜欢

热点阅读