检查PyTorch图的梯度流

2018-09-25  本文已影响0人  药柴

项目中用到了自定义的损失函数,但是在训练过程中发现损失保持不变,说明可能梯度的传导存在问题。在PyTorch论坛中的How to check for vanishing/exploding gradients发现了一个由Adam Paszke给出的较好的小程序bad_grad_viz.py,特别摘录如下:

from graphviz import Digraph
import torch
from torch.autograd import Variable, Function

def iter_graph(root, callback):
    queue = [root]
    seen = set()
    while queue:
        fn = queue.pop()
        if fn in seen:
            continue
        seen.add(fn)
        for next_fn, _ in fn.next_functions:
            if next_fn is not None:
                queue.append(next_fn)
        callback(fn)

def register_hooks(var):
    fn_dict = {}
    def hook_cb(fn):
        def register_grad(grad_input, grad_output):
            fn_dict[fn] = grad_input
        fn.register_hook(register_grad)
    iter_graph(var.grad_fn, hook_cb)

    def is_bad_grad(grad_output):
        grad_output = grad_output.data
        return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any()

    def make_dot():
        node_attr = dict(style='filled',
                        shape='box',
                        align='left',
                        fontsize='12',
                        ranksep='0.1',
                        height='0.2')
        dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))

        def size_to_str(size):
            return '('+(', ').join(map(str, size))+')'

        def build_graph(fn):
            if hasattr(fn, 'variable'):
                u = fn.variable
                node_name = 'Variable\n ' + size_to_str(u.size())
                dot.node(str(id(u)), node_name, fillcolor='lightblue')
            else:
                assert fn in fn_dict, fn
                fillcolor = 'white'
                if any(is_bad_grad(gi) for gi in fn_dict[fn]):
                    fillcolor = 'red'
                dot.node(str(id(fn)), str(type(fn).__name__), fillcolor=fillcolor)
            for next_fn, _ in fn.next_functions:
                if next_fn is not None:
                    next_id = id(getattr(next_fn, 'variable', next_fn))
                    dot.edge(str(next_id), str(id(fn)))
        iter_graph(var.grad_fn, build_graph)

        return dot

    return make_dot

if __name__ == '__main__':
    x = Variable(torch.randn(10, 10), requires_grad=True)
    y = Variable(torch.randn(10, 10), requires_grad=True)

    z = x / (y * 0)
    z = z.sum() * 2
    get_dot = register_hooks(z)
    z.backward()
    dot = get_dot()
    dot.save('tmp.dot')

例程运行得到一个tmp.dot文件,可视化效果如下:

tmp.png
可以看到由于计算式中出现了x / (y * 0),梯度出现了问题,这两个function被标为红色。将x / (y * 0)改为x / (y * 1)后,生成的图就变成了
tmp.png
在本人的例子中,出现了‘NoneType' object has no attribute 'data'的问题。
这里需要注意的是,这段代码假设了图的所有输入都是设置了requires_grad=True的,然而很多时候这种情况并不满足,例如在简单的图像分类问题中,我们对于输入图像并不要求梯度,因为我们不需要对其进行修改。因此,为了使这段代码能够直接适用于大型的模型,可以修改register_hooks(var)内的is_bad_grad(grad_output),以修正这一个错误,如下:
def is_bad_grad(grad_output):
        try:
            grad_output = grad_output.data
        except:
            print('Fail to get grad')
            return True
        return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any()

不过,这个修改只是简单为了让这段代码能够工作,实际上时不符合道理的。事实上,这段代码更适合小的单元模块测试,例如自定义的Loss函数。

上一篇下一篇

猜你喜欢

热点阅读