torch.gather

2022-07-02  本文已影响0人  菌子甚毒

https://pytorch.org/docs/stable/generated/torch.gather.html

一个简单的例子:

t = torch.rand(2,3)
"""
tensor([[0.8133, 0.5586, 0.7917],
        [0.0551, 0.2322, 0.9087]])
"""
t.gather(dim=0,index=torch.tensor([[0,1,0],[1,0,1]]))
"""
tensor([[0.8133, 0.2322, 0.7917],
        [0.0551, 0.5586, 0.9087]])
"""
# 常用于以下需求:
# celoss = torch.tensor([i_s[i_t] for i_s,i_t in zip(softmax,target)])

input = torch.randn(3, 5, requires_grad=True) # (3,5)

n_samples = input.shape[0] # 注意dim=1时,input.shape[0]=index.shape[0], 同理, 可推dim=0时,input.shape[1]=index.shape[1]
channel = 6

idx = torch.randint(low=0,high=5,size=(n_samples*channel,)).reshape(n_samples,channel)
"""
tensor([[0, 0, 4, 2, 3, 1], 第一行取第0个,第0个,第4个...
        [3, 3, 1, 0, 2, 2], 第二行取第3个,第3个,第1个...
        [4, 4, 4, 2, 1, 3]]) ...
"""
input.gather(dim=1,index=idx) # torch.Size([3, 6])
上一篇下一篇

猜你喜欢

热点阅读