从零单排fastai脚本(2)

2019-05-29  本文已影响0人  深度学习模型优化

这次看下wgan脚本,这里使用fastai来完成wgan的训练和使用。

老三样,我就不加标题了

%reload_ext autoreload
%autoreload 2
%matplotlib inline

1 重要的包

from fastai.vision import *
from fastai.vision.gan import *

其中gan包是在../fastai/vision下的文件夹。大家从fastai的官网上下载源代码可以看见。

2 数据

本教程使用的脚本是LSun bedroom数据集,该数据集是卧室的图片,我们目的是使用fastai使用wgan生成卧室的图片。
这里使用了该数据集的一小部分,因为原始数据集实在是太大了。

使用data.show_batch()显示部分数据。

图1 部分数据示例

3 模型

生成对抗网络GAN有很多种。

这里使用WGAN(Wassertein GAN)。

模型训练过程如下:

WGAN有两部分:

  1. freeze generator,然后训练critic。

2 freeze critic,然后训练generator。

4 模型训练

训练使用fit_one_cyclemax_lr=2e-4
这里对比了下Google和kaggle提供的免费模型训练平台。
因此我选择了kaggle的kernel,速度是真的快!但是不能超过9个小时,这个要注意。

这里的fastai需要事先定义:

generator = basic_generator(in_size=64, n_channels=3, n_extra_layers=1)
critic    = basic_critic   (in_size=64, n_channels=3, n_extra_layers=1)

learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
                        opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)
# 模型训练
learn.fit(30,2e-4)

5 结果

learn.gan_trainer.switch(gen_mode=True)
learn.show_results(ds_type=DatasetType.Train, rows=16, figsize=(8,8))
图2 生成结果
上一篇 下一篇

猜你喜欢

热点阅读