Embedding的理解
2024-04-23 本文已影响0人
sretik
Embedding :一个简单的查找表,存储固定字典和大小的嵌入。
>>> import torch
>>> import torch.nn as nn
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
# a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[
[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]
],
[
[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]
]])
如上面的例子所示,nn.Embedding生成了一个shape=(10,3)的向量,分别表示0-9十个数字。
可以看到input向量中两个2的向量表示是一样的,4的向量表示也是一样的。