transformer实践
2020-11-08 本文已影响0人
习惯了千姿百态
定义输入输出
inputs = torch.FloatTensor([[[1.4, 2.1, 3.7],
[4.1, 5.2, 6.3],
[4.1, 5.5, 3.5],
[2.3, 4.5, 6.4],
[0.3, 4.2, 6.4],
[3.1, 4.5, 6.4],
],
[[1.4, 2.1, 3.7],
[4.1, 5.2, 6.3],
[4.1, 5.5, 3.5],
[2.3, 4.5, 6.4],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
],
[[1.4, 2.1, 3.7],
[4.1, 5.2, 6.3],
[4.1, 5.5, 3.5],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
],
[[1.4, 2.1, 3.7],
[4.1, 5.2, 6.3],
[4.1, 5.5, 3.5],
[2.3, 4.5, 6.4],
[1.1, 2.2, 3.5],
[0.0, 0.0, 0.0],
],
[[1.4, 2.3, 3.7],
[4.1, 5.3, 6.3],
[4.1, 5.1, 3.5],
[2.3, 4.6, 6.4],
[1.3, 5.2, 2.6],
[0.0, 0.0, 0.0],
],
])
# torch.Size([5, 6, 3])
input_length = [6, 4, 3, 5, 5]
targets = torch.IntTensor([[1, 2, 3, 0], [4, 5, 6, 4], [1, 2, 0, 0], [3, 0, 0, 0], [3, 5, 6, 0]])
# torch.Size([5, 4])
Encoder 部分
计算enc_mask
torch.sum(inputs, dim=-1).ne(0).unsqueeze(-2)
# 计算inputs的最后一维的和,和为0即为padding部分

计算self-attention
import torch.nn as nn
# torch.Size([5, 6, 3])
query = inputs
key = inputs
value = inputs
nfeat = inputs.shape[-1]
linear_q = nn.Linear(nfeat, nfeat)
linear_k = nn.Linear(nfeat, nfeat)
linear_v = nn.Linear(nfeat, nfeat)
# 先将q, k, v进行映射
query = linear_q(query)
key = linear_k(key)
value = linear_v(value)
# 计算attention weights
score = torch.matmul(query, key.transpose(-2, -1)) # torch.Size([5, 6, 6])

# 等于对enc_mask取反。因为下面的masked_fill方法只mask True部分,而现在padding对应的位置为False
mask = enc_mask.eq(0)
print('mask:', mask, mask.shape)
# 一个非常小的值,可以视为无穷小。这样,这个位置softmax的值就为0了
min_value = float(numpy.finfo(torch.tensor(0, dtype=score.dtype).numpy().dtype).min)
score = score.masked_fill(mask, min_value)
print(score)


# 计算attention的结果
attn = torch.softmax(score, dim=-1).masked_fill(mask, 0.0)
print(attn)
enc_output = torch.matmul(attn, value)
print('enc_output:', enc_output.shape, enc_output)

enc_output.masked_fill_(~enc_mask.transpose(1, 2), 0.0) # ~为取反
print('enc_output:', enc_output.shape, enc_output)

Decoder部分
self-attention
- 对targets进行embedding
output_size = 7 # 这里词表大小为7,即0-7
d_model = 3 # 和前面的encoder的输出维度保持一致
embedding = torch.nn.Embedding(output_size, d_model)
dec_output = embedding(targets)
# positional encoding和embedding维度一样,这里用embedding作为dec_output来演示decoder的计算过程

- 计算self-attention mask
def get_dec_seq_mask(targets, targets_length=None):
steps = targets.size(-1) # targets的最大长度
padding_mask = targets.ne(0).unsqueeze(-2) # b x 1 x L
# print('padding_mask:', padding_mask, padding_mask.shape)
# 下三角矩阵,因为自回归看不到当前位置之后的标签
seq_mask = torch.ones([steps, steps], device=targets.device)
seq_mask = torch.tril(seq_mask).bool()
seq_mask = seq_mask.unsqueeze(0) # 1 x L x L
# print('seq_mask:', seq_mask, seq_mask.shape)
# 两个mask都要满足,所以取&
return seq_mask & padding_mask # b x L x L



# 计算decoder部分的self-attention
query = dec_output
key = dec_output
value = dec_output
nfeat = dec_output.shape[-1]
linear_q = nn.Linear(nfeat, nfeat)
linear_k = nn.Linear(nfeat, nfeat)
linear_v = nn.Linear(nfeat, nfeat)
query = linear_q(query)
key = linear_k(key)
value = linear_v(value)
score = torch.matmul(query, key.transpose(-2, -1))
# 等于对dec_mask取反
mask = dec_mask.eq(0)
# print('mask:', mask, mask.shape)
# 一个非常小的值,可以视为无穷小
min_value = float(numpy.finfo(torch.tensor(0, dtype=score.dtype).numpy().dtype).min)
score = score.masked_fill(mask, min_value)
attn = torch.softmax(score, dim=-1).masked_fill(mask, 0.0)
print('score1:', attn)
dec_output = torch.matmul(attn, value)
print('dec_output:', dec_output)


- 计算encoder-decoder attention
# Q为decoder的输出,K,V来自encoder的输出
query = dec_output
key = enc_output
value = enc_output
nfeat = dec_output.shape[-1]
linear_q = nn.Linear(nfeat, nfeat)
linear_k = nn.Linear(nfeat, nfeat)
linear_v = nn.Linear(nfeat, nfeat)
query = linear_q(query)
key = linear_k(key)
value = linear_v(value)
score = torch.matmul(query, key.transpose(-2, -1))
# 这里的mask为encoder的mask
mask = enc_mask.eq(0)
score = score.masked_fill(mask, min_value)
print(score)
dec_output = torch.matmul(attn, value)


- decoder结果映射到词表维度
d_model = 3
output_size = 7
output_layer = nn.Linear(d_model, output_size)
logits = output_layer (dec_output )

计算loss
batch_size = logits.size(0)
vocab_size = 7
logits = logits.view(-1, vocab_size)
targets = targets.reshape(-1)
with torch.no_grad():
true_dist = logits.clone()
true_dist.fill_(0.1 / (vocab_size - 1))
ignore = targets == 0 # (B,) 找出targets的padding部分。这样之前decoder计算中未mask的部分,在计算loss的时候去掉
total = len(targets) - ignore.sum().item()
target = targets.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), 0.9) # 将正确的标签所在的位置赋值为0.9,其他位置平分0.1的概率
print(true_dist)
criterion = nn.KLDivLoss(reduction='none')
kl = criterion(torch.log_softmax(logits, dim=1), true_dist) # logits 要取log_softmax
print(kl.masked_fill(ignore.unsqueeze(1), 0))
loss = kl.masked_fill(ignore.unsqueeze(1), 0).sum() / total # 除以total为了归一化
print(loss)

