pytorch训练lenet网络mnist手写体

2024-01-05  本文已影响0人  一路向后

1.源码实现

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16*4*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = x.view(-1, 16*4*4)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

tran_dataset = datasets.MNIST('mnist/', download=False, train=True, transform=transform)
test_dataset = datasets.MNIST('mnist/', download=False, train=False, transform=transform)

train_dataloader = DataLoader(tran_dataset, batch_size=256, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

device = "cpu"

lenet = LeNet().to(device)

epochs = 1000
lr = 1e-4

optimizer = torch.optim.Adam(lenet.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

train_acc_list = []
test_acc_list = []
train_loss_list = []

for epoch in range(epochs):
    train_loss_epoch = []
    acc = 0
    loss = 1e-4

    for train_data, labels in tqdm(train_dataloader):
        train_data = train_data.to(device)
        labels = labels.to(device)

        y_hat = lenet(train_data)

        train_loss = loss_fn(y_hat, labels)

        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        train_loss_epoch.append(train_loss.cpu().detach().numpy())
        right = torch.argmax(y_hat, 1) == labels
        acc += right.sum().cpu().detach().numpy()

    acc = acc / len(tran_dataset)

    train_acc_list.append(acc)

    real_loss = sum(train_loss_epoch) / len(train_loss_epoch)

    train_loss_list.append(sum(train_loss_epoch) / len(train_loss_epoch))

    print(f'epoch:{epoch}, train_loss:{sum(train_loss_epoch) / len(train_loss_epoch)}')

    if real_loss < loss:
        break;

torch.save(lenet.state_dict(), "mnist.pth")

2.运行程序

$ python train.py

3.结果

运行后将得到保存后的模型文件mnist.pth

上一篇下一篇

猜你喜欢

热点阅读