attention pytorch实现学习
attention pytorch实现学习
关于global attention概述见:
https://www.jianshu.com/p/841557506ab5
本文基于《dive into deep learning》-pytorch
attention原理图(图源《dive into deep learning》).pngAdditive Attention
addtive attention.png如果key和query是不同长度的向量,一般方法是,将两者拼接起来,然后过一个线性层。
这也是常用的concat attention方法
公式也可以写成
实现方式
class AdditiveAttention(nn.Module):
def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):#转换为num_hiddens维度,词向量长度
#假设:query:(2, 1, 20), key:(2, 10, 2), value: (2, 10, 4)
#batch seq word_embedding, key和value seq_len是一样的,query是一个单独的向量,1×20
super(AdditiveAttention,self).__init__(**kwargs)
self.W_k=nn.Linear(key_size,num_hiddens,bias=False)
self.W_q=nn.Linear(query_size,num_hiddens,bias=False)
self.w_v=nn.Linear(num_hiddens,1,bias=False)#
self.dropout=nn.Dropout(dropout)
def forward(self,queries,keys,values,valid_lens):
queries,keys = self.W_q(queries),self.W_k(keys)#映射到相同维度 [2,1,8] [2,10,8]
#query增加一个维度为了方便和key相加。key增加一个维度后面需要
features = queries.unsqueeze(2)+keys.unsqueeze(1) #torch.Size([2, 1, 1, 8]) torch.Size([2, 1, 10, 8])
print(queries.unsqueeze(2).shape,keys.unsqueeze(1).shape)
print(features.shape)#torch.Size([2, 1, 10, 8])
features = torch.tanh(features)
scores = self.w_v(features)#8 *1
print(scores.shape)#torch.Size([2, 1, 10, 1])
scores=scores.squeeze(-1)# w_v消掉最后隐藏层维,因此把这一维去掉,这里就得到了
print(scores.shape)# 2,1,10 把seq中不需要的部分隐藏掉
self.attention_weigths = masked_softmax(scores,valid_lens)#结果取softmax
print(self.attention_weigths)
print(self.attention_weigths.shape)#2,1,10
# attention weights和values加权相加
return torch.bmm(self.dropout(self.attention_weigths),values)#2,1,10 2*10*4 ->2*1*4,10个value的权重加和
pytorch 知识
1.bmm
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,h),注意两个tensor的维度必须为3. https://blog.csdn.net/qq_40178291/article/details/100302375
2.None
[:,None]
None表示该维不进行切片,而是将该维整体作为数组元素处理。
所以,[:,None]的效果就是将二维数组按每行分割,最后形成一个三维数组
torch.tensor([2,3])[:, None]
tensor([[2],
[3]])
print(torch.arange((10), dtype=torch.float32)[None, :] )
print(torch.tensor([2,3])[:, None])
print(torch.arange((10), dtype=torch.float32)[None, :]<torch.tensor([2,3])[:, None])
------------------------------
tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])
tensor([[2],
[3]])
tensor([[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False]]
repeat_interleave repeat_interleave(self: Tensor, repeats: _int, dim: Optional[_int]=None) 参数说明: self: 传入的数据为tensor repeats: 复制的份数 dim: 要复制的维度,可设定为0/1/2.....
sequece_mask
def sequence_mask(X, valid_len, value=0):
"""Mask irrelevant entries in sequences."""
#X size=2,10
maxlen = X.size(1)#10
mask = torch.arange((maxlen), dtype=torch.float32,
device=X.device)[None, :] < valid_len[:, None] #index比大小,比这个index小的,都保留为true
X[~mask] = value#则不保留的部分赋值为value。
return X
mask_softmax
用于去除不需要的padding部分,mask部分的attention score可以忽视。
def masked_softmax(X,valid_lens):
if valid_lens is None:
return nn.functional.softmax(X,dim=-1)
else:
shape=X.shape
if valid_lens.dim()==1:
valid_lens = torch.repeat_interleave(valid_lens,shape[1])
else:
valid_lens=valid_lens.reshape(-1)
X=sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)#2,1,10转换成2,10mask,value复制一个极小值
return nn.functional.softmax(X.reshape(shape),dim=-1)#mask后再softmax
Additive Attention
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,dropout=0.1)
attention.eval()
res =attention(queries, keys, values, valid_lens)
#weight
attention.attention_weigths
#torch.Size([2, 1, 10])# 10个值代表十个weight
show_heatmap
show_heatmaps(attention.attention_weigths.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')
# 2×10
#query:(2, 1, 20), key:(2, 10, 2), value: (2, 10, 4)
#batch1 query 和10个key的交互值
#batch2 query 和10个key的交互值
import torch
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
import random
def show_heatmaps(matrices,xlabel,ylabel,titles=None,figsize=(2.5,2.5),cmap='Reds'):
display.set_matplotlib_formats('svg')
num_rows,num_cols=matrices.shape[0],matrices.shape[1]
print(num_rows,num_cols)
fig,axes=plt.subplots(num_rows,num_cols,figsize=figsize,sharex=True,sharey=True,squeeze=False)#sharex,sharey共享x,y axes,返回各个子图
for i,(row_axes,row_matrices) in enumerate(zip(axes,matrices)):#数据
for j,(ax,matrix) in enumerate(zip(row_axes,row_matrices)):
print(i,j)
pcm = ax.imshow(matrix.detach().numpy(),cmap=cmap)
if i==num_rows-1:
ax.set_xlabel(xlabel)
if j==0:
ax.set_ylabel(ylabel)
if titles:
ax.set_title(titles[j])
fig.colorbar(pcm,ax=axes,shrink=0.6)
Scaled Dot-Product Attention
如果query和key的维度相同,可以用点乘注意力。
Scaled Dot-Product Attention 《dive into DL》.png
class DotProductAttention(nn.Module):
"""Scaled dot product attention."""
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
# Shape of `queries`: (`batch_size`, no. of queries, `d`)
# Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
# Shape of `values`: (`batch_size`, no. of key-value pairs, value dimension)
# Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Set `transpose_b=True` to swap the last two dimensions of `keys`
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
import math
queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 6))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)
在学习文本分类中,我们对余弦相似度会比较熟悉,余弦相似度是用夹角来衡量向量之间相似性的一种方法,
公式:https://www.jianshu.com/p/a894ebba4a1a
余弦相似度.png某种程度上,和这种注意力的公式有相似之处,也可以认为点积式注意力机制以这种方法衡量相似度。
点积attention有一个scaled的操作,这个操作的原因可以参考:
https://www.zhihu.com/question/339723385 transformer中的attention为什么scaled?
总结:
数量级对softmax得到的分布影响非常大。在数量级较大时,softmax将几乎全部的概率分布都分配给了最大值对应的标签。
也就是说,在输入的数量级很大时,梯度消失为0,造成参数更新困难。
transformer中的attention为什么scaled? - TniL的回答 - 知乎 https://www.zhihu.com/question/339723385/answer/782509914
attention知识.pngtransformer中的attention为什么scaled? - 小莲子的回答 - 知乎 https://www.zhihu.com/question/339723385/answer/811341890
即如果不scale的话,容易造成梯度消失,给参数更新造成困难。