Pytorch

Pytorch框架学习(3)——计算图与动态图机制

2020-01-18  本文已影响0人  aidanmomo

计算图与动态图机制

1. 计算图

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)

    a = torch.add(w, x) # 可以通过retain_grad()来保存相应结点的梯度
    a.retain_grad()
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    y.backward()
    print(w.grad)

    # 查看叶子结点:
    print('is_leaf:\n', w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)

    # 查看梯度:
    print('gradient:\n', w.grad, x.grad, a.grad, b.grad, y.grad) # a的梯度通过retain_grad()保存下来

执行结果如下,这里可以通过在反向传播之前,执行retain_grad()操作保存下来相应结点的梯度

tensor([5.])
is_leaf:
 True True False False False
gradient:
 tensor([5.]) tensor([2.]) tensor([2.]) None None

2. Pytorch的动态图

根据计算图搭建方式的不同,可将计算图分为动态图和静态图

上一篇下一篇

猜你喜欢

热点阅读