计算机视觉学习分享

【NAS工具箱】Pytorch中的Buffer&Paramete

2021-05-27  本文已影响0人  pprpp

Parameter : 模型中的一种可以被反向传播更新的参数。

第一种:

def __init__(self):
    super(MyModel, self).__init__()
    self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量

或者:

def __init__(self):
    super(MyModel, self).__init__()
    param = nn.Parameter(torch.randn(3, 3))  # 普通 Parameter 对象
    self.register_parameter("param", param)

Buffer : 模型中不能被反向传播算法更新的参数。

def __init__(self):
    super(MyModel, self).__init__()
    buffer = torch.randn(2, 3)  # tensor
    self.register_buffer('my_buffer', buffer)
    self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量

总结:

上一篇下一篇

猜你喜欢

热点阅读