Pytorch中的Variable变量---自动求导

2019-01-03  本文已影响0人  spectre_hola

Variable和Tensor的区别是,Variable可以自动求导,它具有三个重要属性:data,grad,grad_fn。将一个Tensor a转化成Variable只需要Variable(a)就可以了。

grad_fn的意思是得到这个变量的操作,举个例子

x = torch.ones(2, 2, requires_grad=True)
y = x + 2
print(y.grad_fn)
#运行
<AddBackward object at 0x000001578C171C50>

意思是y是通过加法得到的

下面来说用变量求导的具体过程

#Create variable
x = Variable(torch.Tensor([1]), requires_grad = True)
w = Variable(torch.Tensor([2]), requires_grad = True)
b = Variable(torch.Tensor([3]), requires_grad = True)

#Build a comptation graph
y = w * x + b

#Compute gradients
y.backward()  #same as y.backward(torch.FloatTensor([1]))

#Print out the gradients
print(x.grad)    #x.grad = 2
print(w.grad)    #x.grad = 1
print(b.grad)    #x.grad = 1

注意其中一行y.backward(),这里它是等价于y.backward(torch.FloatTensor([1])),只是对于标量求导里面参数可以不写了,如果x是一个三维向量x = torch.randn(3),则该行需要换成y.backward(torch.FloatTensor([1, 1, 1]))才可以对应求导输出,或者y.backward(torch.FloatTensor([1, 0.1, 0.01]))是对应输出导数的对应1,0.1,0.01倍。

上一篇下一篇

猜你喜欢

热点阅读