Pytorch 自定义模型参数及更新
2020-12-30 本文已影响0人
廿怎么念
被这个问题困扰了很久,用nn.parameter()定义了参数,但该参数没有更新,.grad() 为none, is_leaf 为False, 其了个怪了。原来是在参数初始化的时候没有正确初始化,我好菜~~~~~。
先看正确的例子
import torch
import torch.nn as nn
class Mask(nn.Module):
def __init__(self):
super(Mask, self).__init__()
self.weight = (torch.nn.Parameter(data=torch.Tensor(1, 1, 1, 1), requires_grad=True))
self.weight.data.uniform_(-1, 1)
print(self.weight.is_leaf)
def forward(self, x):
masked_wt = (self.weight.mul(1)).cuda()
return masked_wt
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.Mask = Mask()
def forward(self, x):
x = Mask(x)
return x
model = Model()
for name, param in model.named_parameters():
print(name, param)
输出为
True
Mask.weight Parameter containing:
tensor([[[[0.7625]]]], requires_grad=True)
错误例子A如下, 在初始化的时候使用了.cuda,把参数加载到GPU
import torch
import torch.nn as nn
class Mask(nn.Module):
def __init__(self):
super(Mask, self).__init__()
self.weight = (torch.nn.Parameter(data=torch.Tensor(1, 1, 1, 1), requires_grad=True)).cuda
self.weight.data.uniform_(-1, 1)
print(self.weight.is_leaf)
def forward(self, x):
masked_wt = (self.weight.mul(1)).cuda
return masked_wt
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.Mask = Mask()
def forward(self, x):
x = Mask(x)
return x
model = Model()
for name, param in model.named_parameters():
print(name, param)
错误例子B:使用了view函数,改变了该参数
import torch
import torch.nn as nn
class Mask(nn.Module):
def __init__(self):
super(Mask, self).__init__()
self.weight = (torch.nn.Parameter(data=torch.Tensor(1, 1, 1, 1), requires_grad=True)).view(-1, 1)
self.weight.data.uniform_(-1, 1)
print(self.weight.is_leaf)
def forward(self, x):
masked_wt = (self.weight.mul(1)).cuda
return masked_wt
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.Mask = Mask()
def forward(self, x):
x = Mask(x)
return x
model = Model()
for name, param in model.named_parameters():
print(name, param)
以上两个例子均导致模型参数初始化错误,没有把改参数加到模型的parameter pool 里,而是把该参数当做一个intermediate variable,所以在更新之后,grad被清零了,变成了none.