论文剖析:Unsupervised Diverse Colori

2018-09-17  本文已影响120人  HellyCla

论文地址

背景

向着convert sketch/gray to image的方向搜索论文与资料,发现当前sketch-to-image的研究中还原效果很差,轮廓模糊,种类少而简单,单个物品训练转换,至多训练了50多种,不符合我预先设想的复杂任意场景的有效还原,后又看到colorization相关的论文,为grayscale-to-RGB/YUV image转换,目测效果还不错,因此想复现一下论文尝试用自己搜集的复杂数据集做一下测试。

已阅论文

Colorization

Unsupervised Diverse Colorization via Generative Adversarial Networks是近期阅读论文中看起来效果还很不错的一篇,于是仔细研究一下论文和代码。

基本网络结构

下面是基本的利用GAN网络来解决此问题的框架结构:Generator 的两个输入是 灰度图+随机noise z,经过Generator网络产生UV分量,与Y分量(即输入的灰度图)合并得到Generated color image, Discriminator即用来鉴别该图与原始的彩色图像。


传统的基本的GAN结构(针对自上色问题)

论文中的创新点总结:

论文中采用的网络结构

Generator && Discriminator
细节

代码部分

在公布的一份简陋的代码中,仍存在许多问题,需要我们自己根据自己的库、版本以及需求来修正。
微修正的实现代码
存在的问题有:

  1. 代码中:visualize,multi-z,multi-conditiontransform等参数均未实现对应功能代码。multi-z的作用是concat 多层 noise,multi-condition即concat multi-layer condition.
  2. improve_GAN主要用于增加梯度惩罚项,系数默认为10,体现在d_loss中。可用可不用,可以在测试对比损失函数效果时用,使用此项时优化函数用Adam。
    3.noise z在代码中并未单独处理concat问题,将noise也与灰度图做统一处理,使用了生成器网络,若想实现论文效果,需另外重写输入noise时的网络结构。
  3. test_z_fixed.pkl这个序列化文件应该是对noise z做了调整,具体结构内容未知。

算法流程

算法流程

在实现中Kd=5, Kg=1.在迭代中不断更新参数值。

关键代码
多层联结的生成器卷积网络是这个GAN网络创新的核心部分,实现如下:

  def generator_colorization(self, z, image_Y, config=None):
        with tf.variable_scope("generator") as scope:
            # project z
            h0 = linear(z, config.image_size * config.image_size, 'g_h0_lin', with_w=False)
            # reshape 
            h0 = tf.reshape(h0, [-1, config.image_size, config.image_size, 1])
            h0 = tf.nn.relu(batch_norm(h0, name='g_bn0'))
            # concat with Y
            h1 = tf.concat([image_Y, h0], 3)
            # print 'h0 shape after concat:', h0.get_shape()
            h1 = conv2d(h1, 128, k_h=7, k_w=7, d_h=1, d_w=1, name='g_h1_conv')
            h1 = tf.nn.relu(batch_norm(h1, name='g_bn1'))

            h2 = tf.concat([image_Y, h1], 3)
            h2 = conv2d(h2, 64, k_h=5, k_w=5, d_h=1, d_w=1, name='g_h2_conv')
            h2 = tf.nn.relu(batch_norm(h2, name='g_bn2'))

            h3 = tf.concat([image_Y, h2], 3)
            h3 = conv2d(h3, 64, k_h=5, k_w=5, d_h=1, d_w=1, name='g_h3_conv')
            h3 = tf.nn.relu(batch_norm(h3, name='g_bn3'))

            h4 = tf.concat([image_Y, h3], 3)
            h4 = conv2d(h4, 64, k_h=5, k_w=5, d_h=1, d_w=1, name='g_h4_conv')
            h4 = tf.nn.relu(batch_norm(h4, name='g_bn4'))

            h5 = tf.concat([image_Y, h4], 3)
            h5 = conv2d(h5, 32, k_h=5, k_w=5, d_h=1, d_w=1, name='g_h5_conv')
            h5 = tf.nn.relu(batch_norm(h5, name='g_bn5'))

            h6 = tf.concat([image_Y, h5], 3)
            h6 = conv2d(h6, 2, k_h=5, k_w=5, d_h=1, d_w=1, name='g_h6_conv')
            out = tf.nn.tanh(h6)

            print('generator out shape:', out.get_shape())

            return out

测试结果:待更新。

笔记们

略多。列表见好了。

上一篇下一篇

猜你喜欢

热点阅读