SeqGAN学习笔记(三)
上回书说到在GAN中判别器只给生成器生成的句子一个真假的判断,不能像MLE一样每一个词都计算loss,在离散序列生成领域判别器给回的信息是不足够的,SeqGAN给出的解决方案是引入强化学习中的reward概念,给判别器传回的信息“乘以”reward倍,通过这种方式“放大”判别器传回的信息。
在这种预设下,神经网络生成器可以等价于强化学习中的策略(policy)网络。强化学习中策略网络的更新依据于reward,reward大,策略网络就倾向于在产生这个大reward的action的梯度方向上多更新几次。这样就解决了判别器传回信息太微弱的问题。
下面是对SeqGAN代码学习的记录:
程序进入sequence_gan.py中的main()函数,首先实例化生成器(generator),判别器(discriminator)和Oracle(target_lstm):
![](https://img.haomeiwen.com/i4858318/d189902ef40cf923.png)
1、预训练
接下来是训练GAN网络的trick,从Oracle中sample一些数据使用MLE预训练生成器:
![](https://img.haomeiwen.com/i4858318/38fe1f40d869c9a3.png)
(在这里有一个疑问,SeqGAN中的真实数据由Oracle合成,,然而在代码里也没看见预训练Oracle,如果有,那么应该是我没有理解下图公式的意义):
![](https://img.haomeiwen.com/i4858318/22cbe27b4ec60c14.png)
pre_train_epoch完成生成器的预训练:
![](https://img.haomeiwen.com/i4858318/809cead2a99fcd72.png)
下面这几行代码是对上面公式(13)的计算:
![](https://img.haomeiwen.com/i4858318/686f6d8a2a22a484.png)
2、生成器判别器开始对抗
在discriminator也预训练完成并且实例化rollout(对rollout网络的理解:rollout实现了论文中的蒙特卡洛搜索,rollout与generator本质上一样,在论文中因为一些原因设计了generator与rollout两个网络,rollout的参数和generator一样但有些延迟更新)网络后,分别对生成器和判别器的参数进行更新。
-
生成器生成samples
image.png
-
使用rollout生成每句sample中每个词的reward,注意每个词都有reward。reward的的tensor shape是:
image.png
具体地步骤是:
![](https://img.haomeiwen.com/i4858318/d097950056625828.png)
论文中对应的公式是:
![](https://img.haomeiwen.com/i4858318/8c37e224a2d52d21.png)
rollout_num对应的是公式中的,也就是说蒙特卡洛search了N个句子。given_num代表使用sample中token的长度,假设现在given_num是5,也就是说我们现在要计算第五个生成词的reward。当小于5时,生成器网络直接使用已产生的token:
![](https://img.haomeiwen.com/i4858318/3ab4f6d12cf15e99.png)
当大于5时,生成器网络生成剩下的部分句子:
![](https://img.haomeiwen.com/i4858318/ec142c8781726c12.png)
在计算完reward后,我们返回主函数,将samples和其对应的rewards feed进generator,更新生成器的参数:
![](https://img.haomeiwen.com/i4858318/e20a4a22f4ea3935.png)
可以在Generator class中看到,GAN网络更新生成器的方法是给每个词的loss乘上reward。
![](https://img.haomeiwen.com/i4858318/7a5d1096e7f0f5fe.png)
完。