生成对抗网络解读

2021-02-09  本文已影响0人  小生很忙

摘要:生成对抗网络( Generative Adversarial Networks, GAN)是通过对抗训练的方式来使得生成网络产生的样本服从真实数据分布。在生成对抗网络中,有两个网络进行对抗训练。一个是判别器,目标是尽量准确地判断一个样本是来自于真实数据还是由生成器产生;另一个是生成器,目标是尽量生成判别网络无法区分来源的样本 。两者交替训练,当判别器无法判断一个样本是真实数据还是生成数据时,生成器即达到收敛状态。以上是对生成对抗网络的简单描述,本文将对生成对抗网络的内在原理以及相应的优化机制进行介绍。

文章概览

概率生成模型

  概率生成模型,简称生成模型,是指一系列用于随机生成可观测数据的模型。假设在一个连续或离散的空间\chi中,存在一个随机向量X服从一个未知的数据分布P_{data}(x)x \in \chi。生成模型是根据一些可观测的样本x^1, x^2, ..., x^m来学习m一个参数化模型P_G(\theta;x)来近似未知分布,并可以用这P_{data}(x)个模型来生成一些样本,使得生成的样本和真实的样本尽可能的相似。对于一个低维空间中的简单分布而言,我们可以采用最大似然估计的方法来对p_\theta(x)进行求解。假设我们要统计全国人民的年均收入的分布情况,如果我们对每个P_{data}(x)样本都进行统计,这将消耗大量的人力物力。为了得到近似准确的收入分布情况,我们可以先假设其服从高斯分布,我们比如选取某个P_G(x;\theta)城市的人口收入x^1, x^2, ..., x^m,作为我们的观察样本结果,然后通过最大似然估计来计算上述假设中的高斯分布的参数。
L=\prod_{i=1}^{m} P_{G}\left(x^{i} ; \theta\right)

{\theta}^*=\arg \max _{\theta} \sum_{i=1}^{m} \log P_{G}\left(x^{i} ; \theta\right)

由于P_G(x;\theta)服从高斯分布,我们将其带入即可求得最终的近似的分布情况。下面我们对上述过程进行一些拓展,我们从P_{data}(x)尽可能采样更多的数据,此时可以得到
{\theta}^*=\arg \max _{\theta} \sum_{i=1}^{m} \log P_{G}\left(x^{i} ; \theta\right)\approx \arg \max _{\theta} E_{x \sim P_{\text {data }}}\left[\log P_{G}(x ; \theta)\right]
对该式进行一些变换,可以得到
{\theta}^*=\arg \min _{\theta} K L\left(P_{\text {data }} \| P_{G}\right)
  由此可以看出,最大似然估计的过程其实就是最小化P_{data}(x)分布和P_G分布之间KL散度的过程。从本质上讲,所有的生成模型的问题都可以转换成最小化P_{data}(x)分布和P_G分布之间距离的问题,KL散度只是其中一种度量方式。

  如上所述,对于低维空间的简单分布而言,我们可以显式的假设样本服从某种类型的分布,然后通过极大似然估计来进行求解。但是对于高维空间的复杂分布而言,我们无法假设样本的分布类型,因此无法采用极大似然估计来进行求解,生成对抗网络即属于这样一类生成模型。

生成对抗网络

生成对抗网络的理论解释

  在生成对抗网络中,我们假设低维空间中样本z服从标准类型分布,利用神经网络可以构造一个映射函数G(即生成器)将z映射到真实样本空间。我们希望映射函数G能够使得P_G(x)分布尽可能接近P_{data}(x)分布,即P_GP_{data}之间的距离越小越好:
G^{*}=\arg \min _{G} {\operatorname{Div}}\left(P_{G}, P_{\text {data }}\right)
由于P_GP_{data}的分布都是未知的,所以无法直接求解P_GP_{data}之间的距离。生成对抗网络借助判别器来解决这一问题。首先我们分别从P_GP_{data}中取样,利用取出的样本训练一个判别器:我们希望当输入样本为P_{data}时,判别器会给出一个较高的分数;当输入样本为P_G时,判别器会给出一个较低的分数。例如,我们可以将判别器的目标函数定义成以下形式(与二分类的目标函数一致,即交叉熵):
V(G, D)=E_{x \sim P_{\text {data }}}[\log D(x)]+E_{x \sim P_{G}}[\log (1-D(x))]
我们希望得到这样一个判别器(G固定):
D^{*}=\arg \max _{D} V(D, G)
从本质上来看,\max _{D} V(D, G)即表示P_GP_{data}之间的JS散度(具体推导参见李宏毅老师的课程),即:
\max _{D} V(G, D)=V\left(G, D^{*}\right)=-2 \log 2+2 J S D\left(P_{\text {data }} \| P_{G}\right)

D^{*}(x)=\frac{P_{\text {data }}(x)}{P_{\text {data }}(x)+P_{G}(x)}

因此通过构建判别器可以度量P_GP_{data}之间的距离,所以G^*可以表示为:
G^{*}=\arg \min _{G} \max _{D} V(G, D)

生成对抗网络的求解过程

G^*的求解过程大致如下:

gan

对上述算法过程进行几点说明:

生成对抗网络的优化

fGAN

  通过上面的分析我们可以知道,构建生成模型需要解决的关键问题是最小化P_GP_{data}之间的距离,这就涉及到如何对P_GP_{data}之间的距离进行度量。在上述GAN的分析中,我们通过构建一个判别器来对P_GP_{data}之间的距离进行度量,其中采用的目标函数为:
V(G, D)=E_{x \sim P_{\text {data }}}[\log D(x)]+E_{x \sim P_{G}}[\log (1-D(x))]\
通过证明可知,V(G, D)其实度量的是P_GP_{data}之间的JS散度。如果我们希望采用其他方式来衡量两个分布之间的距离,则需要对判别器的目标函数进行修改。根据论文fGAN,可以将判别器的目标函数定义成如下形式:
D_{f^*}\left(P_{\text {data }} \| P_{G}\right)=\max _{\mathrm{D}}\left\{E_{x \sim P_{\text {data }}}[D(x)]-E_{x \sim P_{G}}\left[f^{*}(D(x))\right]\right\}
G^*可以表示为:
G^{*}=\arg \min _{G} D_{f^*}\left(P_{\text {data }} \| P_{G}\right)
f^*取不同表达式时,即表示不同的距离度量方式。

fgan

f^*(t)=-log(1-exp(t))D(x)log,代入D_{f^*}\left(P_{\text {data }} \| P_{G}\right)即可得到V(G,D)

WGAN

  自2014年Goodfellow提出以来,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。针对这些问题,Martin Arjovsky进行了严密的理论分析,并提出了解决方案,即WGAN(WGAN的详细解读可参考这篇博客)。

wgan

  由以上算法可以看出,WGAN与原始的GAN在算法实现方面只有四处不同:(1)判别器最后一层去掉sigmoid;(2)生成器和判别器的loss不取log;(3)每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c;(4)不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行。

生成对抗网络的实现

  本文实现了几种常见的生成对抗网络模型,包括原始GAN、CGAN、WGAN、DCGAN。开发环境为jupyter lab,所使用的深度学习框架为pytorch,并结合tensorboard动态观测生成器的训练效果,具体代码请参考我的github。

GAN

real_label = torch.ones(batch_size, 1)
fake_label = torch.zeros(batch_size, 1)

# 训练判别器
d_real = D(real_img)
d_real_loss = criterion(d_real, real_label)

z = torch.normal(0, 1, (batch_size, latent))
fake_img = G(z)
d_fake = D(fake_img)
d_fake_loss = criterion(d_fake, fake_label)

optimizer_D.zero_grad()
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
optimizer_D.step()

# 训练生成器
fake_img = G(z)
d_fake = D(fake_img)
g_loss = criterion(d_fake, real_label)

optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()

CGAN

real_label = torch.ones(batch_size, 1)
fake_label = torch.zeros(batch_size, 1)

z = torch.normal(0, 1, (batch_size, latent))

# 训练判别器
d_real = D(real_img, label)
d_real_loss = criterion(d_real, real_label)

fake_img = G(z, label)
d_fake = D(fake_img, label)
d_fake_loss = criterion(d_fake, fake_label)

optimizer_D.zero_grad()
d_loss = (d_real_loss + d_fake_loss)
d_loss.backward()
optimizer_D.step()

# 训练生成器
fake_img = G(z, label)
d_fake = D(fake_img, label)
g_loss = criterion(d_fake, real_label)

optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()

WGAN

# 训练判别器
d_real = D(real_img)
#d_real_loss = criterion(d_real, real_label)
d_real_loss = d_real

z = torch.normal(0, 1, (batch_size, latent))
fake_img = G(z)
d_fake = D(fake_img)
#d_fake_loss = criterion(d_fake, fake_label)
d_fake_loss = d_fake

optimizer_D.zero_grad()
#d_loss = d_real_loss + d_fake_loss
d_loss = torch.mean(d_fake_loss) - torch.mean(d_real_loss)
d_loss.backward()
optimizer_D.step()

for p in D.parameters():
    p.data.clamp_(-clip_value, clip_value)
# 训练生成器
fake_img = G(z)
d_fake = D(fake_img)
#g_loss = criterion(d_fake, real_label)
g_loss = - torch.mean(d_fake)

optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
上一篇 下一篇

猜你喜欢

热点阅读