attention pytorch实现学习

2021-06-20  本文已影响0人  锦绣拾年

attention pytorch实现学习

关于global attention概述见:

https://www.jianshu.com/p/841557506ab5

本文基于《dive into deep learning》-pytorch

attention原理图(图源《dive into deep learning》).png

Additive Attention

addtive attention.png

如果key和query是不同长度的向量,一般方法是,将两者拼接起来,然后过一个线性层。
这也是常用的concat attention方法

公式也可以写成

a(q,k)= w_v^T tanh(W[q,k])

W_qq+W_kk = [W_q,W_k]\begin{bmatrix} q\\k \end{bmatrix}

实现方式

class AdditiveAttention(nn.Module):
    def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):#转换为num_hiddens维度,词向量长度
        #假设:query:(2, 1, 20), key:(2, 10, 2), value: (2, 10, 4) 
        #batch seq word_embedding,  key和value seq_len是一样的,query是一个单独的向量,1×20
        super(AdditiveAttention,self).__init__(**kwargs)
        self.W_k=nn.Linear(key_size,num_hiddens,bias=False)
        self.W_q=nn.Linear(query_size,num_hiddens,bias=False)
        self.w_v=nn.Linear(num_hiddens,1,bias=False)#
        self.dropout=nn.Dropout(dropout)
    def forward(self,queries,keys,values,valid_lens):
        queries,keys = self.W_q(queries),self.W_k(keys)#映射到相同维度 [2,1,8] [2,10,8]
        #query增加一个维度为了方便和key相加。key增加一个维度后面需要    
        features = queries.unsqueeze(2)+keys.unsqueeze(1) #torch.Size([2, 1, 1, 8]) torch.Size([2, 1, 10, 8])
        print(queries.unsqueeze(2).shape,keys.unsqueeze(1).shape)
        print(features.shape)#torch.Size([2, 1, 10, 8])
        features = torch.tanh(features)
        
        scores = self.w_v(features)#8 *1
        print(scores.shape)#torch.Size([2, 1, 10, 1])
        scores=scores.squeeze(-1)# w_v消掉最后隐藏层维,因此把这一维去掉,这里就得到了
        print(scores.shape)# 2,1,10  把seq中不需要的部分隐藏掉
        
        self.attention_weigths = masked_softmax(scores,valid_lens)#结果取softmax
        print(self.attention_weigths)
        print(self.attention_weigths.shape)#2,1,10
        # attention weights和values加权相加
        return torch.bmm(self.dropout(self.attention_weigths),values)#2,1,10 2*10*4 ->2*1*4,10个value的权重加和

pytorch 知识

1.bmm

计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,h),注意两个tensor的维度必须为3. https://blog.csdn.net/qq_40178291/article/details/100302375

2.None

[:,None]

None表示该维不进行切片,而是将该维整体作为数组元素处理。

所以,[:,None]的效果就是将二维数组按每行分割,最后形成一个三维数组

torch.tensor([2,3])[:, None]
tensor([[2],
        [3]])
print(torch.arange((10), dtype=torch.float32)[None, :] )
print(torch.tensor([2,3])[:, None])
print(torch.arange((10), dtype=torch.float32)[None, :]<torch.tensor([2,3])[:, None])
------------------------------
tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])
tensor([[2],
        [3]])
tensor([[ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False]]
repeat_interleave repeat_interleave(self: Tensor, repeats: _int, dim: Optional[_int]=None) 参数说明: self: 传入的数据为tensor repeats: 复制的份数 dim: 要复制的维度,可设定为0/1/2.....

sequece_mask

def sequence_mask(X, valid_len, value=0):
    """Mask irrelevant entries in sequences."""
    #X size=2,10
    maxlen = X.size(1)#10
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None] #index比大小,比这个index小的,都保留为true

    X[~mask] = value#则不保留的部分赋值为value。
    return X

mask_softmax

用于去除不需要的padding部分,mask部分的attention score可以忽视。

def masked_softmax(X,valid_lens):  
    if valid_lens is None:
        return nn.functional.softmax(X,dim=-1)
    else:
        shape=X.shape
        if valid_lens.dim()==1:
            valid_lens = torch.repeat_interleave(valid_lens,shape[1])
        else:
            valid_lens=valid_lens.reshape(-1)
        X=sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)#2,1,10转换成2,10mask,value复制一个极小值
        return nn.functional.softmax(X.reshape(shape),dim=-1)#mask后再softmax

Additive Attention

attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,dropout=0.1)
attention.eval()
res =attention(queries, keys, values, valid_lens)
#weight
attention.attention_weigths
#torch.Size([2, 1, 10])# 10个值代表十个weight

show_heatmap

show_heatmaps(attention.attention_weigths.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')
# 2×10  
#query:(2, 1, 20), key:(2, 10, 2), value: (2, 10, 4) 
#batch1 query 和10个key的交互值
#batch2 query 和10个key的交互值
import torch
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
import random
def show_heatmaps(matrices,xlabel,ylabel,titles=None,figsize=(2.5,2.5),cmap='Reds'):
    display.set_matplotlib_formats('svg')
    num_rows,num_cols=matrices.shape[0],matrices.shape[1]
    print(num_rows,num_cols)
    fig,axes=plt.subplots(num_rows,num_cols,figsize=figsize,sharex=True,sharey=True,squeeze=False)#sharex,sharey共享x,y axes,返回各个子图
    for i,(row_axes,row_matrices) in enumerate(zip(axes,matrices)):#数据
        for j,(ax,matrix) in enumerate(zip(row_axes,row_matrices)):
            print(i,j)
            pcm = ax.imshow(matrix.detach().numpy(),cmap=cmap)
            if i==num_rows-1:
                ax.set_xlabel(xlabel)
            if j==0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm,ax=axes,shrink=0.6)

Scaled Dot-Product Attention

如果query和key的维度相同,可以用点乘注意力。


Scaled Dot-Product Attention 《dive into DL》.png
class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
    # Shape of `queries`: (`batch_size`, no. of queries, `d`) 
    # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
    # Shape of `values`: (`batch_size`, no. of key-value pairs, value dimension)
    # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Set `transpose_b=True` to swap the last two dimensions of `keys`
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
import math
queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 6))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)

在学习文本分类中,我们对余弦相似度会比较熟悉,余弦相似度是用夹角来衡量向量之间相似性的一种方法,

公式:https://www.jianshu.com/p/a894ebba4a1a

余弦相似度.png

某种程度上,和这种注意力的公式有相似之处,也可以认为点积式注意力机制以这种方法衡量相似度。

点积attention有一个scaled的操作,这个操作的原因可以参考:

https://www.zhihu.com/question/339723385 transformer中的attention为什么scaled?

总结:

数量级对softmax得到的分布影响非常大。在数量级较大时,softmax将几乎全部的概率分布都分配给了最大值对应的标签

也就是说,在输入的数量级很大时,梯度消失为0,造成参数更新困难

transformer中的attention为什么scaled? - TniL的回答 - 知乎 https://www.zhihu.com/question/339723385/answer/782509914

attention知识.png

transformer中的attention为什么scaled? - 小莲子的回答 - 知乎 https://www.zhihu.com/question/339723385/answer/811341890

即如果不scale的话,容易造成梯度消失,给参数更新造成困难。

上一篇下一篇

猜你喜欢

热点阅读