pytorch 计算2维矩阵每行最小N个值的索引

2022-01-26  本文已影响0人  Ailien

以二维numpy矩阵为例

import torch
import numpy as np
K=3 #取每行最小3个值的索引
data=np.random.rand(4,7)
print(data)
data=torch.from_numpy(data)
a, idx = torch.sort(data, descending=False)
lists=idx[:,:K]
print(lists)

运行结果如下:

results.jpg
上一篇下一篇

猜你喜欢

热点阅读