torch.cat()与torch.stack()函数
2022-11-28 本文已影响0人
午字横
torch.cat()
import torch
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape # torch.Size([2, 3])
inputs = [x1, x2]
print(inputs)
x=torch.cat(inputs,dim=0)
print(x.shape)
outputs = torch.cat(inputs, dim=?)
dim
代表在哪个维度上进行堆叠
inputs
: 待连接的张量序列,可以是任意相同Tensor类型的python 序列
dim
: 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。
torch.stack()
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(T1.shape)
print(T2.shape)
print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)
outputs = torch.stack(inputs, dim=?)
dim
代表要生成的维度是哪个
inputs
: 待连接的张量序列。
注:python的序列数据只有list
和tuple
。
dim
: 新的维度, 必须在0
到len(outputs)
之间。
注:len(outputs)
是生成数据的维度大小,也就是outputs
的维度值。