transformer 中位置编码的理解

2023-06-22  本文已影响0人  michael_0x

参考https://www.youtube.com/watch?v=dichIcUZfOw
验证如下:
对应论文里的编码公式:

image.png
import math

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt


def plot_sin():
    X = np.linspace(-np.pi, +np.pi, 256)
    Y = np.sin(X / 2)
    # Y = np.concatenate([Y, Y, Y], 0)
    plt.plot(X, Y)
    plt.show()

def plot_pos_embedding():
    # plot_sin()
    d_model = 32
    features = torch.arange(0, d_model)
    positions = torch.arange(0, 100)
    print("positions:", positions)
    # pos_fn = lambda p: torch.sin(p / torch.pow(10000, features / d_model))
    pos_sin = lambda d: torch.sin(positions / math.pow(10000, 2 * d / d_model))
    pos_cos = lambda d: torch.cos(positions / math.pow(10000, 2 * d / d_model))
    pos = pos_sin(0)
    print("pos encodings:", pos)
    plt.xlabel(positions)
    plt.plot(pos_sin(0))
    plt.plot(pos_sin(2) + 2)
    plt.plot(pos_sin(4) + 4)
    plt.plot(pos_cos(1) + 0)
    plt.plot(pos_cos(3) + 2)
    plt.plot(pos_cos(5) + 4)
    plt.show()
plot_pos_embedding()

只有Sin:


image.png

加上Cos:


image.png

首先,验证了对于词向量里不同index,位置编码三角函数的频率是不一样的。
粗略直观的来理解一下:
位置编码三角函数的频率不同,说明相同的Y值(即编码的值)重复出现所需的步长不一样,因此如果要确认任意两个位置之间的相对距离:

  1. 首先把频率从低到高排列,实际就是词向量index的逆序排列。
  2. 先从频率最慢的(词向量index最大的)看,因为重复值所需的步长可能超过了最大的位置间隔,因此在这个index/频率产生编码数据里,没有重复的值。
  3. 再看下一条,重复值所需的步长会小一点,看看有没有重复的值。一值往下一条搜索,重复值所需的步长不断缩小,最终会找到在某个频率/步长,首次找到两个位置编码出来的值是相等的,也就确定了这两个位置的相对距离了。
  4. 那么如果这个相对位置很重要,有助于降低训练的loss,模型训练的时候,就会把相关单词这个维度相对位置对应的模型参数的重要性给放大。

下一步,多个词之间的相对位置:

  1. 因为已经能确认任意两个词之间的相对位置,那么多个词之间的相对位置其实就是多个“两个词之间相对位置”的组合,如果多个词之间的相对位置比较重要,有助于降低训练的loss,模型训练的时候,就会把相关单词这个维度相对位置“组合”对应的模型参数的重要性给放大。

这个编码的本质,就是定义可以被AI学习的词汇位置关系。

后面再联系attention机制来进一步研究ai是如何学习到词汇之间的位置关系所代表的意义的。

参考资料:
https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
https://timodenk.com/blog/linear-relationships-in-the-transformers-positional-encoding/

上一篇下一篇

猜你喜欢

热点阅读