Torch反向传播时出错或者梯度为NaN的问题排查

2021-08-23  本文已影响0人  酌泠

师兄问到的backward时出现NaN和错误不好定位的问题,感觉有必要记录一下

解决方案
Automatic differentiation package - torch.autograd — PyTorch 1.9.0 documentation
使用with torch.autograd.detect_anomaly():包裹传播过程检测错误

import torch
from torch import autograd

class MyFunc(autograd.Function):  # A func generate NaN when backward
    @staticmethod
    def forward(ctx, inp):
        return inp.clone()

    @staticmethod
    def backward(ctx, gO):
        grad1 = torch.zeros_like(gO) / torch.zeros_like(gO)  # NaN
        return grad1


class Net(torch.nn.Module):  # toy net
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(10, 10)
        self.l2 = torch.nn.Linear(10, 10)

    def forward(self, x):
        o1 = self.l1(x)
        o2 = MyFunc.apply(o1)
        o3 = self.l2(o2)
        return o3.sum()

with autograd.detect_anomaly():  # enable Anomaly Detection
    inp = torch.rand(10, 10, requires_grad=True)
    m = Net()
    out = m(inp)
    out.backward()
    print(inp.grad)

粗暴的方法,直接查看模型各节点的梯度是否计算正确

for name, param in m.named_parameters():
    shape, c = (param.grad.shape, param.grad.sum()) if param.grad is not None else (None, None)
    print(f'{name}: {param.shape} \n\t grad: {shape} \n\t {c}')
上一篇 下一篇

猜你喜欢

热点阅读