Pytorch学习笔记(8) Pytorch 常用的优化算法

2020-05-14  本文已影响0人  银色尘埃010

Pytorch中如何使用优化方法。
torch.optim是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

一、如何使用

首先看看最简单的使用方式

(1)定义优化算法和损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
loss_fn =  torch.nn.MSELoss()

(2) 三步走
optimizer.zero_grad()
loss = loss_fn(model(input), target)
loss..backward()
optimizer.step()

经过以上的基本步骤,实现了模型参数的更新。

仔细来看:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)
optim.SGD([
                {'params': model.base.parameters()},
                {'params': model.classifier.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

这意味着model.base的参数将会使用1e-2的学习率,model.classifier的参数将会使用1e-3的学习率,并且0.9的momentum将会被用于所 有的参数。

所有的optimizer都实现了step()方法,这个方法会更新所有的参数。它能按两种方式来使用:
optimizer.step()
这是大多数optimizer所支持的简化版本。一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数。

optimizer.step(closure)
一些优化算法例如Conjugate Gradient和LBFGS需要重复多次计算函数,因此你需要传入一个闭包去允许它们重新计算你的模型。这个闭包应当清空梯度, 计算损失,然后返回。

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

二、常见的优化算法

最后是所有优化函数的基类:

上一篇下一篇

猜你喜欢

热点阅读