pytorch

2. pytorch-基本数据类型

2018-06-29  本文已影响786人  FantDing

1. Tensor

1.1. 如何生成Tensor

1.2. 数据类型转换

We DO not recommend double for performance, especially on the GPU. GPUs have bad double precision perf and are optimized for float32 performance.参考 即不推荐使用DoubleTensor

1.3. 属性

1.4. 方法

1.4.1. 数学操作

2. Variable

2.1. Variable生成

仅仅可以传入Tensor

2.2. 属性

2.3. 方法

3. 示例代码

import numpy as np
import torch
from torch.autograd import Variable

np_data=np.arange(4,dtype=np.float).reshape(2,2)+1
print(np_data)
# [[1. 2.]
#  [3. 4.]]
tensor=torch.from_numpy(np_data)
variable=Variable(tensor,requires_grad=True)
# variable=Variable(np_data,requires_grad=True) # 只能传入Tensor
t_out=variable*variable
out=torch.mean(t_out)
# 反响传播
out.backward()
v_grad=variable.grad
print(v_grad)
# tensor([[ 0.5000,  1.0000],
#         [ 1.5000,  2.0000]], dtype=torch.float64)
上一篇 下一篇

猜你喜欢

热点阅读