Pytorch starGAN
源代码:https://github.com/yunjey/StarGAN
StarGAN有两个主要的模块,一个discriminator记为D,一个generator记为G.
-
(a) D 学习区分真实图像与伪图像,将真实图像归类于正确的域
-
(b) G 输入图像和目标域标签,生成伪图像。目标域标签重复多次后与图像串联在一起。
-
(c) G 使用伪图像和原始的域标签重构原始图像.
-
(d) G 产生以假乱真的图像,并让D判断为目标域图像
训练多个数据库
-
(e) D 学习区分真实图像和伪图像,并且在已知类别的图像上最小化分类损失
-
(g) 假设掩模向量是[0, 1](紫色),G学习CelebA的标签(黄色)而非RaFD的标签(绿色)[1,0]的话正好反过来。
n表示数据库个数,m表示掩码,这里n==2 -
(h) G生成图像,让D分不出是真实图像还是伪图像,但能分出目标域图像的标签。
多个数据库
以下来自新智元
然而,现有的模型在多域图像转换任务中效率低下。这些模型的低效率是因为在学习K域的时候,需要训练K(K−1)个生成器。图2说明了如何在四个不同的域之间转换图像的时候,训练十二个不同的生成器的网络。即使它们可以从所有域图像学习全局特征,如形状特征学习,这种模型也是无效的,因为每个生成器不能充分利用整个训练数据,只能从K学习的两个领域。未能充分利用训练数据很可能会限制生成图像的质量。此外,它们不能联合训练来自不同域的数据集,因为每个数据集只有部分标记。
我们的StarGAN模型与其他跨域模型的比较。(a)为处理多个域,应该在每一对域都建立跨域模型。(b)StarGAN用单个发生器学习多域之间的映射。该图表示连接多个域的拓扑图。
为解决这些问题我们提出了StarGAN,它是生成对抗网络,能够学习多个域之间的映射。如图2(b)所示,文章中提出的模型接受多个域的训练数据,并且只使用一个生成器学习所有可用域之间的映射。这个想法是非常简单的。其模型不是学习固定的图像转化(例如,从黑发到金发),而是输入图像和域信息,学习如何灵活地将输入图像转换到相应的域中。文章中使用一个标签(二进制或one hot向量)代表域信息。在训练过程中,随机生成目标域标签并训练模型,以便灵活地将输入图像转换到目标域。通过这样做,可以控制域标签并在测试阶段将图像转换成任何所需的域。
本文还引入了一种简单而有效的方法,通过将掩码向量添加到域标签,使不同数据集的域之间进行联合训练。文章中所提出的方法使模型可以忽略未知的标签,并专注于有标签的特定数据集。在这种方式下,此模型对任务能获得良好的效果,如利用从RaFD数据集学到的特征来在CelebA图像中合成表情,如图1的最右边的列。据本文中提及,这篇工作是第一个成功地完成跨不同数据集的多域图像转化。
总的来说,本文的贡献如下:
-
提出了StarGAN,生成一个新的对抗网络,只使用一个单一的发生器和辨别器实现多个域之间的映射,有效地从所有域的图像进行训练;
-
展示了如何在多个数据集之间学习多域图像转化,并利用掩码向量的方法使StarGAN控制所有可用的域标签。
-
提供定性和定量的结果,对面部表情合成任务和面部属性传递任务使用StarGAN,相比baseline模型显示出它的优越性。