PyTrch深度学习简明实战10 - ResNet18
2023-03-24 本文已影响0人
薛东弗斯
Resnet18 模型架构
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
使用预训练模型,一般是改变输出。默认模型是基于ImageNet数据集,1000个分类。 与vgg不同的地方是,在最后的fc层之前全部都是Conv/Relu层,最后输出经过avgpool 全局平均池化。fc实际上是Linear层。
训练中不改变预训练模型的梯度
修改模型代码
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False # 不可训练
in_f = model.fc.in_features # 此时,连接最后Linear的in_features 需要保持不变,512
model.fc = nn.Linear(in_f,4) # 用新建的Linear替换掉模型里面的Linear,输入512 feature,输出4,新建的Linear层默认requires_grad = True, 可训练
改变输出层后,模型结构
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=4, bias=True)
)
完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
from torchvision import transforms
import os
base_dir = r'./data/4_weather'
train_dir = os.path.join(base_dir,'train')
test_dir = os.path.join(base_dir,'test')
transform = transforms.Compose([
transforms.Resize((192,192)),
transforms.ToTensor(),
transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])
train_ds = torchvision.datasets.ImageFolder(train_dir,transform=transform)
test_ds = torchvision.datasets.ImageFolder(test_dir,transform=transform)
BATCH_SIZE=32
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=BATCH_SIZE)
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False # 不可训练
in_f = model.fc.in_features # 此时,连接最后Linear的in_features 需要保持不变,512
model.fc = nn.Linear(in_f,4) # 用新建的Linear替换掉模型里面的Linear,输入512 feature,输出4,新建的Linear层默认requires_grad = True, 可训练
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
optimizer = torch.optim.Adam(model.fc.parameters(),lr=0.0001) # 该模型,需要优化的只有最后的全连接层 nn.Linear(in_f,4)这个参数,其它参数都是不可变的。 因为param.requires_grad = False
loss_fn = nn.CrossEntropyLoss()
# 使用pytorch内置的学习率衰减函数
from torch.optim import lr_scheduler
exp_lr_scheduler = lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.9) # 每隔5步衰减0.9
# fit函数必须指定模型的模式是训练模式,还是预测模式。 因为有BN层,BN层在不同模式表现不同。
def fit(epoch, model, trainloader, testloader):
correct = 0
total = 0
running_loss = 0
model.train()
for x, y in trainloader:
x, y = x.to(device), y.to(device)
y_pred = model(x)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
y_pred = torch.argmax(y_pred, dim = 1)
correct += (y_pred == y).sum().item()
total += y.size(0)
running_loss += loss.item()
exp_lr_scheduler.step() # 规定一个step为一个epoch
epoch_acc = correct / total
epoch_loss = running_loss / len(trainloader.dataset)
test_correct = 0
test_total = 0
test_running_loss = 0
model.eval()
with torch.no_grad():
for x, y in testloader:
x, y = x.to(device), y.to(device)
y_pred = model(x)
loss = loss_fn(y_pred, y)
y_pred = torch.argmax(y_pred, dim = 1)
test_correct += (y_pred == y).sum().item()
test_total += y.size(0)
test_running_loss += loss.item()
epoch_test_acc = test_correct / test_total
epoch_test_loss = test_running_loss / len(testloader.dataset)
print('epoch: ', epoch,
'loss: ', round(epoch_loss, 3),
'accuracy: ', round(epoch_acc, 3),
'test_loss: ', round(epoch_test_loss, 3),
'test_accuracy: ', round(epoch_test_acc, 3))
return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
训练结果
Using cpu device
epoch: 0 loss: 0.045 accuracy: 0.294 test_loss: 0.046 test_accuracy: 0.333
epoch: 1 loss: 0.041 accuracy: 0.399 test_loss: 0.043 test_accuracy: 0.449
epoch: 2 loss: 0.038 accuracy: 0.509 test_loss: 0.039 test_accuracy: 0.569
epoch: 3 loss: 0.035 accuracy: 0.58 test_loss: 0.036 test_accuracy: 0.658
epoch: 4 loss: 0.033 accuracy: 0.656 test_loss: 0.033 test_accuracy: 0.689
epoch: 5 loss: 0.03 accuracy: 0.732 test_loss: 0.03 test_accuracy: 0.747
epoch: 6 loss: 0.028 accuracy: 0.744 test_loss: 0.029 test_accuracy: 0.782
epoch: 7 loss: 0.026 accuracy: 0.786 test_loss: 0.027 test_accuracy: 0.813
epoch: 8 loss: 0.025 accuracy: 0.824 test_loss: 0.025 test_accuracy: 0.853
epoch: 9 loss: 0.024 accuracy: 0.828 test_loss: 0.024 test_accuracy: 0.862
epoch: 10 loss: 0.023 accuracy: 0.841 test_loss: 0.023 test_accuracy: 0.849
epoch: 11 loss: 0.021 accuracy: 0.858 test_loss: 0.021 test_accuracy: 0.893
epoch: 12 loss: 0.021 accuracy: 0.866 test_loss: 0.021 test_accuracy: 0.893
epoch: 13 loss: 0.02 accuracy: 0.869 test_loss: 0.02 test_accuracy: 0.898
epoch: 14 loss: 0.019 accuracy: 0.874 test_loss: 0.019 test_accuracy: 0.907
epoch: 15 loss: 0.018 accuracy: 0.901 test_loss: 0.018 test_accuracy: 0.898
epoch: 16 loss: 0.018 accuracy: 0.892 test_loss: 0.018 test_accuracy: 0.902
epoch: 17 loss: 0.018 accuracy: 0.903 test_loss: 0.018 test_accuracy: 0.911
epoch: 18 loss: 0.017 accuracy: 0.897 test_loss: 0.016 test_accuracy: 0.907
epoch: 19 loss: 0.017 accuracy: 0.917 test_loss: 0.016 test_accuracy: 0.911
epoch: 20 loss: 0.016 accuracy: 0.892 test_loss: 0.015 test_accuracy: 0.92
epoch: 21 loss: 0.016 accuracy: 0.908 test_loss: 0.015 test_accuracy: 0.911
epoch: 22 loss: 0.015 accuracy: 0.913 test_loss: 0.015 test_accuracy: 0.916
epoch: 23 loss: 0.015 accuracy: 0.914 test_loss: 0.015 test_accuracy: 0.916
epoch: 24 loss: 0.015 accuracy: 0.91 test_loss: 0.014 test_accuracy: 0.92
epoch: 25 loss: 0.014 accuracy: 0.919 test_loss: 0.014 test_accuracy: 0.92
epoch: 26 loss: 0.014 accuracy: 0.934 test_loss: 0.014 test_accuracy: 0.916
epoch: 27 loss: 0.014 accuracy: 0.912 test_loss: 0.013 test_accuracy: 0.924
epoch: 28 loss: 0.015 accuracy: 0.933 test_loss: 0.013 test_accuracy: 0.92
epoch: 29 loss: 0.013 accuracy: 0.93 test_loss: 0.013 test_accuracy: 0.929
epoch: 30 loss: 0.013 accuracy: 0.926 test_loss: 0.013 test_accuracy: 0.933
epoch: 31 loss: 0.013 accuracy: 0.927 test_loss: 0.013 test_accuracy: 0.916
epoch: 32 loss: 0.013 accuracy: 0.923 test_loss: 0.012 test_accuracy: 0.924
epoch: 33 loss: 0.012 accuracy: 0.931 test_loss: 0.012 test_accuracy: 0.924
epoch: 34 loss: 0.012 accuracy: 0.931 test_loss: 0.012 test_accuracy: 0.929
epoch: 35 loss: 0.013 accuracy: 0.931 test_loss: 0.012 test_accuracy: 0.933
epoch: 36 loss: 0.012 accuracy: 0.944 test_loss: 0.012 test_accuracy: 0.929
epoch: 37 loss: 0.012 accuracy: 0.943 test_loss: 0.011 test_accuracy: 0.924
epoch: 38 loss: 0.012 accuracy: 0.921 test_loss: 0.012 test_accuracy: 0.933
epoch: 39 loss: 0.011 accuracy: 0.936 test_loss: 0.012 test_accuracy: 0.933
epoch: 40 loss: 0.011 accuracy: 0.929 test_loss: 0.011 test_accuracy: 0.933
epoch: 41 loss: 0.012 accuracy: 0.943 test_loss: 0.011 test_accuracy: 0.929
epoch: 42 loss: 0.011 accuracy: 0.946 test_loss: 0.011 test_accuracy: 0.933
epoch: 43 loss: 0.011 accuracy: 0.937 test_loss: 0.011 test_accuracy: 0.938
epoch: 44 loss: 0.01 accuracy: 0.943 test_loss: 0.011 test_accuracy: 0.938
epoch: 45 loss: 0.011 accuracy: 0.942 test_loss: 0.011 test_accuracy: 0.933
epoch: 46 loss: 0.01 accuracy: 0.938 test_loss: 0.011 test_accuracy: 0.933
epoch: 47 loss: 0.011 accuracy: 0.936 test_loss: 0.011 test_accuracy: 0.933
epoch: 48 loss: 0.01 accuracy: 0.948 test_loss: 0.01 test_accuracy: 0.933
epoch: 49 loss: 0.011 accuracy: 0.946 test_loss: 0.011 test_accuracy: 0.938
发现过拟合程度优于vgg
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
from torchvision import transforms
import os
base_dir = r'./data/4weather'
train_dir = os.path.join(base_dir , 'train')
test_dir = os.path.join(base_dir , 'test')
transform = transforms.Compose([
transforms.Resize((192, 192)),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
train_ds = torchvision.datasets.ImageFolder(
train_dir,
transform=transform
)
test_ds = torchvision.datasets.ImageFolder(
test_dir,
transform=transform
)
BTACH_SIZE = 32
train_dl = torch.utils.data.DataLoader(
train_ds,
batch_size=BTACH_SIZE,
shuffle=True
)
test_dl = torch.utils.data.DataLoader(
test_ds,
batch_size=BTACH_SIZE,
)
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
in_f = model.fc.in_features
model.fc = nn.Linear(in_f, 4)
if torch.cuda.is_available():
model.to('cuda')
loss_fn = nn.CrossEntropyLoss()
# Decay LR by a factor of 0.1 every 7 epochs
#from torch.optim import lr_scheduler
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.0001)
#exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
def fit(epoch, model, trainloader, testloader):
correct = 0
total = 0
running_loss = 0
model.train()
for x, y in trainloader:
if torch.cuda.is_available():
x, y = x.to('cuda'), y.to('cuda')
y_pred = model(x)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
y_pred = torch.argmax(y_pred, dim=1)
correct += (y_pred == y).sum().item()
total += y.size(0)
running_loss += loss.item()
# exp_lr_scheduler.step()
epoch_loss = running_loss / len(trainloader.dataset)
epoch_acc = correct / total
test_correct = 0
test_total = 0
test_running_loss = 0
model.eval()
with torch.no_grad():
for x, y in testloader:
if torch.cuda.is_available():
x, y = x.to('cuda'), y.to('cuda')
y_pred = model(x)
loss = loss_fn(y_pred, y)
y_pred = torch.argmax(y_pred, dim=1)
test_correct += (y_pred == y).sum().item()
test_total += y.size(0)
test_running_loss += loss.item()
epoch_test_loss = test_running_loss / len(testloader.dataset)
epoch_test_acc = test_correct / test_total
print('epoch: ', epoch,
'loss: ', round(epoch_loss, 3),
'accuracy:', round(epoch_acc, 3),
'test_loss: ', round(epoch_test_loss, 3),
'test_accuracy:', round(epoch_test_acc, 3)
)
return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
model,
train_dl,
test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
plt.plot(range(1, epochs+1), train_loss, label='train_loss')
plt.plot(range(1, epochs+1), test_loss, label='test_loss')
plt.legend()
image.png
plt.plot(range(1, epochs+1), train_acc, label='train_acc')
plt.plot(range(1, epochs+1), test_acc, label='test_acc')
plt.legend()
image.png