pytorch学习笔记深度学习目标跟踪&&目标检测

pytorch学习(十八)—预训练模型微调

2019-01-08  本文已影响2人  侠之大者_7d3f

训练结果

image.png image.png image.png image.png image.png image.png image.png

完整工程

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import copy


# ---------------------------------------------------------
# 载入预训练的AlexNet模型
model = models.alexnet(pretrained=True)
# 修改输出层,2分类
model.classifier[6] = nn.Linear(in_features=4096, out_features=2)


# -------------------------数据集----------------------------------------------------

transform = transforms.Compose([transforms.Resize((227,227)),
                                transforms.ToTensor()])

train_dataset = ImageFolder(root='./data/train', transform=transform)
val_dataset = ImageFolder(root='./data/val', transform=transform)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, num_workers=4, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)


# ------------------优化方法,损失函数--------------------------------------------------
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fc = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, 20, 0.1)


# --------------------判断是否支持GPU--------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

# -------------------训练-------------------------------------------------------------

epoch_nums = 50
best_model_wts = model.state_dict()
best_acc = 0
for epoch in range(epoch_nums):
    scheduler.step()
    running_loss = 0.0
    epoch_loss = 0.0
    correct = 0
    total = 0

    for i, sample_batch in enumerate(train_dataloader):
        inputs = sample_batch[0]
        labels = sample_batch[1]

        inputs.to(device)
        labels.to(device)

        model.train()
        optimizer.zero_grad()
        # forward
        outputs = model(inputs)
        # loss
        loss = loss_fc(outputs, labels)

        loss.backward()
        optimizer.step()

        #
        running_loss += loss.item()
        if i % 10 == 9:
            correct = 0
            total = 0
            for images_test, labels_test in val_dataloader:
                model.eval()
                images_test = images_test.to(device)
                labels_test = labels_test.to(device)
                outputs_test = model(images_test)
                _, prediction = torch.max(outputs_test, 1)
                correct += ((prediction == labels_test).sum()).item()
                total += labels_test.size(0)
            accuracy = correct/total
            print('[{}, {}] running loss={:.5f}, accuracy={:.5f}'.format(epoch + 1, i + 1, running_loss/10, accuracy))
            running_loss = 0.0
            if accuracy > best_acc:
                best_acc = accuracy
                best_model_wts = copy.deepcopy(model.state_dict())


print('Train finish')
torch.save(best_model_wts, './models/model_50.pth')
上一篇下一篇

猜你喜欢

热点阅读