torch.cat

2021-02-04  本文已影响0人  三方斜阳

torch.cat 的作用是把两个 tensor 合并为一个 tensor
第一个参数是需要连接的tensor list , 第二个参数指定按照哪个维度进行拼接

import torch
A=torch.zeros(2,5) #2x5的张量(矩阵)                                     
print(A)
B=torch.ones(3,5)
print(B)
list=[]
list.append(A)
list.append(B)
C=torch.cat(list,dim=0)#按照行进行拼接,此时所有tensor的列数需要相同
print(C,C.shape)

>>
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]]) torch.Size([5, 5])
>>

按照列进行拼接

import torch
A=torch.zeros(2,5) #2x5的张量(矩阵)                                     
print(A)
B=torch.ones(2,5)
print(B)
list=[]
list.append(A)
list.append(B)

C=torch.cat(list,dim=1)#按照列进行拼接,此时的tensor 行数必须一致
#C=torch.cat((A,B),dim=1)
print(C,C.shape)
>>
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]]) torch.Size([2, 10])
>>

这个函数的应用是基础的,很多任务,例如要处理的数据是一行一行的,分别都转换为 tensor 之后,需要将全部的句子都拼接起来,然后再分成 batch 批量输入模型,所以需要用到 cat 的操作;

上一篇 下一篇

猜你喜欢

热点阅读