BERT MLM LOSS2024-05-30
2024-05-29 本文已影响0人
9_SooHyun
BERT(Bidirectional Encoder Representations from Transformers)的MLM(Masked Language Model)损失是这样设计的:在训练过程中,BERT随机地将输入文本中的一些单词替换为一个特殊的[MASK]标记,然后模型的任务是预测这些被掩盖的单词。具体来说,它会预测整个词汇表中每个单词作为掩盖位置的概率。
MLM损失的计算方式是使用交叉熵损失函数。对于每个被掩盖的单词,模型会输出一个概率分布,表示每个可能的单词是正确单词的概率。交叉熵损失函数会计算模型输出的概率分布与真实单词的分布(实际上是一个one-hot编码,其中正确单词的位置是1,其余位置是0)之间的差异。
具体来说,如果你有一个词汇表大小为V,对于一个被掩盖的单词,模型会输出一个V维的向量,表示词汇表中每个单词的概率。如果y是一个one-hot编码的真实分布,而p是模型预测的分布,则交叉熵损失可以表示为(用于衡量模型预测概率分布与真实标签概率分布之间的差异):
其中:
-
表示损失函数的值
-
表示类别的数量
-
是第
个类别的真实标签,通常为0或1
-
是模型预测第
个类别的概率
-
表示自然对数
-
表示对所有类别求和
在这个公式中,是真实分布中的第i个元素,而
是模型预测的分布中的第i个元素。由于y是one-hot编码的,所以除了正确单词对应的位置为1,其余位置都是0,这意味着上面的求和实际上只在正确单词的位置计算。
在实际操作中,为了提高效率,通常不会对整个词汇表进行预测,而是使用采样技术,如负采样(negative sampling)或者层次softmax(hierarchical softmax),来减少每个训练步骤中需要计算的输出数量。