Pytorch | 自动求导机制下tensor的各个属性

2020-05-11  本文已影响0人  yuanCruise
import torch
def test_not_leaf_get_grad(t):
    t.retain_grad()
def print_isleaf(t):
    print('*'*5)
    print('{} is leaf {}!'.format(t,t.is_leaf))
def print_grad(t):
    print('{} grad is {}'.format(t,t.grad))
def double_grad(grad):
    grad=grad*2
    return grad

if __name__ == "__main__":
    #init
    input_ = torch.randn(1,requires_grad=True)
    output = input_ * input_
    output2 = output*2

    #test  retain_grad()
    test_not_leaf_get_grad(output)
    #test  register_hook()
    output2.register_hook(double_grad)

    output2.backward()

    #test  leaf
    print_isleaf(input_)
    print_grad(input_)

    #test  no leaf
    print_isleaf(output)
    print_grad(output)

# output
*****
tensor([],requires_grad=True) is leaf True !
tensor([],requires_grad=True) grad is tensor([2.])
*****
tensor([],grad_fn=<MulBackward0>) is leaf False!
tensor([],grad_fn=<MulBackward0>) grad is None

## result for retain_grad
*****
tensor([],requires_grad=True) is leaf True !
tensor([],requires_grad=True) grad is tensor([2.])
*****
tensor([],grad_fn=<MulBackward0>) is leaf False!
tensor([],grad_fn=<MulBackward0>) grad is tensor([1.])

## result for register_hook
*****
tensor([],requires_grad=True) is leaf True !
tensor([1.1],requires_grad=True) grad is tensor([8.8.],grad_fn=<CloneBackward>)
*****
tensor([],grad_fn=<MulBackward0>) is leaf False!
tensor([],grad_fn=<MulBackward0>) grad is tensor([4.])

## test detach()
*****
tensor([],requires_grad=True) is leaf True !
tensor([1.1],requires_grad=True) grad is tensor([8.8.],grad_fn=<CloneBackward>)
*****
tensor([],grad_fn=<MulBackward0>) is leaf False!
tensor([],grad_fn=<MulBackward0>) grad is tensor([4.])
*****
tensor([]) is leaf True !
tensor([]) grad is tensor([4.])
上一篇 下一篇

猜你喜欢

热点阅读