How to add padding mask to nn.Tr

2020-03-03  本文已影响0人  魏鹏飞

I think, when using src_mask, we need to provide a matrix of shape (S,S), where S is our source sequence length, for example,

import torch
import torch.nn as nn

q = torch.randn(3, 1, 10) # source sequence length 3, batch size 1, embedding size 10
attn = nn.MultiheadAttention(10, 1) # embedding size 10, one head
attn(q, q, q) # self attention

for attn_mask, we need matrix of shape (S, S),

def src_mask(sz):
  mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  return mask

print(src_mask(3))

# 结果
tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])
out = attn(q, q, q, attn_mask=src_mask(3))[1] # attention output weights
print(out)

# 结果
tensor([[[1.0000, 0.0000, 0.0000],
         [0.2497, 0.7503, 0.0000],
         [0.1139, 0.2764, 0.6097]]], grad_fn=<DivBackward0>)

if we look at F.multi_head_attention_forward, then what attn_mask is doing is,

if attn_mask is not None:
        attn_mask = attn_mask.unsqueeze(0)
        attn_output_weights += attn_mask

as we added float('-inf') to some of the weights, so, when we do softmax, then it returns zero, for example,

a = nn.Softmax(dim=-1)
b = torch.tensor([3., 4., float('-inf')])
print(a(b))

# 结果
tensor([0.2689, 0.7311, 0.0000])

which means that we are not considering some words when finding the representation for a word, for example, when finding attn_weights for first word in our source sentence, we do not want to consider next words, for finding atten_weights for second word in our source sentence, we want to consider only fist and sencond word, and not third word.

as for, src_key_padding_mask, it has to be of shape (N, S), where N is batch size, and S is source sentence length.

I think it is to make us not consider any padded words for finding representation of other words.

for example, if we want to not consider third word in our source sequence, for finding attention weights, then, (batch size of 1)

src_key_padding_mask = torch.tensor([[0, 0, 1]]).bool()
out = attn(q, q, q, attn_mask=src_mask(3), key_padding_mask=src_key_padding_mask)[1]
print(out)

# 结果
tensor([[[1.0000, 0.0000, 0.0000],
         [0.2497, 0.7503, 0.0000],
         [0.2919, 0.7081, 0.0000]]], grad_fn=<DivBackward0>)

the third column is always zero, as we did not consider what impact the third word has no the representation of other words.

参考链接:
https://discuss.pytorch.org/t/how-to-add-padding-mask-to-nn-transformerencoder-module/63390

上一篇下一篇

猜你喜欢

热点阅读