SeqGAN学习笔记(三)

2020-02-10  本文已影响0人  p_w

上回书说到在GAN中判别器只给生成器生成的句子一个真假的判断,不能像MLE一样每一个词都计算loss,在离散序列生成领域判别器给回的信息是不足够的,SeqGAN给出的解决方案是引入强化学习中的reward概念,给判别器传回的信息“乘以”reward倍,通过这种方式“放大”判别器传回的信息。
在这种预设下,神经网络生成器可以等价于强化学习中的策略(policy)网络。强化学习中策略网络的更新依据于reward,reward大,策略网络就倾向于在产生这个大reward的action的梯度方向上多更新几次。这样就解决了判别器传回信息太微弱的问题。
下面是对SeqGAN代码学习的记录:
程序进入sequence_gan.py中的main()函数,首先实例化生成器(generator),判别器(discriminator)和Oracle(target_lstm):


image.png
1、预训练

接下来是训练GAN网络的trick,从Oracle中sample一些数据使用MLE预训练生成器:


image.png

(在这里有一个疑问,SeqGAN中的真实数据由Oracle合成,,然而在代码里也没看见预训练Oracle,如果有,那么应该是我没有理解下图公式的意义):


image.png
pre_train_epoch完成生成器的预训练:
image.png

下面这几行代码是对上面公式(13)的计算:


image.png
2、生成器判别器开始对抗

在discriminator也预训练完成并且实例化rollout(对rollout网络的理解:rollout实现了论文中的蒙特卡洛搜索,rollout与generator本质上一样,在论文中因为一些原因设计了generator与rollout两个网络,rollout的参数和generator一样但有些延迟更新)网络后,分别对生成器和判别器的参数进行更新。

具体地步骤是:

image.png
论文中对应的公式是:
image.png
rollout_num对应的是公式中的,也就是说蒙特卡洛search了N个句子。given_num代表使用sample中token的长度,假设现在given_num是5,也就是说我们现在要计算第五个生成词的reward。当小于5时,生成器网络直接使用已产生的token:
image.png
当大于5时,生成器网络生成剩下的部分句子:
image.png
在计算完reward后,我们返回主函数,将samples和其对应的rewards feed进generator,更新生成器的参数:
image.png
可以在Generator class中看到,GAN网络更新生成器的方法是给每个词的loss乘上reward。
image.png
完。
上一篇 下一篇

猜你喜欢

热点阅读