CGAN-监督式GAN

2019-05-16  本文已影响0人  baiyang白杨

1.CGAN的简介

为了解决带标签的数据生成问题,研究者们提出了条件生成对抗网络(CGAN)的概念。

CGAN的结构如上图所示,与GAN的主要区别是生成器和判别器的输入数据中都加入类别标签向量(C_vector),生成器的优化目标函数基本上没有变化。

总的来说CGAN在GAN上的改动并不大,但是普通的GAN所生成的内容是随机的,CGAN实现了根据输入标签生成指定类别的内容。

2.CGAN的实现

目前CGAN的实现由多种形式,主要的区别是C_vector的形式,目前主要有以下三种形式:

第一中形式:

将输入Generator的C_vector进行One-hot编码,然后与noise进行拼接,此时C_vector为(batch_size, class_num) ,noise为(batch_size, latent_dim),将拼接之后大小为(batch_size, latent_dim+class_num)作为生成器的输入。

将输入Discrimintor的C_vector首先进行One-hot编码然后通过expand()方法进行维度扩展,此时的C_vector为 (batch_size, class_num, cols, rows) , Real_data 和 Fake_data为(batch_size, channel, cols, rows),最后将转换后的C_vector和Real_data或者Fake_data进行拼接,将拼接之后大小为(batch_size, channel+class_num, cols, rows)的张量作为判别器的输入。

第二种形式:

将输入Generator的C_vector通过Embedding方法进词嵌入,并进行Flatten操作,从而将C_vector转换成为与noise大小相同的张量(batch_size, latent_dim), 然后将noise 和 C_vector 进行mulitiply()操作(即对应位置上的元素相乘,该运输不改变张量的大小),将最终得到的(batch_size, latent_dim)的张量作为生成器的输入。

将输入Discriminator的C_vector通过Embedding方法进行词嵌入,并进行Flatten操作,从而将C_vector转换为(batch_size, channel*rows*cols),接着对Real_data和Fake_data进行Flatten操作,将其转换为(batch_size, channel*rows*cols),然后将转换后的C_vector和Real_data或者Fake_data进行multiply()操作,将最终得到的(batch_size, channel*rows*cols)张量作为判别器的输入。

第三种形式:

将输入Generator的C_vector进行One-hot编码,然后与noise进行拼接,此时C_vector为(batch_size, class_num) ,noise为(batch_size, latent_dim),最后将拼接后大小为(batch_size, latent_dim+class_num)作为生成器的输入。

将输入Discriminator的C_vector进行One-hot编码,然后与经过Flatten()处理之后的Real_data或者Fake_data进行拼接,此时Real_data和Fake_data为(batch_size, channel*rows*cols),C_vector为(batch_size, num_class),最后将拼接之后大小为(batch_size, channel*rows*cols + num_class)的张量作为判别器的输入。

损失函数:

在具体实现上,CGAN的损失函数和GAN基本相同。

上一篇下一篇

猜你喜欢

热点阅读