Pytorch中的view()和reshape()有何不同?

2023-01-02  本文已影响0人  LabVIEW_Python

Pytorch中的view()和reshape()的功能都是reshape tensor:

import torch
x = torch.arange(10)

x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)

其区别是:

import torch
x = torch.arange(10)
# contiguous memory
x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)

# noncontiguous memory
y = x_2x5.t()
y_1x10 = y.view(10)

报错信息:

y_1x10 = y.view(10)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

解决方式:用Tensor.contiguous(memory_format=torch.contiguous_format) → Tensor
方法,将noncontiguous memory变成contiguous memory,然后再用view()

import torch
x = torch.arange(10)
# contiguous memory
x_2x5 = x.view(2, 5)
print(x_2x5)
x_5x2 = x.reshape(5, 2)
print(x_5x2)

# noncontiguous memory
y = x_2x5.t()
y_1x10 = y.contiguous().view(10)
print(y_1x10.shape)

执行结果:

torch.Size([10])

上一篇 下一篇

猜你喜欢

热点阅读