Embedding的理解

2024-04-23  本文已影响0人  sretik

Embedding :一个简单的查找表,存储固定字典和大小的嵌入。

>>> import torch
>>> import torch.nn as nn
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
# a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
    tensor([[
                [-0.0251, -1.6902,  0.7172],
                [-0.6431,  0.0748,  0.6969],
                [ 1.4970,  1.3448, -0.9685],         
                [-0.3677, -2.7265, -0.1685]
            ],        
            [
                [ 1.4970,  1.3448, -0.9685],         
                [ 0.4362, -0.4004,  0.9400],         
                [-0.6431,  0.0748,  0.6969],         
                [ 0.9124, -2.3616,  1.1151]
            ]])

如上面的例子所示,nn.Embedding生成了一个shape=(10,3)的向量,分别表示0-9十个数字。
可以看到input向量中两个2的向量表示是一样的,4的向量表示也是一样的。

上一篇下一篇

猜你喜欢

热点阅读