PyTorch 使用问题及技巧汇总

2019-07-29  本文已影响0人  捡个七

专开此篇汇总使用 PyTorch 过程中遇到的各种问题及技巧。

one-hot encoding 的注意情况

网络中的梯度流检查

def plot_grad_flow(named_parameters):
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.
    
    Usage: Plug this function in Trainer class after loss.backwards() as 
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
    ave_grads = []
    max_grads= []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
            max_grads.append(p.grad.abs().max())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([Line2D([0], [0], color="c", lw=4),
                Line2D([0], [0], color="b", lw=4),
                Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])

参考:Check gradient flow in network

padding 相关问题

参考:PyTorch学习笔记(9)——nn.Conv2d和其中的padding策略

保存和提取模型参数

# 保存整个模型
torch.save(net, 'net.pkl')

# 提取模型
net = torch.load('net.pkl')
# 保存模型参数,节省空间
torch.save(net.state_dict(), 'net_params.tar')
net = xxx_model() # call model object
net.load_state_dict(torch.load('net_params.tar')) # load the params
上一篇下一篇

猜你喜欢

热点阅读