pytorchLinux学习|Gentoo/Arch/FreeBSD

torch.cat()与torch.stack()函数

2022-11-28  本文已影响0人  午字横

torch.cat()

import torch

x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape  # torch.Size([2, 3])

inputs = [x1, x2]
print(inputs)

x=torch.cat(inputs,dim=0)
print(x.shape)

outputs = torch.cat(inputs, dim=?)
dim代表在哪个维度上进行堆叠
inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列
dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。

torch.stack()

T1 = torch.tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])

T2 = torch.tensor([[10, 20, 30],
                [40, 50, 60],
                [70, 80, 90]])
print(T1.shape)
print(T2.shape)
print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)

outputs = torch.stack(inputs, dim=?)
dim代表要生成的维度是哪个

inputs: 待连接的张量序列。
注:python的序列数据只有listtuple

dim : 新的维度, 必须在0len(outputs)之间。
注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。

上一篇下一篇

猜你喜欢

热点阅读