检查PyTorch图的梯度流
2018-09-25 本文已影响0人
药柴
项目中用到了自定义的损失函数,但是在训练过程中发现损失保持不变,说明可能梯度的传导存在问题。在PyTorch论坛中的How to check for vanishing/exploding gradients发现了一个由Adam Paszke给出的较好的小程序bad_grad_viz.py,特别摘录如下:
from graphviz import Digraph
import torch
from torch.autograd import Variable, Function
def iter_graph(root, callback):
queue = [root]
seen = set()
while queue:
fn = queue.pop()
if fn in seen:
continue
seen.add(fn)
for next_fn, _ in fn.next_functions:
if next_fn is not None:
queue.append(next_fn)
callback(fn)
def register_hooks(var):
fn_dict = {}
def hook_cb(fn):
def register_grad(grad_input, grad_output):
fn_dict[fn] = grad_input
fn.register_hook(register_grad)
iter_graph(var.grad_fn, hook_cb)
def is_bad_grad(grad_output):
grad_output = grad_output.data
return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any()
def make_dot():
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
def size_to_str(size):
return '('+(', ').join(map(str, size))+')'
def build_graph(fn):
if hasattr(fn, 'variable'):
u = fn.variable
node_name = 'Variable\n ' + size_to_str(u.size())
dot.node(str(id(u)), node_name, fillcolor='lightblue')
else:
assert fn in fn_dict, fn
fillcolor = 'white'
if any(is_bad_grad(gi) for gi in fn_dict[fn]):
fillcolor = 'red'
dot.node(str(id(fn)), str(type(fn).__name__), fillcolor=fillcolor)
for next_fn, _ in fn.next_functions:
if next_fn is not None:
next_id = id(getattr(next_fn, 'variable', next_fn))
dot.edge(str(next_id), str(id(fn)))
iter_graph(var.grad_fn, build_graph)
return dot
return make_dot
if __name__ == '__main__':
x = Variable(torch.randn(10, 10), requires_grad=True)
y = Variable(torch.randn(10, 10), requires_grad=True)
z = x / (y * 0)
z = z.sum() * 2
get_dot = register_hooks(z)
z.backward()
dot = get_dot()
dot.save('tmp.dot')
例程运行得到一个tmp.dot文件,可视化效果如下:
可以看到由于计算式中出现了
x / (y * 0)
,梯度出现了问题,这两个function被标为红色。将x / (y * 0)
改为x / (y * 1)
后,生成的图就变成了tmp.png
在本人的例子中,出现了
‘NoneType' object has no attribute 'data'
的问题。这里需要注意的是,这段代码假设了图的所有输入都是设置了
requires_grad=True
的,然而很多时候这种情况并不满足,例如在简单的图像分类问题中,我们对于输入图像并不要求梯度,因为我们不需要对其进行修改。因此,为了使这段代码能够直接适用于大型的模型,可以修改register_hooks(var)
内的is_bad_grad(grad_output)
,以修正这一个错误,如下:
def is_bad_grad(grad_output):
try:
grad_output = grad_output.data
except:
print('Fail to get grad')
return True
return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any()
不过,这个修改只是简单为了让这段代码能够工作,实际上时不符合道理的。事实上,这段代码更适合小的单元模块测试,例如自定义的Loss函数。