Pytorch 自定义模型参数及更新

2020-12-30  本文已影响0人  廿怎么念

被这个问题困扰了很久,用nn.parameter()定义了参数,但该参数没有更新,.grad() 为none, is_leaf 为False, 其了个怪了。原来是在参数初始化的时候没有正确初始化,我好菜~~~~~。

先看正确的例子
import torch
import torch.nn as nn

class Mask(nn.Module):
    def __init__(self):
        super(Mask, self).__init__()
        self.weight = (torch.nn.Parameter(data=torch.Tensor(1, 1, 1, 1), requires_grad=True))

        self.weight.data.uniform_(-1, 1)
        print(self.weight.is_leaf)
    def forward(self, x):
        masked_wt = (self.weight.mul(1)).cuda()
        return masked_wt


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.Mask = Mask()

    def forward(self, x):
        x = Mask(x)
        return x


model = Model()

for name, param in model.named_parameters():
    print(name, param)

输出为

True
Mask.weight Parameter containing:
tensor([[[[0.7625]]]], requires_grad=True)
错误例子A如下, 在初始化的时候使用了.cuda,把参数加载到GPU
import torch
import torch.nn as nn

class Mask(nn.Module):
    def __init__(self):
        super(Mask, self).__init__()
        self.weight = (torch.nn.Parameter(data=torch.Tensor(1, 1, 1, 1), requires_grad=True)).cuda

        self.weight.data.uniform_(-1, 1)
        print(self.weight.is_leaf)
    def forward(self, x):
        masked_wt = (self.weight.mul(1)).cuda
        return masked_wt


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.Mask = Mask()

    def forward(self, x):
        x = Mask(x)
        return x


model = Model()

for name, param in model.named_parameters():
    print(name, param)
错误例子B:使用了view函数,改变了该参数
import torch
import torch.nn as nn

class Mask(nn.Module):
    def __init__(self):
        super(Mask, self).__init__()
        self.weight = (torch.nn.Parameter(data=torch.Tensor(1, 1, 1, 1), requires_grad=True)).view(-1, 1)

        self.weight.data.uniform_(-1, 1)
        print(self.weight.is_leaf)
    def forward(self, x):
        masked_wt = (self.weight.mul(1)).cuda
        return masked_wt


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.Mask = Mask()

    def forward(self, x):
        x = Mask(x)
        return x


model = Model()

for name, param in model.named_parameters():
    print(name, param)

以上两个例子均导致模型参数初始化错误,没有把改参数加到模型的parameter pool 里,而是把该参数当做一个intermediate variable,所以在更新之后,grad被清零了,变成了none.

上一篇下一篇

猜你喜欢

热点阅读