pytorch:learning rate
2021-01-21 本文已影响0人
wzNote
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
initial_lr = 0.1
class model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
def forward(self, x):
pass
net = model()
n_batch = 20
optimizer = torch.optim.Adam(net.parameters(), lr = initial_lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=n_batch, epochs=10)
print("初始化的学习率:", optimizer.defaults['lr'])
for epoch in range(1, 10):
# train
for iter in range(n_batch):
optimizer.zero_grad()
optimizer.step()
print("第%d个epoch的学习率:%f" % (epoch, optimizer.param_groups[0]['lr']))
scheduler.step()