Pytorch中的Variable变量---自动求导
2019-01-03 本文已影响0人
spectre_hola
Variable和Tensor的区别是,Variable可以自动求导,它具有三个重要属性:data,grad,grad_fn。将一个Tensor a转化成Variable只需要Variable(a)就可以了。
grad_fn的意思是得到这个变量的操作,举个例子
x = torch.ones(2, 2, requires_grad=True)
y = x + 2
print(y.grad_fn)
#运行
<AddBackward object at 0x000001578C171C50>
意思是y是通过加法得到的
下面来说用变量求导的具体过程
#Create variable
x = Variable(torch.Tensor([1]), requires_grad = True)
w = Variable(torch.Tensor([2]), requires_grad = True)
b = Variable(torch.Tensor([3]), requires_grad = True)
#Build a comptation graph
y = w * x + b
#Compute gradients
y.backward() #same as y.backward(torch.FloatTensor([1]))
#Print out the gradients
print(x.grad) #x.grad = 2
print(w.grad) #x.grad = 1
print(b.grad) #x.grad = 1
注意其中一行y.backward(),这里它是等价于y.backward(torch.FloatTensor([1])),只是对于标量求导里面参数可以不写了,如果x是一个三维向量x = torch.randn(3),则该行需要换成y.backward(torch.FloatTensor([1, 1, 1]))才可以对应求导输出,或者y.backward(torch.FloatTensor([1, 0.1, 0.01]))是对应输出导数的对应1,0.1,0.01倍。