自注意力机制和位置编码

2022-10-08  本文已影响0人  小黄不头秃

(一)自注意力机制

给定一个由词元组成的输入序列\mathbf{x}_1, \ldots, \mathbf{x}_n
其中任意\mathbf{x}_i \in \mathbb{R}^d1 \leq i \leq n)。
自注意力池化层将x_i当作是key,value和query。然后对每个序列抽取特征得到y。
该序列的自注意力输出为一个长度相同的序列
\mathbf{y}_1, \ldots, \mathbf{y}_n,其中:

\mathbf{y}_i = f(\mathbf{x}_i, (\mathbf{x}_1, \mathbf{x}_1), \ldots, (\mathbf{x}_n, \mathbf{x}_n)) \in \mathbb{R}^d

(1)比较不同的处理序列的神经网络

CNN也可以用于处理序列问题,他的实现就是说把序列当作是一个一维的输入。
RNN的特殊网络结构更适合记忆长序列。
使用自注意力机制能够预测长序列,但是随着长度的增加,计算复杂度也会成倍增加。

(二)位置编码

跟CNN/RNN不同,自注意力并没有记录序列的位置信息
位置编码将位置信息注入到输入里

假设输入表示\mathbf{X} \in \mathbb{R}^{n \times d}
(每一行是一个样本,每一列是一个特征。)
包含一个序列中n个词元的d维嵌入表示。
位置编码使用相同形状的位置嵌入矩阵(p不是概率,是position)
\mathbf{P} \in \mathbb{R}^{n \times d}输出\mathbf{X} + \mathbf{P}
矩阵第i行、第2j列和2j+1列上的元素为:

\begin{aligned} p_{i, 2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right),\\p_{i, 2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right).\end{aligned}

两个相邻的列,曲线是一样的但是会在横向移动。
相同的列,不同行之间的位置信息可以通过矩阵变换,转换过去。

这样的编码属于是相对位置编码。也就是说知道了一个位置信息,序列中其他的位置信息都能够用一个矩阵变换给推算出来。

(三)代码实现

import torch
from torch import nn 
from d2l import torch as d2l
#@save
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
                  ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
上一篇下一篇

猜你喜欢

热点阅读