深度学习

快速上手Pytorch

2018-11-19  本文已影响7人  逆风g

这篇文章需要大家对深度学习里的神经网络训练有一定的基础,我以前训练网络一直都是用的TensorFlow,后面需要把模型和数据迁移到Pytorch平台上去,发现很多里面有很多知识点需要注意,写这篇文章一方面是给自己做个笔记,总结下自己的经验,另一方面是为了方便想要快速上手Pytorch的同学。这篇文章主要内容有:

Tensorflow的PlayGround

PlayGround是一个在线演示、实验的神经网络平台,是一个入门神经网络非常直观的网站。这个图形化平台非常强大,将神经网络的训练过程直接可视化。假若有的同学刚刚想入门深度学习这一领域,可以去看看:
PlayGround地址:http://playground.tensorflow.org
这里也有一篇PlayGround介绍写的非常详细的文章:
参考地址:https://finthon.com/tensorflow-playground-nn/

Pytorch介绍和安装


2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。Pytorch和Torch底层实现都用的是C语言,但是Torch的调用需要掌握Lua语言,相比而言使用Python的人更多,根本不是一个数量级,所以Pytorch基于Torch做了些底层修改、优化并且支持Python语言调用。
它是一个基于Python的可续计算包,目标用户有两类:

  1. 使用GPU来运算numpy
  2. 一个深度学习平台,提供最大的灵活型和速度

如何安装Pytorch呢?

  1. Anaconda(可选)和Python
  2. 显卡驱动和CUDA
  3. 运行Pytorch的安装命令

Torch和Torchvision里的常用包

Torch

Torchvision

Variable、Tensor、Numpy之间的关系

>>> import numpy as np
>>> x=np.array([[1,2,3],[9,8,7],[6,5,4]])
  1. PyTorch 张量的简单封装
  2. 帮助建立计算图
  3. Autograd(自动微分库)的必要部分
  4. 将关于这些变量的梯度保存在 .grad 中
  1. 将Numpy矩阵转换为Tensor张量
    sub_ts = torch.from_numpy(sub_img)
  2. 将Tensor张量转化为Numpy矩阵
    sub_np1 = sub_ts.numpy()
  3. 将Tensor转换为Variable
    sub_va = Variable(sub_ts)
  4. 将Variable转换为Tensor
    sub_np2 = sub_va.data

CPU与GPU

Pytorch支持CPU运行,但是速度非常慢,一张好的NVIDIA显卡能够大大减少网络训练时间,以我自己经验来看,15年MacBook Pro 与戴尔工作站附加一张显存11GB的1080ti显卡相比,后者速度是前者速度的224倍,尤其训练复杂网络一定要在GPU上跑。Pytorch中把数据和模型从CPU迁移到GPU非常简单:


直接对变量、张量、模型使用.cuda()即可把他们迁移到GPU上,反过来迁移到CPU上,使用.cpu()
当有多行显卡时,想充分利用它们,则可使用model = nn.DataParallel(model)命令:

常见问题

示例--GAN生成MINIST数据

最后看个实例,如何使用GAN网络生成MINIST 数据,主要内容有:


MNIST数据集

MNIST数据集是一个手写体数据集,图片大小都是28x28,包含0-9共10个数字,各种风格:



下载好的数据集:


测试集t10k开头,训练集train开头,images是图片,labels是标签

GAN网络模型

输入100长度的噪声向量,经过一个全连接,两个卷积层,一个下采样之后生成成28x28大小的图片,这一部分是生成器
生成的假图片和MNIST里的真图片经过两个卷积层下采样之后,再次经历两个全连接层后输出一个1长度的单位向量,1代表输入图片为真,0代表输入图片为假
GAN训练和Loss


训练判别器D时,要使得V整体变大,训练生成器G时,要使得V整体变小。
这是一个博弈的过程,就像制造假钱的犯罪团伙和验钞机的关系,犯罪团伙需要努力提高技术,让验钞机无法识别出来其制造的假币,而验钞机要能够正确的分辨出真正的纸币还有假币。
理论上当判别器D只有一半的概率0.5能识别出假图片时,就已经收敛了,实际上达不到一半的概率,没关系,使得假图片概率尽量高就行了,最终看上去效果不错。
这是一张由生成器生成的假图片,你能区分出来吗?
可视化

可视化方式有两种,一种是利用torchvision里面的包 torchvision.utils,另外一种是利用visdom插件,下面是二者的对比:


上面那张生成的假图片就是利用torchvision.utils里的save_image函数来存储在本地的。
而以下这张图是利用visdom,在浏览器中查看到的效果:

visdom不光可以查看图片,还可以查看loss变化曲线图等各种功能。

具体的代码实现去工程里查看,这里给出分享地址:
https://github.com/gcfrun/GAN_MNIST_Pytorch
mnist_data.py:数据输入模块
mnist_net.py:网络模型模块
mnist_loss.py:Loss计算模块
mnist_train.py:迭代训练模块
mnist_visual.py:可视化模块

上一篇 下一篇

猜你喜欢

热点阅读