Embedding 原理与代码实战
embedding 的原理
embedding 层做了个什么呢?它把我们的稀疏矩阵,通过一些线性变换(在CNN中用全连接层进行转换,也称为查表操作),变成了一个密集矩阵,这从稀疏矩阵到密集矩阵的过程,叫做 embedding,很多人也把它叫做查表,因为它们之间也是一个一一映射的关系。
对 one-hot 向量的 embedding,相当于查表,embedding 直接用查表作为操作,而不是矩阵乘法运算,这大大降低了运算量,所以降低运算量不是因为id的embedding 向量的出现,而是因为把 one-hot 的 embedding 矩阵乘法运算简化为了查表操作。
如下图所示,embedding 过程就是将 one-hot 向量输入到全连接层输出2个3维的稠密向量,这个(6, 3)的全连接层参数,就是一个 id 向量表,对应 6 种 id 的 embedding 稠密向量。又例如,假设不同 id 的个数为 100(即 one-ho t向量长度为100),设定 embedding 稠密向量的维度为 10,则全连接层的参数矩阵为100*10(这个矩阵就是 id 向量表,每个 id 特征都有一个 10 维的稠密向量表示它)。
image.pngembedding 代码实现(Pytroch版本)
首先定义一个 embedding
import torch.nn as nn
# 5 输入类别数目, 即One-hot长度, 3 输出 embedding 稠密向量维度
my_embedding = nn.Embedding(5, 3)
查看一下embedding初始化的 weight
my_embedding.weight
image.png
从这可以看到 embedding 生成了一个5*3的矩阵,其实也就是 embedding 全连接层的参数。
这里以[0,1,2,3,4]为例, 假设有以下4条数据,具体特征值如下 (注意因为定义的 embedding 类别数目为5,所以输入值不能超过4)
test = [0, 1, 2, 4]
embed = my_embedding(torch.LongTensor(test))
embed
image.png
从计算结果可以看到,embedding 之后得到的是一个4*3的矩阵,即原始特征每一个值用一个3维稠密向量表示。看到这里可能会有朋友疑问,这个4*3的矩阵具体是怎么生成的,或者生成的依据是什么?
带着这个问题,我们不妨回到计算之前,如果没有 embedding 我们该如何对一个类别型特征 one-hot, 答案很显然,用0、1表示。现在我们使用 one-hot 对上面的数据处理,可以想到,one-hot 之后预期结果如下:
test = [0,1,2,4]
one_hot(test) #这里是伪代码,具体 one_hot 计算逻辑不再展示
image.png
这里one-hot之后生成了一个4*5的矩阵。很显然,这个结果很好理解并且符合我们预期。那么这个结果和上面embedding生成的4*3矩阵有什么关系呢?
embedding 可用性理解
其实,前文已经说明过,embedding 相当于查表。所以这里查的到底是什么表?细心的朋友可以发现,其实查的就是我们最初定义 embedding 层的时候生成的 weight 矩阵(5*3),现在再回顾一下embedding 对 input 数据的计算过程,“查表”结果显而易见。
image.png
最后为了加深我们对embedding查表逻辑的理解,我们可以尝试对这个全连接层的参数,使用矩阵乘法来计算一下,看一下最后的计算结果:
test2 = [[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 1],
]
torch.matmul(torch.FloatTensor(test2), my_embedding.weight.data)
image.png
结果和 embedding 计算结果一致!这里也是文章最开始提到的,embedding 直接用查表作为操作,而不是矩阵乘法运算,这大大降低了运算量。
image.png
以上就是 embedding 在对稀疏类别特征的计算过程,这里有一点要注意,最初 embedding 产生的 weight 可以理解为随机的,并且整个过程并没有进行训练,所以此时的 embedding 本质仅仅是一种低维的表示向量,不具有其他数据信息。
embedding 之所以强大,在于 weight 本身是一个可训练的张量,可以接入各种网络结构。所以往往 embeddin 作为网络结构的第一层,经过中间 n 层网络结构处理(n可以为0),最后到输出层。这样在网络的训练过程中,weight 会得到更新,此时 embedding 才具有数据信息,直接用这个全连接层的权重参数作为特征表达。代表某一个 id,或者作为 id 的特征表达(向量的夹角余弦能够在某种程度上表示不同id间的相似度)。