Pythorch源码

torch.size(), torch.view(),torch

2018-03-12  本文已影响3570人  吐舌小狗

博客参考

1.torch.size()的实质上是tuple

from __future__ import print_function
import torch

x = torch.Tensor(5, 3)
print(x.size()[0])  # Print the row dimension
print(x.size()[1])  # Print the column dimension

输出: 5 3
注: 可以使用tuple的基本操作,获取torch.size()的值

2.使用torch.view对tensor进行变形

from __future__ import print_function
import torch

x = torch.randn(8, 8)
z = x.view(-1, 4)
print(x.size(), z.size())

输出: (8L, 8L) (16L, 4L)

注:这里如果单独对x.view(-1,4)之后,x的shape是不发生任何变化的,需要重新赋值给另外的变量。

3.使用add_()对原始的tensor的每一个数值进行加的操作

from __future__ import print_function
import torch
x=torch.rand(5,3)
print(x)
print(x.add_(1))

输出:

 0.3659  0.1633  0.2380
 0.1108  0.9994  0.9355
 0.3237  0.9887  0.1847
 0.7046  0.0212  0.7640
 0.0731  0.2785  0.1676
[torch.FloatTensor of size 5x3]

 1.3659  1.1633  1.2380
 1.1108  1.9994  1.9355
 1.3237  1.9887  1.1847
 1.7046  1.0212  1.7640
 1.0731  1.2785  1.1676
[torch.FloatTensor of size 5x3]

注:torch.add_()这里加了下划线,表示是内建(in-place)函数, 将要改变x的值

一般来说函数加了下划线的属于内建函数,将要改变原来的值,没有加下划线的并不会改变原来的数据,引用时需要另外赋值给其他变量

上一篇下一篇

猜你喜欢

热点阅读