How to add padding mask to nn.Tr
I think, when using src_mask
, we need to provide a matrix of shape (S,S), where S is our source sequence length, for example,
import torch
import torch.nn as nn
q = torch.randn(3, 1, 10) # source sequence length 3, batch size 1, embedding size 10
attn = nn.MultiheadAttention(10, 1) # embedding size 10, one head
attn(q, q, q) # self attention
for attn_mask
, we need matrix of shape (S, S),
def src_mask(sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
print(src_mask(3))
# 结果
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0.]])
out = attn(q, q, q, attn_mask=src_mask(3))[1] # attention output weights
print(out)
# 结果
tensor([[[1.0000, 0.0000, 0.0000],
[0.2497, 0.7503, 0.0000],
[0.1139, 0.2764, 0.6097]]], grad_fn=<DivBackward0>)
if we look at F.multi_head_attention_forward
, then what attn_mask
is doing is,
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_output_weights += attn_mask
as we added float('-inf')
to some of the weights, so, when we do softmax, then it returns zero, for example,
a = nn.Softmax(dim=-1)
b = torch.tensor([3., 4., float('-inf')])
print(a(b))
# 结果
tensor([0.2689, 0.7311, 0.0000])
which means that we are not considering some words when finding the representation for a word, for example, when finding attn_weights
for first word in our source sentence, we do not want to consider next words, for finding atten_weights
for second word in our source sentence, we want to consider only fist and sencond word, and not third word.
as for, src_key_padding_mask
, it has to be of shape (N, S), where N is batch size, and S is source sentence length.
I think it is to make us not consider any padded words for finding representation of other words.
for example, if we want to not consider third word in our source sequence, for finding attention weights, then, (batch size of 1)
src_key_padding_mask = torch.tensor([[0, 0, 1]]).bool()
out = attn(q, q, q, attn_mask=src_mask(3), key_padding_mask=src_key_padding_mask)[1]
print(out)
# 结果
tensor([[[1.0000, 0.0000, 0.0000],
[0.2497, 0.7503, 0.0000],
[0.2919, 0.7081, 0.0000]]], grad_fn=<DivBackward0>)
the third column is always zero, as we did not consider what impact the third word has no the representation of other words.
参考链接:
https://discuss.pytorch.org/t/how-to-add-padding-mask-to-nn-transformerencoder-module/63390