Pytorch 之 模型的保存与调用
2021-08-19 本文已影响0人
Allard_c205
介绍关于用pytorch搭建模型时,对模型进行保存以及再次调用模型参数的相关函数命令。
使用 torch.save(model.state_dict(), PATH)来保存模型学习到的参数,给模型恢复提供最大的灵活性。
先对模型进行实例化,再用load_state_dict()调用模型,在对模型进行推理之前,调用model.eval():
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH)) #该函数只接收字典对象,而不是保存对象的路径,在这之前要反序列化保存的state_dict。
model.eval()
torch.load( f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, **pickle_load_args) 从文件加载用torch.save()保存的对象。 目前需要知道该函数前两个参数的正确使用即可。
f: 类似于文件的对象,或包含文件名称的字符串,如:要载入的模型所在的完整路径的字符串
map_location: 一个函数,torch.device,字符串或字典,明确如何重映射存储空间位置
pickle_module:用于解开元数据和对象的模块(必须与序列化文件的pickle_module相匹配)
pickle_load_args:(只有Python3才有)可选择的关键字参数,并传递给pickle_module.load()和pickle_module.Unpickler(),比如,errors=...。