pytorch入门_网络架构/训练/保存提取

2018-08-14  本文已影响0人  conson_wm

1. 用nn.Module搭建网络架构

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5) # 1 input image channel, 6 output channels, 5x5 square convolution kernel
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120) # an affine operation: y = Wx + b
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv2(x)), 2) # If the size is a square you can only specify a single number
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:] # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

net = Net()
net

'''神经网络的输出结果是这样的
Net (
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear (400 -> 120)
  (fc2): Linear (120 -> 84)
  (fc3): Linear (84 -> 10)
)
'''

  上面这段是LeNet的架构, 需要说明的是

1.1 重要的函数有两个, 一个是_init_(), 一个是forward()

  其他的可以根据需要再加, _init_()是用来定义网络中需要的模块, 比如基本的conv, fc的模块的大小都在这里面定义好, forward()是用来定义前向的架构, 就是顺序地堆叠模块就可以了, 这个和keras没有什么区别, 我觉得Pytorch的这套架构比较有整体性, 就是把一个模型封装在一个类里面, 包含module和architecture都用这个类体现出来了


1.2 在_init_()里开头的那一句

super(Net, self).__init__()

  这句的意思是_init_()函数继承自父类nn.Module
  super()函数是用于调用父类的一个方法, super(Net, self)首先找到Net的父类(就是nn.Module), 然后把类Net的对象转换为类nn.Module的对象, 具体的可以看一下
http://www.runoob.com/python/python-func-super.html


1.3 conv和fc模块的入参

self.conv1 = nn.Conv2d(1, 6, 5) # 1 input image channel, 6 output channels, 5x5 square convolution kernel
self.fc1   = nn.Linear(16*5*5, 120) # an affine operation: y = Wx + b

  conv的入参有三个, 分别是输入channel, 输出channel, 以及kernel的大小(5就代表5x5)
  fc的入参只有两个, 分别就是输入的结点数, 输出的结点数


1.4 torch.nn.functional as F

  这个是通常写法, 我们用到的激活函数, pool函数都在这个F里面(损失函数还有conv/fc等层在nn里面), 这个要注意一下, 那BN/dropout是在nn还是在F里面?


2. 如何开始训练

1中是把网络定义好了, 那离能够训练这个网络还有几步呢?我们需要想一下训练一个神经网络需要一些什么基本的步骤

2.1 输入值格式

  我们的应用都是图片输入, 那pytorch有没有keras那种直接一个文件夹输入作为训练集或测试集(ImageDataGenerator())的那种方法呢?是有的, 就是Dataloader()

import torchvision
import torchvision.transforms as transforms


# torchvision数据集的输出是在[0, 1]范围内的PILImage图片。
# 我们此处使用归一化的方法将其转化为Tensor,数据范围为[-1, 1]

transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                             ])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, 
                                          shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
'''注:这一部分需要下载部分数据集 因此速度可能会有一些慢 同时你会看到这样的输出

Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Extracting tar file
Done!
Files already downloaded and verified
'''

  这段代码就是用DataLoader()来生成patch的方法, 这个和keras的ImageDataGenerator()就很像了, 生成的trainloader是一个包含tensor的迭代器, 在训练的过程中就可以直接从迭代器里去除tensor, 然后装入Variable中作为神经网络的输入, 这就是输入值格式
  当然, 这只是一种简单的从数据集中得到batch的方法, 要是我们的数据是从某个文件夹来还得经过点pre-process的话(比如SR)会更复杂一些, 我们会在SR benchmark笔记里面再详细说

2.2 Optimizer

import torch.optim as optim
# create your optimizer
optimizer = optim.SGD(net.parameters(), lr = 0.01)

# in your training loop:
optimizer.zero_grad() # zero the gradient buffers
optimizer.step()

  Optimizer的定义和Keras差不多, 可以用参数定义lr, momentum, weight_decay等, 这个去查文档就好, optimizer.zero_grad()的意思是初始化所有的梯度buffer为0, 这个是在每一次计算梯度更新之前做的, optimizer.step()就是根据你定义的optimizer执行梯度更新


2.3 loss function

import torch.optim as optim
# make your loss function
criterion = nn.MSELoss()

# in your training loop:
output = net(input)
loss = criterion(output, target)
loss.backward()

  criterion定义了loss function, 然后在训练的loop中loss就根据criterion不停去算, loss是一个Variable, 它具有backward属性, 就是在训练过程中可以直接用.backward来计算loss对每个weight的梯度值


2.4 训练过程

  有了输入, loss function和Optimizer, 我们就可以开始进行训练了, 过程就是从迭代器trainloader把input tensor送入网络, 然后算output, 根据loss function算loss, 从loss再反过去算梯度, 根据Optimizer去更新权重

for epoch in range(2): # loop over the dataset multiple times
    
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        
        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()        
        optimizer.step()
        
        # print statistics
        running_loss += loss.data[0]
        if i % 2000 == 1999: # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / 2000))
            running_loss = 0.0
print('Finished Training')

2.5 神经网络的快速搭建方法

  除了上面提到的搭建神经网络的方法之外, pytorch还提供了另一种更快速的搭建方法, 有点类似于Keras的Sequentiao模型
http://keras-cn.readthedocs.io/en/latest/getting_started/sequential_model/

net = torch.nn.Sequential(
  torch.nn.Linear(2,10),
  torch.nn.ReLU,
  torch.nn.Linear(10,2),
)
# fast implementation of FC 2->10->2

3. 保存和提取训练结果

  保存和加载网络模型有两种方法, 一种是把模型和参数一起save(相当于keras的save()), 还有一种就是只save参数(相当于keras的save_weight())

# 保存和加载整个模型  
torch.save(model_object, 'model.pth')  
model = torch.load('model.pth')  
 
# 仅保存和加载模型参数  
torch.save(model_object.state_dict(), 'params.pth')  
model_object.load_state_dict(torch.load('params.pth')) 

  当网络比较大的时候, 用第一种方法会花比较多的时间, 同时所占的存储空间也比较大
  还有一个问题就是存储模型的时候, 有的时候会存成.pkl格式, 应该是没有本质区别, 都是Pickle格式, 后续在用C来读取网络的时候, 也从Pickle的C实现来考虑直接解析Pytorch的模型
https://www.zhihu.com/question/274533811

上一篇下一篇

猜你喜欢

热点阅读