深度学习代码示例-图像分类

2021-03-09  本文已影响0人  邯山之郸

源代码:

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import torch.nn.functional as F
import copy
import torch.optim as optim
plt.ion()   # interactive mode
os.chdir(r"D:\箕斗\神经网络") # 当前文件夹


# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'Train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'Test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'CNNDemoB' # 数据所在文件夹
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['Train', 'Test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=0)
              for x in ['Train', 'Test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['Train', 'Test']}
class_names = image_datasets['Train'].classes

# 展示部分图片
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(10)  # pause a bit so that plots are updated
# 搭建网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 53 * 53, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        # print("x.shape",x.shape)
        x = x.view(-1, 16 * 53 * 53)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

# 测试
def Data_Test():
        net = Net()
        net.load_state_dict(torch.load(r"models\net.pth"))
        correct = 0
        total = 0
        for i, data in enumerate(dataloaders['Test'], 0):
            # get the inputs; data is a list of [inputs, labels]
            
            inputs, labels = data

            # zero the parameter gradients
            # optimizer.zero_grad()

            # forward + backward + optimize
            t1=time.time()
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data,dim = 1)
            print("time:",time.time()-t1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
        print("Accuracy on test set:",round((100*correct)/total,2),"%")
# 训练
def Data_Train():
    net = Net()
    # print("权重初始值:",net.state_dict())
    net.load_state_dict(torch.load(r"models\net.pth"))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0)

    for epoch in range(20):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(dataloaders['Train'], 0):
            # get the inputs; data is a list of [inputs, labels]
            
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            # print(running_loss)
            if i % 20 == 19:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                    (epoch + 1, i + 1, running_loss / 20))
                running_loss = 0.0

    print('Finished Training')
    # 保存权重数据
    torch.save(obj=net.state_dict(), f=r"models\net.pth")
    print("权重保存:",net.state_dict())

if __name__ == '__main__':
    # Get a batch of training data
    inputs, classes = next(iter(dataloaders['Train']))
    # Make a grid from batch
    out = torchvision.utils.make_grid(inputs)
    imshow(out, title=[class_names[x] for x in classes])
    # 测试
    Data_Test()
    

    
上一篇 下一篇

猜你喜欢

热点阅读