pytorch获取中间变量的梯度
2019-10-04 本文已影响0人
顾北向南
https://oldpan.me/archives/pytorch-autograd-hook
https://blog.csdn.net/loseinvain/article/details/99172594
本文仅作为学术分享,如果侵权,会删文处理
1. pytorch中对非叶节点的变量计算梯度
- 在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点才需要去更新),这样可以节省很大部分的显存,但是在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯度,为了实现这个目的,我们可以通过两种手段进行
2.retain_grad()
- Tensor.retain_grad()显式地保存非叶节点的梯度,当然代价就是会增加显存的消耗,而用hook函数的方法则是在反向计算时直接打印,因此不会增加显存消耗,但是使用起来retain_grad()要比hook函数方便一些。代码如:
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)
> tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
3. hook
grads = {}
def save_grad(name):
def hook(grad):
grads[name] = grad
return hook
x = Variable(torch.randn(1,1), requires_grad=True)
y = 3*x
z = y**2
# In here, save_grad('y') returns a hook (a function) that keeps 'y' as name
y.register_hook(save_grad('y'))
z.register_hook(save_grad('z'))
z.backward()
print(grads['y'])
print(grads['z'])
import torch
import torch.nn as nn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class MyMul(nn.Module):
def forward(self, input):
out = input * 2
return out
class MyMean(nn.Module): # 自定义除法module
def forward(self, input):
out = input/4
return out
def tensor_hook(grad):
print('tensor hook')
print('grad:', grad)
return grad
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.f1 = nn.Linear(4, 1, bias=True)
self.f2 = MyMean()
self.weight_init()
def forward(self, input):
self.input = input
output = self.f1(input) # 先进行运算1,后进行运算2
output = self.f2(output)
return output
def weight_init(self):
self.f1.weight.data.fill_(8.0) # 这里设置Linear的权重为8
self.f1.bias.data.fill_(2.0) # 这里设置Linear的bias为2
def my_hook(self, module, grad_input, grad_output):
print('doing my_hook')
print('original grad:', grad_input)
print('original outgrad:', grad_output)
# grad_input = grad_input[0]*self.input # 这里把hook函数内对grad_input的操作进行了注释,
# grad_input = tuple([grad_input]) # 返回的grad_input必须是tuple,所以我们进行了tuple包装。
# print('now grad:', grad_input)
return grad_input
if __name__ == '__main__':
input = torch.tensor([1, 2, 3, 4], dtype=torch.float32, requires_grad=True).to(device)
net = MyNet()
net.to(device)
net.register_backward_hook(net.my_hook) # 这两个hook函数一定要result = net(input)执行前执行,因为hook函数实在forward的时候进行绑定的
input.register_hook(tensor_hook)
result = net(input)
print('result =', result)
result.backward()
print('input.grad:', input.grad)
for param in net.parameters():
print('{}:grad->{}'.format(param, param.grad))
result = tensor([ 20.5000], device='cuda:0')
doing my_hook
original grad: (tensor([ 0.2500], device='cuda:0'),)
original outgrad: (tensor([ 1.], device='cuda:0'),)
tensor hook
grad: tensor([ 2., 2., 2., 2.], device='cuda:0')
input.grad: None
Parameter containing:
tensor([[ 8., 8., 8., 8.]], device='cuda:0'):grad->tensor([[ 0.2500, 0.5000, 0.7500, 1.0000]], device='cuda:0')
Parameter containing:
tensor([ 2.], device='cuda:0'):grad->tensor([ 0.2500], device='cuda:0')