PyTrch深度学习简明实战10 - ResNet18

2023-03-24  本文已影响0人  薛东弗斯

Resnet18 模型架构

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

使用预训练模型,一般是改变输出。默认模型是基于ImageNet数据集,1000个分类。 与vgg不同的地方是,在最后的fc层之前全部都是Conv/Relu层,最后输出经过avgpool 全局平均池化。fc实际上是Linear层。
训练中不改变预训练模型的梯度

修改模型代码

model = torchvision.models.resnet18(pretrained=True)

for param in model.parameters():
    param.requires_grad = False  # 不可训练
    
in_f = model.fc.in_features # 此时,连接最后Linear的in_features 需要保持不变,512
model.fc = nn.Linear(in_f,4)  # 用新建的Linear替换掉模型里面的Linear,输入512 feature,输出4,新建的Linear层默认requires_grad = True, 可训练

改变输出层后,模型结构

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=4, bias=True)
)

完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
from torchvision import transforms
import os

base_dir = r'./data/4_weather'
train_dir = os.path.join(base_dir,'train')
test_dir = os.path.join(base_dir,'test')

transform = transforms.Compose([
    transforms.Resize((192,192)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])

train_ds = torchvision.datasets.ImageFolder(train_dir,transform=transform)
test_ds = torchvision.datasets.ImageFolder(test_dir,transform=transform)

BATCH_SIZE=32

train_dl = torch.utils.data.DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=BATCH_SIZE)

model = torchvision.models.resnet18(pretrained=True)

for param in model.parameters():
    param.requires_grad = False  # 不可训练
    
in_f = model.fc.in_features # 此时,连接最后Linear的in_features 需要保持不变,512
model.fc = nn.Linear(in_f,4)  # 用新建的Linear替换掉模型里面的Linear,输入512 feature,输出4,新建的Linear层默认requires_grad = True, 可训练

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
    
optimizer = torch.optim.Adam(model.fc.parameters(),lr=0.0001)  # 该模型,需要优化的只有最后的全连接层 nn.Linear(in_f,4)这个参数,其它参数都是不可变的。 因为param.requires_grad = False
loss_fn = nn.CrossEntropyLoss()
# 使用pytorch内置的学习率衰减函数
from torch.optim import lr_scheduler
exp_lr_scheduler = lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.9) # 每隔5步衰减0.9

# fit函数必须指定模型的模式是训练模式,还是预测模式。 因为有BN层,BN层在不同模式表现不同。
def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    model.train()
    for x, y in trainloader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim = 1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()

    exp_lr_scheduler.step()    # 规定一个step为一个epoch
    
    epoch_acc = correct / total
    epoch_loss = running_loss / len(trainloader.dataset)
    
    test_correct = 0
    test_total = 0
    test_running_loss = 0
    
    model.eval()
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim = 1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
    epoch_test_acc = test_correct / test_total
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    
    print('epoch: ', epoch, 
          'loss: ', round(epoch_loss, 3),
          'accuracy: ', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy: ', round(epoch_test_acc, 3))
    
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

epochs = 50

train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

训练结果

Using cpu device
epoch:  0 loss:  0.045 accuracy:  0.294 test_loss:  0.046 test_accuracy:  0.333
epoch:  1 loss:  0.041 accuracy:  0.399 test_loss:  0.043 test_accuracy:  0.449
epoch:  2 loss:  0.038 accuracy:  0.509 test_loss:  0.039 test_accuracy:  0.569
epoch:  3 loss:  0.035 accuracy:  0.58 test_loss:  0.036 test_accuracy:  0.658
epoch:  4 loss:  0.033 accuracy:  0.656 test_loss:  0.033 test_accuracy:  0.689
epoch:  5 loss:  0.03 accuracy:  0.732 test_loss:  0.03 test_accuracy:  0.747
epoch:  6 loss:  0.028 accuracy:  0.744 test_loss:  0.029 test_accuracy:  0.782
epoch:  7 loss:  0.026 accuracy:  0.786 test_loss:  0.027 test_accuracy:  0.813
epoch:  8 loss:  0.025 accuracy:  0.824 test_loss:  0.025 test_accuracy:  0.853
epoch:  9 loss:  0.024 accuracy:  0.828 test_loss:  0.024 test_accuracy:  0.862
epoch:  10 loss:  0.023 accuracy:  0.841 test_loss:  0.023 test_accuracy:  0.849
epoch:  11 loss:  0.021 accuracy:  0.858 test_loss:  0.021 test_accuracy:  0.893
epoch:  12 loss:  0.021 accuracy:  0.866 test_loss:  0.021 test_accuracy:  0.893
epoch:  13 loss:  0.02 accuracy:  0.869 test_loss:  0.02 test_accuracy:  0.898
epoch:  14 loss:  0.019 accuracy:  0.874 test_loss:  0.019 test_accuracy:  0.907
epoch:  15 loss:  0.018 accuracy:  0.901 test_loss:  0.018 test_accuracy:  0.898
epoch:  16 loss:  0.018 accuracy:  0.892 test_loss:  0.018 test_accuracy:  0.902
epoch:  17 loss:  0.018 accuracy:  0.903 test_loss:  0.018 test_accuracy:  0.911
epoch:  18 loss:  0.017 accuracy:  0.897 test_loss:  0.016 test_accuracy:  0.907
epoch:  19 loss:  0.017 accuracy:  0.917 test_loss:  0.016 test_accuracy:  0.911
epoch:  20 loss:  0.016 accuracy:  0.892 test_loss:  0.015 test_accuracy:  0.92
epoch:  21 loss:  0.016 accuracy:  0.908 test_loss:  0.015 test_accuracy:  0.911
epoch:  22 loss:  0.015 accuracy:  0.913 test_loss:  0.015 test_accuracy:  0.916
epoch:  23 loss:  0.015 accuracy:  0.914 test_loss:  0.015 test_accuracy:  0.916
epoch:  24 loss:  0.015 accuracy:  0.91 test_loss:  0.014 test_accuracy:  0.92
epoch:  25 loss:  0.014 accuracy:  0.919 test_loss:  0.014 test_accuracy:  0.92
epoch:  26 loss:  0.014 accuracy:  0.934 test_loss:  0.014 test_accuracy:  0.916
epoch:  27 loss:  0.014 accuracy:  0.912 test_loss:  0.013 test_accuracy:  0.924
epoch:  28 loss:  0.015 accuracy:  0.933 test_loss:  0.013 test_accuracy:  0.92
epoch:  29 loss:  0.013 accuracy:  0.93 test_loss:  0.013 test_accuracy:  0.929
epoch:  30 loss:  0.013 accuracy:  0.926 test_loss:  0.013 test_accuracy:  0.933
epoch:  31 loss:  0.013 accuracy:  0.927 test_loss:  0.013 test_accuracy:  0.916
epoch:  32 loss:  0.013 accuracy:  0.923 test_loss:  0.012 test_accuracy:  0.924
epoch:  33 loss:  0.012 accuracy:  0.931 test_loss:  0.012 test_accuracy:  0.924
epoch:  34 loss:  0.012 accuracy:  0.931 test_loss:  0.012 test_accuracy:  0.929
epoch:  35 loss:  0.013 accuracy:  0.931 test_loss:  0.012 test_accuracy:  0.933
epoch:  36 loss:  0.012 accuracy:  0.944 test_loss:  0.012 test_accuracy:  0.929
epoch:  37 loss:  0.012 accuracy:  0.943 test_loss:  0.011 test_accuracy:  0.924
epoch:  38 loss:  0.012 accuracy:  0.921 test_loss:  0.012 test_accuracy:  0.933
epoch:  39 loss:  0.011 accuracy:  0.936 test_loss:  0.012 test_accuracy:  0.933
epoch:  40 loss:  0.011 accuracy:  0.929 test_loss:  0.011 test_accuracy:  0.933
epoch:  41 loss:  0.012 accuracy:  0.943 test_loss:  0.011 test_accuracy:  0.929
epoch:  42 loss:  0.011 accuracy:  0.946 test_loss:  0.011 test_accuracy:  0.933
epoch:  43 loss:  0.011 accuracy:  0.937 test_loss:  0.011 test_accuracy:  0.938
epoch:  44 loss:  0.01 accuracy:  0.943 test_loss:  0.011 test_accuracy:  0.938
epoch:  45 loss:  0.011 accuracy:  0.942 test_loss:  0.011 test_accuracy:  0.933
epoch:  46 loss:  0.01 accuracy:  0.938 test_loss:  0.011 test_accuracy:  0.933
epoch:  47 loss:  0.011 accuracy:  0.936 test_loss:  0.011 test_accuracy:  0.933
epoch:  48 loss:  0.01 accuracy:  0.948 test_loss:  0.01 test_accuracy:  0.933
epoch:  49 loss:  0.011 accuracy:  0.946 test_loss:  0.011 test_accuracy:  0.938

发现过拟合程度优于vgg

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torchvision
from torchvision import transforms
import os

base_dir = r'./data/4weather'
train_dir = os.path.join(base_dir , 'train')
test_dir = os.path.join(base_dir , 'test')   

transform = transforms.Compose([
    transforms.Resize((192, 192)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

train_ds =  torchvision.datasets.ImageFolder(
        train_dir,
        transform=transform
    )
    
test_ds =  torchvision.datasets.ImageFolder(
        test_dir,
        transform=transform
    )

BTACH_SIZE = 32

train_dl = torch.utils.data.DataLoader(
                            train_ds,
                            batch_size=BTACH_SIZE,
                            shuffle=True
)

test_dl = torch.utils.data.DataLoader(
                            test_ds,
                            batch_size=BTACH_SIZE,
)

model = torchvision.models.resnet18(pretrained=True)

for param in model.parameters():
    param.requires_grad = False
    
in_f = model.fc.in_features
model.fc = nn.Linear(in_f, 4)

if torch.cuda.is_available():
    model.to('cuda')
    
loss_fn = nn.CrossEntropyLoss()

# Decay LR by a factor of 0.1 every 7 epochs
#from torch.optim import lr_scheduler
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.0001)
#exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    model.train()
    for x, y in trainloader:
        if torch.cuda.is_available():
            x, y = x.to('cuda'), y.to('cuda')
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
#    exp_lr_scheduler.step()
    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / total
        
        
    test_correct = 0
    test_total = 0
    test_running_loss = 0 
    
    model.eval()
    with torch.no_grad():
        for x, y in testloader:
            if torch.cuda.is_available():
                x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
    
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc = test_correct / test_total
    
        
    print('epoch: ', epoch, 
          'loss: ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3)
             )
        
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
    
epochs = 50

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                 model,
                                                                 train_dl,
                                                                 test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
    

plt.plot(range(1, epochs+1), train_loss, label='train_loss')
plt.plot(range(1, epochs+1), test_loss, label='test_loss')
plt.legend()
image.png
plt.plot(range(1, epochs+1), train_acc, label='train_acc')
plt.plot(range(1, epochs+1), test_acc, label='test_acc')
plt.legend()
image.png
上一篇下一篇

猜你喜欢

热点阅读