pytorch学习(十八)—预训练模型微调
2019-01-08 本文已影响2人
侠之大者_7d3f
训练结果
image.png image.png image.png image.png image.png image.png image.png完整工程
-
工程目录结构
image.png -
代码
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import copy
# ---------------------------------------------------------
# 载入预训练的AlexNet模型
model = models.alexnet(pretrained=True)
# 修改输出层,2分类
model.classifier[6] = nn.Linear(in_features=4096, out_features=2)
# -------------------------数据集----------------------------------------------------
transform = transforms.Compose([transforms.Resize((227,227)),
transforms.ToTensor()])
train_dataset = ImageFolder(root='./data/train', transform=transform)
val_dataset = ImageFolder(root='./data/val', transform=transform)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, num_workers=4, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)
# ------------------优化方法,损失函数--------------------------------------------------
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fc = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, 20, 0.1)
# --------------------判断是否支持GPU--------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
# -------------------训练-------------------------------------------------------------
epoch_nums = 50
best_model_wts = model.state_dict()
best_acc = 0
for epoch in range(epoch_nums):
scheduler.step()
running_loss = 0.0
epoch_loss = 0.0
correct = 0
total = 0
for i, sample_batch in enumerate(train_dataloader):
inputs = sample_batch[0]
labels = sample_batch[1]
inputs.to(device)
labels.to(device)
model.train()
optimizer.zero_grad()
# forward
outputs = model(inputs)
# loss
loss = loss_fc(outputs, labels)
loss.backward()
optimizer.step()
#
running_loss += loss.item()
if i % 10 == 9:
correct = 0
total = 0
for images_test, labels_test in val_dataloader:
model.eval()
images_test = images_test.to(device)
labels_test = labels_test.to(device)
outputs_test = model(images_test)
_, prediction = torch.max(outputs_test, 1)
correct += ((prediction == labels_test).sum()).item()
total += labels_test.size(0)
accuracy = correct/total
print('[{}, {}] running loss={:.5f}, accuracy={:.5f}'.format(epoch + 1, i + 1, running_loss/10, accuracy))
running_loss = 0.0
if accuracy > best_acc:
best_acc = accuracy
best_model_wts = copy.deepcopy(model.state_dict())
print('Train finish')
torch.save(best_model_wts, './models/model_50.pth')