两次sort排序求矩阵的升序或降序元素的位置(pytorch版)
之前看过ssd.pytorch里面的代码的Hard Negative Mining部分,通过两次sort排序来求取每张图片中正样本数的3倍的负样本,感觉还是比较巧妙的,以后可能会用到类似的思想,这里记录一下:
先来看一下官方文档的解释:
torch.sort(input, dim=-1, descending=False, out=None) -> (Tensor, LongTensor)
Sorts the elements of the input tensor along a given dimension in ascending order by value.
Ifdimis not given, the last dimension of the input is chosen.
IfdescendingisTruethen the elements are sorted in descending order by value.
A namedtuple of (values, indices) is returned, where the values are the sorted values and indices are the indices of the elements in the original input tensor.
Parameters:
- input (Tensor) – the input tensor.
- dim (int, optional) – the dimension to sort along
- descending (bool, optional) – controls the sorting order (ascending or descending)
- out (tuple, optional) – the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers
这个函数还是比较好理解的,和python里面的sorted函数还是很像的。
下面开始进入正题:
b = torch.randint(low=1, high=10, size=(2,5))
b
Out[174]:
tensor([[4, 9, 7, 8, 5],
[3, 5, 9, 1, 4]])
# 现在进行第一次的sort,返回的是元素降序的对应索引
_, loss_idx = b.sort(dim=1, descending=True)
loss_idx
Out[176]:
tensor([[1, 3, 2, 4, 0],
[2, 1, 4, 0, 3]])
# 进行第二次的sort,得到原Tensor的元素按dim指定维度,排第几,索引变成了排名
_, idx_rank = loss_idx.sort(dim=1)
idx_rank
Out[178]:
tensor([[4, 0, 2, 1, 3],
[3, 1, 0, 4, 2]])
# 具体来说,可看原Tensor第一排的元素9,它是第一排(也就是按dim=1看)里面最大的,
# 所以它的排名是0,原Tensor第一排的元素4,它是第一排里面最小的,所以它的排名是4
# 当然这是以0-based的排名,且这里因为第一次sort是指定降序排列
可是这样计算到底有什么用呢?
可以看下面的这幅图,还是拿上面的例子:
图例
假设这里的b是两张图片负样本的loss,我通过难例挖掘,想要第一张图片中的loss前3的样本,想要第二张图片中的loss前4的样本,就可以通过上述最后的粉色命令得到一个掩码mask,如果我这时候运行:
num_neg = torch.tensor([[3],[4]])
num_neg
Out[180]:
tensor([[3],
[4]])
mask = idx_rank < num_neg.expand_as(idx_rank)
mask
Out[183]:
tensor([[0, 1, 1, 1, 0],
[1, 1, 1, 0, 1]], dtype=torch.uint8)
b[mask]
Out[185]: tensor([9, 7, 8, 3, 5, 9, 4])
# 9, 7, 8就是第一行中的loss的前3
# 3, 5, 9, 4就是第二行中的loss的前4
这样,我就得到我想要的了。为什么不直接用切片呢?因为每个图片我想要top_k的loss的数量不同,所以无法直接排序后用切片操作选取样本。当然如果要取倒数的k个,可以指定第一次sort为升序。
好了,就写到这里啦~~~