torch.nn.Module | 所有模型都继承自该类

2019-12-22  本文已影响0人  yuanCruise

torch.nn.Module:

所有神经网络模块的基类。您的模型也应该继承这个类。模块还可以包含其他模块,允许将它们嵌套在树结构中。可以将子模块分配为常规属性。

https://pytorch.org/docs/stable/nn.html?highlight=load_state_dict#torch.nn.Module.load_state_dict

1.state_dict
state_dict(destination=None, prefix='', keep_vars=False)

example:

>>> module.state_dict().keys()
['bias', 'weight']

返回一个包含模块完整状态的字典。包括参数和持久缓冲区(例如,运行的平均值)。键是对应的参数和缓冲区名称。
返回值:包含模块的整个状态的字典.

2.load_state_dict
load_state_dict(state_dict, strict=True)

将state_dict中的参数和缓冲区复制到此模块及其后代中。如果strict为真,则state_dict的键必须与该模块的state_dict()函数返回的键完全匹配。

上一篇下一篇

猜你喜欢

热点阅读