Torch.gather()、Torch.cat()

2020-06-30  本文已影响0人  何哀何欢

数组维度和方向(连接)

gather() 函数是按照索引选取数字: 一个二维数组,如果沿第0维选取元素,则按照将头方向依次选取数字。0,2,1就是如图: 如果沿第1维选取元素,则按照将头方向依次选取数字。0,2,1就是如图:

cat()是用来连接多个tensor的:

T = torch.tensor( [ [ 1 ] ] )
print("[[1]]:", torch.cat( [ T, T, T ] ) )

[[1]]: tensor( [ [1], [1], [1] ] )

T = torch.tensor( [ 1 ] )
print("[1]:", torch.cat( [ T, T, T ] ) )

[1]: tensor([1, 1, 1])

这样不行:

torch.cat( [ 1, 1, 1 ] )

TypeError: expected Tensor as element 0 in argument 0, but got int

上一篇 下一篇

猜你喜欢

热点阅读