arXiv'20-(检索增强语言模型)REALM: Retrie
2023-03-26 本文已影响0人
Caucher
标题:REALM:检索增强语言模型预训练
编者的总结
- 作者在语言模型中嵌入了一个知识库文档的检索部分,即输入将首先找到最相关的一批文档,然后共同进入encoder预测语言模型以提升精度。
- 由于训练知识库文档的embedding成本很大,作者选择每次只选择其中的top-K个最相关的文档来进行训练/推理,top-K使用MIPS索引来完成。
- 由于MIPS索引需要提前构建,但是embedding在训练过程中会随时间改变,因此作者选择定期重建索引以解决这个问题。
1. Abstract & Introduction
- 语言模型的隐式存储在网络参数中,要求模型足够大才能覆盖到更多的知识。
- 为了能让知识获取更加模块化和可解释,本文提出检索增强的语言模型预训练REALM。
- REALM从一个大的语料库中(比如维基百科)检索以辅助预训练、精调和推理。并且把检索这一步的影响放在一个独立的模块当中以选择什么样的知识会用来做推理。
- 这其中的一个挑战是:retriver要考虑的知识库太大了,训练负担很大。
2. Background
2.1 Language model pre-training
- 语言模型一般会从一个无标签的语料库中做预训练,先学习知识;然后针对下游特定任务做精调。这样做的适用性通常比端到端的从头训练要好。
3. Approach
3.1. REALM’s generative process
REALM以一些句子为输入,输出是一个分布,即各种可能的预测及其概率。
- 预训练阶段从语料库的句子中挖出去几个位置,交给模型去预测。
- 精调阶段则直接用QA的问答来训练模型。
- 预测任务可以下式表达。p(z|x)表示在一个输入x的基础上,检索到相关文档z的概率;p(y|x,z)则表示已知输入x和相关文档z,推理出y的概率;对所有的z求和则表示:通过所有相关文档的帮助,推理出y的总概率。
-
前一项由retriever来完成,后一项由encoder来完成。
image.png
3.2. Model architecture
3.2.1 Knowledge Retriever
- 如下式,输入和知识库文档都会被Transformer转换成同等维数的Embedding。
- 相关性用向量内积来表示,内积越大,相关性越高。
- 最终的相关性概率分布则是一个简单的softmax。
-
注意Transformer中的参数仍然需要训练以保证embedding的质量。
image.png
3.2.2 Knowledge-Augmented Encoder
这一部分和BERT类似,用的是MLM loss,表示为token的embedding,和通过以输入和相关文档为输入的Encoder的输出embedding向量做内积。
3.3. Training
- 上述框架的一个问题是每次推理都需要去算p(z|x),这要求输入要和所有的知识库文档算一次向量内积距离,这个时间代价是不可接受的。
- 一个替代的方法是用MIPS索引,只找出top-K个相关文档(考虑到绝大部分文档的相关性几乎为0)用以训练。
- 然而MIPS索引要求文档的embedding向量提前做好并建立索引,但是训练过程中embedding会改变,那已经建立好的索引就没用了。
- 作者提供的一个替代方案就是周期性的去更新索引,如下图,每几百步就重新embed一次知识库文档然后建立索引,因为考虑到每一步带来的embedding更新是很少的,所以对top-k影响可能不大。
3.3.1 What does the retriever learn?
目前的训练思路是端到端的训练,包括一个retriever和一个encoder,但是这种训练目标是否意味着retriever可以找到相关性强的文档呢?作者给出了分析。
- 通过下面的梯度分析可知,最终目标梯度提升,只会发生在提供的文档z发挥了正面作用的情况下。