pytorch/numpy/heapq计算二维矩阵每行最小N个值

2022-01-27  本文已影响0人  Ailien

之前写了一个用pytorch计算二维矩阵每行最小N个值的索引的代码,示例用了一个[4,7]维的数据,在示范的时候也没感觉出来运行时间的概念。当我开始用在大规模数据上的时候发现耗时很大。因此我又尝试了其他的方法来对比一下时耗。目前方法包括numpy,pytorch,heapq,代码如下:

import heapq
import time
import torch
import numpy as np
import time
K=3 #取每行最小3个值的索引
data=np.random.rand(4,7)
print(data)

得到数据如下:


data

下面是三种算法及时间对比

start=time.time()
x=data.argsort()
id1=x[:,:K]
print(id1)
end1=time.time()
print(end1-start)
data01=torch.from_numpy(data)
a, idx = torch.sort(data01, descending=False)
# print(idx)
lists=idx[:,:K]
print(lists)
end2=time.time()
print(end2-end1)

lists = [[] for i in range(len(data))]
for i in range(len(data)):
    lists[i].append(heapq.nsmallest(K, range(len(data[i])), data[i].take))
print(lists)
end3=time.time()
print(end3-end2)

结果如下:

results
可以看出三种算法的结果是一样的,而且也是对的,但是运行时间却差的很多,目前来看最快的是基于heapq的算法,pytorch竟然是最慢的,而且我安装的pytorch是cuda版本的,这让我百思不得其解。
可是当我将data矩阵的维度大幅加大以后,三个算法的时间对比又有了新的排序
result1
上图为增大至[1044,7065]以后的结果,可见numpy还是最快的,但heapg算法的时间却增加了很多。
当矩阵变大了以后,耗时还是很大,如果大家有其他可以利用gpu加速的更快速的算法,欢迎交流
上一篇下一篇

猜你喜欢

热点阅读