Multi-head attention 多头注意力机制

2021-06-20  本文已影响0人  锦绣拾年

Multi-head attention

本文基于《dive into deep learning》-pytorch

代码参考 《dive into deep learning》-pytorch

multi-head attention

基本信息

我们可以会希望注意力机制可以联合使用不同子空间的key,value,query的表示。因此,不是只用一个attention pooling,query、key、value可以被h个独立学到的线性映射转换。最后,h个attention pooling输出concat 并且再次通过一个线性映射得到最后的输出。

这种设计就是multi-head attention, h个attention pooling输出中的每一个就是一个头。使用全连接层来实现线性转换。

multi-attention1.png

理解纠错

【我过去有一个误解,就是multi-head是和CNN类似的机制,用多个的W降维,之后再计算多个注意力分数,再concat。直到我用pytorch中自带的multi-head attention,要求num_heads是hidden层维度可以整除的数,才发现这里的multi-head是针对子空间的】
【但是这里可以理解,用同样的维度,训练多个空间,可以更好地增强表达能力】

这部分解答可以参考:

https://www.zhihu.com/question/350369171 -transformer中multi-head attention中每个head为什么要进行降维?(实际上用切割来表示更为准确)

https://www.zhihu.com/question/446385446 - BERT中,multi-head 768*64*12与直接使用768*768矩阵统一计算,有什么区别?

对于 Multi-Head Attention,简单来说就是多个 Self-Attention 的组合,但多头的实现不是循环的计算每个头,而是通过 transposes and reshapes,用矩阵乘法来完成的。

In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. —— 来自 google BERT 源代码注释

Transformer中把 d ,也就是hidden_size/embedding_size这个维度做了reshape拆分,具体可以看对应的 pytorch 代码

hidden_size (d) = num_attention_heads (m) * attention_head_size (a),也即 d=m*a

【↑作者:海晨威
链接:https://www.zhihu.com/question/350369171/answer/1718672303
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。】

正如回答中有写:

transformer中multi-head attention中每个head为什么要进行降维? - LooperXX的回答 - 知乎 https://www.zhihu.com/question/350369171/answer/860552006

回到题主的问题上来,如果只使用 one head 并且维度为 d_model ,相较于 8 head 并且维度为d_model/8,存在高维空间下学习难度较大的问题,文中实验也证实了这一点,于是将原有的高维空间转化为多个低维子空间并再最后进行拼接,取得了更好的效果,十分巧妙。

在实现的时候,multi-head把维度从[batch, len, embeding]变为[batch, len, head, embeding/head], 然后head就是多头,对每一个 embeding/head部分计算对应的attention。

pytorch实现

class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size,num_hiddens,bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)#映射到numhiddens
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
    def forward(self,queries,keys,values,valid_lens):
        #注意最后的 [batch_size` * `num_heads`,number of  key-value pairs,num_hiddens` / `num_heads]
        #这里涉及到reshape操作
        queries = transpose_qkv(self.W_q(queries),self.num_heads)#batch,seq,embed -> batch*num_head,seq,embed/num_head
        keys = transpose_qkv(self.W_k(keys),self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:#相当于每个batch扩充num_heads遍
            valid_lens = torch.repeat_interleave(valid_lens,repeats=self.num_heads,dim=0)
        print(queries.shape)#10,4,20
        print(values.shape)#10,6,20
        print(")*&^%$^&*()")
        output = self.attention(queries, keys, values, valid_lens)#attention计算是transpose之后的向量
        #得到,batch×head, seq,embed/head的矩阵,每一个embed/head是这一部分词向量子空间的attention加权和。
        weights= self.attention.attention_weights
        print(weights.shape)#10,4,6  query: 2 4 100 key: 2,6,100 ,一共10组,每组 4×6,query和key的交互值
        
        output_concat = transpose_output(output, self.num_heads)#transpose的逆运算
        return self.W_o(output_concat)#最后做一次线性变换 #2,4,100
        
def transpose_qkv(X, num_heads):
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)#batch seq head embed/head

    X = X.permute(0, 2, 1, 3) # batch head seq embed/head

    return X.reshape(-1, X.shape[2], X.shape[3])# batch×head, seq,embed/head
def transpose_output(X, num_heads):# batch×head, seq,embed/head
    """Reverse the operation of `transpose_qkv`"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])# batch,head, seq,embed/head
    X = X.permute(0, 2, 1, 3)## batch,seq,head, embed/head
    return X.reshape(X.shape[0], X.shape[1], -1)#batch,seq,embed
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))#2,4,100
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))#2,6,100
attention(X, Y, Y, valid_lens).shape #2,4,100 query有4个,得到4个对应的结果
#中间 attention weight大小是 10,4,6

self-attention

输入和输出大小一样

query,key,value一样

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))#2,4,100
attention(X, X, X, valid_lens).shape #2,4,100

补充
Transformer/CNN/RNN的对比(时间复杂度,序列操作数,最大路径长度) - Gordon Lee的文章 - 知乎
https://zhuanlan.zhihu.com/p/264749298
https://spaces.ac.cn/archives/4765
↑对self-attention的分析也很好,self-attention有不能充分编入位置信息的硬伤等

上一篇下一篇

猜你喜欢

热点阅读