pytorch学习笔记PytorchPyTorch

pytorch学习(十三)—学习率调整策略

2019-01-02  本文已影响0人  侠之大者_7d3f

学习率

学习速率(learning rate)是指导我们该如何通过损失函数的梯度调整网络权重的超参数。学习率越低,损失函数的变化速度就越慢。虽然使用低学习率可以确保我们不会错过任何局部极小值,但也意味着我们将花费更长的时间来进行收敛,特别是在被困在高原区域的情况下。

new_weight = existing_weight — learning_rate * gradient


image.png

图1采用较小的学习率,梯度下降的速度慢;
图2采用较大的学习率,梯度下降太快越过了最小值点,导致不收敛,甚至震荡。

image.png

目的


测试环境


实验/测试

pytorch中相关的API

关于学习率调整,pytorch提供了torch.optim.lr_scheduler

image.png

主要提供了几个类:

1. torch.optim.lr_scheduler.StepLR

import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import AlexNet
import matplotlib.pyplot as plt


model = AlexNet(num_classes=2)
optimizer = optim.SGD(params=model.parameters(), lr=0.05)

# lr_scheduler.StepLR()
# Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.05     if epoch < 30
# lr = 0.005    if 30 <= epoch < 60
# lr = 0.0005   if 60 <= epoch < 90

scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
plt.figure()
x = list(range(100))
y = []
for epoch in range(100):
    scheduler.step()
    lr = scheduler.get_lr()
    print(epoch, scheduler.get_lr()[0])
    y.append(scheduler.get_lr()[0])

plt.plot(x, y)

image.png

0<epoch<30, lr = 0.05
30<=epoch<60, lr = 0.005
60<=epoch<90, lr = 0.0005

torch.optim.lr_scheduler.MultiStepLR

StepLR相比,MultiStepLR可以设置指定的区间

# ---------------------------------------------------------------
# 可以指定区间
# lr_scheduler.MultiStepLR()
#  Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.05     if epoch < 30
# lr = 0.005    if 30 <= epoch < 80
#  lr = 0.0005   if epoch >= 80
print()
plt.figure()
y.clear()
scheduler = lr_scheduler.MultiStepLR(optimizer, [30, 80], 0.1)
for epoch in range(100):
    scheduler.step()
    print(epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
    y.append(scheduler.get_lr()[0])

plt.plot(x, y)
plt.show()
image.png

torch.optim.lr_scheduler.ExponentialLR

指数衰减

scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
print()
plt.figure()
y.clear()
for epoch in range(100):
    scheduler.step()
    print(epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
    y.append(scheduler.get_lr()[0])

plt.plot(x, y)
plt.show()
image.png

End

参考:
https://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate

上一篇下一篇

猜你喜欢

热点阅读