Pytorch 之 模型的保存与调用

2021-08-19  本文已影响0人  Allard_c205

介绍关于用pytorch搭建模型时,对模型进行保存以及再次调用模型参数的相关函数命令。

使用 torch.save(model.state_dict(), PATH)来保存模型学习到的参数,给模型恢复提供最大的灵活性。

先对模型进行实例化,再用load_state_dict()调用模型,在对模型进行推理之前,调用model.eval():

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))  #该函数只接收字典对象,而不是保存对象的路径,在这之前要反序列化保存的state_dict。

model.eval()


torch.load( f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>,  **pickle_load_args)  从文件加载用torch.save()保存的对象。  目前需要知道该函数前两个参数的正确使用即可

f: 类似于文件的对象,或包含文件名称的字符串,如:要载入的模型所在的完整路径的字符串

map_location: 一个函数,torch.device,字符串或字典,明确如何重映射存储空间位置

pickle_module:用于解开元数据和对象的模块(必须与序列化文件的pickle_module相匹配)

pickle_load_args:(只有Python3才有)可选择的关键字参数,并传递给pickle_module.load()和pickle_module.Unpickler(),比如,errors=...。

上一篇下一篇

猜你喜欢

热点阅读