Pytorch笔记7-训练、验证、测试模型

2024-07-18  本文已影响0人  江湾青年

训练

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()  # 切换模型到训练模式
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # 后向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 计算batch内损失
        running_loss += loss.item() * inputs.size(0)
    # 计算epoch内损失
    epoch_loss = running_loss / len(train_loader.dataset)
    return epoch_loss

enumerate()

for batch_idx, batch_data in enumerate(train_loader):
    # 将数据移动到GPU
    inputs, labels = batch_data
    inputs, labels = inputs.to(device), labels.to(device)
    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    # 后向传播
    optimizer.zero_grad()  # 清零所有参数的梯度
    loss.backward()        # 计算梯度
    optimizer.step()       # 更新参数
    # 使用batch_idx
    if batch_idx % 10 == 0:  # 每10个批次打印一次损失
        print(f'Batch [{batch_idx}], Loss: {loss.item():.4f}')

验证

# 定义验证函数
def validate_one_epoch(model, valid_loader, criterion, device):
    model.eval()  # 切换到评估模式
    running_loss = 0.0
    # 在验证过程中不需要计算梯度
    with torch.no_grad():
        for inputs, labels in valid_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)    # 计算平均损失
    epoch_loss = running_loss / len(valid_loader.dataset)
    return epoch_loss

在每个epoch中进行训练+验证

num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    valid_loss = validate_one_epoch(model, valid_loader, criterion, device)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}')

测试(推理)

使用训练好的模型进行推理,其实validation部分就是推理,因此代码和validate_one_epoch比较类似

# 设置模型为评估模式
model.eval()
# 进行推理
with torch.no_grad():  # 在推理过程中不需要计算梯度
    outputs = model(new_inputs)
# 输出结果
print(outputs)
上一篇下一篇

猜你喜欢

热点阅读