Conditional Generation by RNN &

2020-11-26  本文已影响0人  没日没夜醉心科研的九天

Outline

        Generation

        Attention

        Tips for Generation

        Pointer Network

Generation

        generate a structured object component by component

        1. examples of generation

        逐次用RNN生成字符/单词

句子的生成

        逐个像素地生成图像

图像的生成 1 图像的生成 2

        2. conditional generation

        不仅限于生成随机的句子,而是根据条件生成相应的sentence。

        如图像标题生成、聊天机器人。

conditional generation

        图像标题生成(Conditional Caption generation)

        将图像通过CNN转换成一个vector,再将vector输入到RNN中生成标题。

Caption generation

        有条件的句子生成(Conditional Sentence Generation)

        如机器翻译/聊天机器人。

        将中文输入到一个RNN(1)中,得到一个包含所有信息的vector,再将vector输入到RNN(2)中。前者是encoder,后者是decoder。RNN(1)和RNN(2)的参数既可以相同也可以不同,具体分析。

Conditional Sentence Generation

Attention

        Dynamic conditional generation

        1.Dynamic conditional generation

        有时候data过大,encoder生成的vector不足以包含所有的信息;或者只有特定的信息对某个输出有用,其他都没有必要,因此需要设计一个dynamic的生成方案。

        2. examples

        Machine Translation

                首先有一个初始化参数Z0(decoder产生的key),然后每个component依次通过一个RNN得到对应的h,然后通过一个match(Z0和hi为输入,a为输出)得到ai,ai即代表着目前的重点关注范围。

                match的具体算法由designer自己决定,可以是small network等。

Machine Translation 1

                对每一个hi做同样的运算,通过softmax就得到对应的概率。将其加权和作为输入,输入到decoder的RNN中,得到对应的transcript,而Z1是隐藏层的输出,进行下一步运算,直至输出结束。

Machine Translation 2 Machine Translation 3

                通过component与key的match运算,最后得出每一个component的attention weight,代表着当前对某些具有较大weight值的component的关注程度。

                每次decoder的input为attention weight与component的加权和(概率分布)。

        Speech Recognition

Speech Recognition

        Image Caption Generation

                与machine translation相似。

Image Caption Generation 1 Image Caption Generation 2 Image Caption Generation 3

        Memory Network

                在memory上做attention。

                首先回顾传统的attention-based RNN模型,与上述介绍的一样。

传统的RNN模型

Tips for generation

      1.Scheduled Sampling

                训练和测试不匹配(mismatch between train and test)的问题可以用scheduled sampling来解决。

                Training:在训练时,预测的结果总是会与reference(label,ground truth)来计算损失,并且下一个component的input会是上一个component对应的reference

training

                Generation(test):在生成时,下一个component的input只能是上一个component对应的output;此时没有reference来参考。

genertaion

                这样就会导致一个mismatch的问题。因为generation没有reference,如果一个output出现了错误的预测,那么就会将接下来的结果带入到没有经过训练的方向上,会导致很多问题,如下图所示。

mismatch

                如果我们考虑改变train的方法呢?

                 为了保持train和test一致,我们应当保持前一个component的output始终是后一个component的input,即使它与reference不一致。这种训练方式看起来很合理,但是很难train,注意:第一个component训练的目标是输出A、第二个要输出B,因此第一个component的预测输出最终会变成A,此时第二个component的输入也会变成A,但是他已经按照B训练好了。因此到最后网络的训练结果不一定好,也有可能更差。

modifying

                因此可以采用scheduled sampling来改善这个问题。

                Scheduled sampling其实就是以一定的概率函数选择下一个input的来源是上一个output还是reference。如图,三种decay衰减函数。实验证明,采用scheduled sampling确实在效果上有一定的改善。

scheduled sampling

      2.beam search

                绿色的路分数最高,但是我们并不能提前知道最终的结果是怎样。因此我们可以设置beam size的大小,最次都挑选最优的路径。

beam search 1

                        beam search的思想就是每一步都挑选最好的路径。在每一次只有两种结果的前提下,设置beam size=2,也就是说每次考虑2个component,从4条路径中挑选最好的路径作为最终训练结果。

beam search 2

      3.better idea?

                前一个output应当是把distribution还是选择的结果(如:非黑即白)作为input送给下一个component呢?显然把选择的结果送给下一个component比较好。如下图,我们想要最终输出“高兴想笑”或“难过想哭”,但是如果传送的是distrubution,可能会得到“高兴想哭”等不好的结果。

better idea?

      4.Object level

                如下图所示,采用component的loss训练到中间结果“The dog is is fast”后就train不动了,结果改善得不明显。但是如果换成object level的loss,则还能继续train下去直到得到目标结果。

Object level

      5.Reinforcement learning

Reinforcement learning

Pointer network

        Pointer Network可以应用在给一堆点找边界上:

 Pointer Network 1

        相对于上述的attention-based网络,pointer network在计算出每一个component的attention weight后,不是计算概率和,而是直接将对应的component输出。如下图,(x4,y4)对应的attention weight是0.7最高,因此直接输出(x4,y4)。

 Pointer Network 2

        其他方面的应用,如machine translation和chat-bot,pointer network可以直接输出相应的word。

application
上一篇下一篇

猜你喜欢

热点阅读