Torch.gather()、Torch.cat()
2020-06-30 本文已影响0人
何哀何欢
gather() 函数是按照索引选取数字:
一个二维数组,如果沿第0维选取元素,则按照将头方向依次选取数字。0,2,1就是如图:
如果沿第1维选取元素,则按照将头方向依次选取数字。0,2,1就是如图:



cat()是用来连接多个tensor的:
T = torch.tensor( [ [ 1 ] ] ) print("[[1]]:", torch.cat( [ T, T, T ] ) )
[[1]]: tensor( [ [1], [1], [1] ] )
T = torch.tensor( [ 1 ] ) print("[1]:", torch.cat( [ T, T, T ] ) )
[1]: tensor([1, 1, 1])
这样不行:
torch.cat( [ 1, 1, 1 ] )
TypeError: expected Tensor as element 0 in argument 0, but got int