使用CNN模型做预测:前向过程解释

2019-08-04  本文已影响0人  钢笔先生

Time: 2019-08-04
视频地址:https://youtu.be/6vweQjouLEE?list=PLZbbT5o_s2xrfNyHZsM6ufI0iZENK9xgG&t=23

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)

# 训练集
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

# 构建网络
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
    
    def forward(self, t):
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = F.relu(self.conv2 (t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = F.relu(self.fc1(t.reshape(-1, 12*4*4)))
        t = F.relu(self.fc2(t))
        t = self.out(t)

        return t

torch.set_grad_enabled(False)

# 实例化网络
net = Network()
sample = next(iter(train_set))
image, label = sample

image.shape # torch.Size([1, 28, 28])
image.unsqueeze(0).shape # torch.Size([1, 1, 28, 28])

# 执行预测
# 预测时需要输入图片的形状为4维张量
pred = net(image.unsqueeze(0))

pred # tensor([[-0.0484, -0.0635, -0.0606, -0.1533,  0.0612,  0.0382,  0.0014, -0.0159, -0.0116, -0.1182]])
pred.argmax(dim=1) # tensor([4])
F.softmax(pred, dim=1)

这里有一些要点需要注意:网络接收的数据是4D张量,单张图片也需要处理成4D的格式,用的是tensor.unsqueeze()方法。

END.

上一篇下一篇

猜你喜欢

热点阅读