pytorch:learning rate

2021-01-21  本文已影响0人  wzNote
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR

initial_lr = 0.1

class model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)

    def forward(self, x):
        pass

net = model()
n_batch = 20
optimizer = torch.optim.Adam(net.parameters(), lr = initial_lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=n_batch, epochs=10)

print("初始化的学习率:", optimizer.defaults['lr'])

for epoch in range(1, 10):
    # train
    for iter in range(n_batch):
        optimizer.zero_grad()
        optimizer.step()
        print("第%d个epoch的学习率:%f" % (epoch, optimizer.param_groups[0]['lr']))
        scheduler.step()
上一篇 下一篇

猜你喜欢

热点阅读