CNN训练循环

2019-08-04  本文已影响0人  钢笔先生

Time: 2019-08-04

循环一个epoch

# 循环一个batch
network = Network()

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01)

total_loss = 0.0
total_correct = 0

for batch in train_loader:
  images, labels = batch

  preds = network(images) # 传入一个batch
  loss = F.cross_entropy(preds, labels) # 计算损失函数
  
  # 计算前需要先使得梯度为0
  optimizer.zero_grad()

  loss.backward() # 计算梯度
  optimizer.step() # 一个批次更新参数

  total_loss += loss.item()
  total_correct += get_num_correct(preds, labels)
  
print("epoch: ", 0, "total_correct: ", total_correct,  "loss: ", total_loss)

训练多个epoch

for epoch in range(6):
  # 循环一个batch
  network = Network()

  train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
  optimizer = optim.Adam(network.parameters(), lr=0.01)

  total_loss = 0.0
  total_correct = 0

  for batch in train_loader:
    images, labels = batch

    preds = network(images) # 传入一个batch
    loss = F.cross_entropy(preds, labels) # 计算损失函数

    # 计算前需要先使得梯度为0
    optimizer.zero_grad()

    loss.backward() # 计算梯度
    optimizer.step() # 一个批次更新参数

    total_loss += loss.item()
    total_correct += get_num_correct(preds, labels)

  print("epoch: ", epoch, "total_correct: ", total_correct,  "loss: ", total_loss)

END.

上一篇 下一篇

猜你喜欢

热点阅读