[PyTorch] register_hook

2019-06-07  本文已影响0人  VanJordan
    def _emb_hook(self, grad):
        return grad * Variable(self.grad_mask.unsqueeze(1)).type_as(grad)

    def set_grad_mask(self, mask):
        self.grad_mask = torch.from_numpy(mask)
        self.embedding.weight.register_hook(self._emb_hook)
上一篇 下一篇

猜你喜欢

热点阅读