pytorch获取中间变量的梯度

2019-10-04  本文已影响0人  顾北向南

https://oldpan.me/archives/pytorch-autograd-hook
https://blog.csdn.net/loseinvain/article/details/99172594
本文仅作为学术分享,如果侵权,会删文处理

1. pytorch中对非叶节点的变量计算梯度

2.retain_grad()

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)

> tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])

3. hook

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook

x = Variable(torch.randn(1,1), requires_grad=True)
y = 3*x
z = y**2

# In here, save_grad('y') returns a hook (a function) that keeps 'y' as name
y.register_hook(save_grad('y'))
z.register_hook(save_grad('z'))
z.backward()

print(grads['y'])
print(grads['z'])
import torch
import torch.nn as nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class MyMul(nn.Module):
    def forward(self, input):
        out = input * 2
        return out

class MyMean(nn.Module):            # 自定义除法module
    def forward(self, input):
        out = input/4
        return out

def tensor_hook(grad):
    print('tensor hook')
    print('grad:', grad)
    return grad

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.f1 = nn.Linear(4, 1, bias=True)    
        self.f2 = MyMean()
        self.weight_init()

    def forward(self, input):
        self.input = input
        output = self.f1(input)       # 先进行运算1,后进行运算2
        output = self.f2(output)      
        return output

    def weight_init(self):
        self.f1.weight.data.fill_(8.0)    # 这里设置Linear的权重为8
        self.f1.bias.data.fill_(2.0)      # 这里设置Linear的bias为2

    def my_hook(self, module, grad_input, grad_output):
        print('doing my_hook')
        print('original grad:', grad_input)
        print('original outgrad:', grad_output)
        # grad_input = grad_input[0]*self.input   # 这里把hook函数内对grad_input的操作进行了注释,
        # grad_input = tuple([grad_input])        # 返回的grad_input必须是tuple,所以我们进行了tuple包装。
        # print('now grad:', grad_input)        

        return grad_input

if __name__ == '__main__':

    input = torch.tensor([1, 2, 3, 4], dtype=torch.float32, requires_grad=True).to(device)

    net = MyNet()
    net.to(device)

    net.register_backward_hook(net.my_hook)   # 这两个hook函数一定要result = net(input)执行前执行,因为hook函数实在forward的时候进行绑定的
    input.register_hook(tensor_hook)
    result = net(input)

    print('result =', result)

    result.backward()

    print('input.grad:', input.grad)
    for param in net.parameters():
        print('{}:grad->{}'.format(param, param.grad))
result = tensor([ 20.5000], device='cuda:0')
doing my_hook
original grad: (tensor([ 0.2500], device='cuda:0'),)
original outgrad: (tensor([ 1.], device='cuda:0'),)
tensor hook
grad: tensor([ 2., 2., 2., 2.], device='cuda:0')
input.grad: None
Parameter containing:
tensor([[ 8., 8., 8., 8.]], device='cuda:0'):grad->tensor([[ 0.2500, 0.5000, 0.7500, 1.0000]], device='cuda:0')
Parameter containing:
tensor([ 2.], device='cuda:0'):grad->tensor([ 0.2500], device='cuda:0')
上一篇下一篇

猜你喜欢

热点阅读