NLP学习

Transformer系列:Beam Search束搜索原理图文

2023-11-27  本文已影响0人  xiaogp

关键词:TransformerBeam Search

前言

在前文...中介绍了Transformer在预测阶段逐位进行单词翻译的过程,采用了Greedy Search贪婪搜索这种简单的解码策略,贪婪搜索只关心局部最优,因此通常效果较差,本篇中将介绍另一种更优的解码策略Beam Search集束搜索。


内容摘要


解码目标:序列联合概率最大

在预测阶段,Transformer的输出为词库中在下一个单词位置出现的概率分布,将概率分布转化为最终的翻译文本序列还需要额外的解码策略,比如Greedy Search和Beam Search。

Transformer和解码策略的工作方式

不论那种策略,最终解码的结果应该满足在给定输入文本的条件下,输出的文本在所有候选文本中得分最高,表现为输出文本每个位置上的单词的联合概率最大。
令x为输入的待翻译文本,y1,y2,...,yn为翻译结果每个词位置上的单词,则联合概率公式如下

联合概率

解码的目标是获得联合概率P最大的单词序列。该公式是一个概率累乘,序列越长联合概率越小越接近0,因此为了方便计算避免出错通常转化为对数相加的形式,公式如下

image.png

其中S代表路径的得分,由于概率P在0~1之间,因此S为一个负数,且随着路径长度越长,S越来越小,绝对值越来越大。


Beam Search原理和可视化

Greedy Search问题在于在每一步它只选择得分最高的top 1单词,假设被它忽略的top 2单词带来的后面一系列单词使得整个序列的得分反而更高,则Greedy Search就不会得到最合理的解码结果。
Beam Search集束搜索是Greedy Search的改进版,它拓展了Greedy Search在每一步的搜索空间,每一步保留当前最优的K个候选,一定程度上缓解了Greedy Search的问题,令K为Beam Size代表束宽,Beam Size是一个超参数,它决定搜索空间的大小,越大搜索结果越接近最优,但是搜索的复杂度也越高,当Beam Size等于1的时候,Beam Search退化为Greedy Search。
以一个解码步长为3,词表中候选词数为5,Beam Size为2场景为例,Beam Search的搜索流程如下

Beam Search可视化

以此类推,该步骤一直持续到解码到end或者超过最大解码步长为止,最终会形成2条完整候选序列,取其中的联合概率得分最大值即可得到最佳解码序列。


Beam Search对文本长度的惩罚项

前文提到采用log将概率的累乘改为负数的累加,随着解码文本长度的增加序列的得分也在不断变得越来越负,因此解码结果注定给短文本结果更高的得分,导致一个更合理的翻译结果因为文本较长被一个不合理的短文本结果淘汰。
Beam Search采用基于文本长度的惩罚项来解决这个问题,公式如下

带有文本长度惩罚的路径得分

该公式本质上将原始得分S,除以文本长度,由于原始S是负数,因此文本长度越长相当于在缩小S的绝对值,惩罚之后的得分S*会变大,该公式的目的是使得长文本的得分不要那么得负。
length_penalty是一个自定义参数,默认是1等同于直接除以文本长度,若length_penalty大于1惩罚越大,对长文本的拔高作用越大,解码结果更倾向于长文本,length_penalty小于1且大于0则对长文本的拔高作用越小,越接近于不加任何惩罚。


Beam Search的停止条件

Beam Search单条候选序列停止条件细分有两种情况,分别是

候选序列解码到<end>停止

一条待翻译文本经过Beam Search之后在每一步下会得到Beam Size个候选,前文提到直到解码到<end>或者超过最大解码步长则停止,具体而言是对于单条翻译文本的单个候选序列解码到<end>停止,此时若其他候选序列还没有解码到<end>,则不影响其他候选序列继续寻找。
通过可视化来理解下Beam Search的停止条件,令Beam Search为3,从<start>位置开始搜索过程如下

Beam Search搜索过程1

其中[ ]代表最终的3个候选序列列表,随着解码步长的进行,列表逐渐添加新的解码单词,列表中的数值代表单词的token id,箭头表示血缘关系,有箭头左边的序列增加了结尾元素成为了箭头右边的序列。在每一步解码都只保留了得分最高的3条序列,后一步的解码结果必定继承了前一步至少其中一个,举例第三步都是继承了第二步中的[93, 288],而第二步的[36, 288],[1100, 288]由于在第三步的联合概率计算中没有排进前三,因此被淘汰。
一直持续这个步骤直到形成以下解码结果

Beam Search搜索过程2

其中第n-1步的第二个高分候选[94, ..., 4, <end>]已经率先解码到<end>,至此该候选路径停止后续寻找,将此路径加入解码结果集合。
由于第n-1步已经有一个候选停止,因此第n步只能以剩下2个作为前句继续搜索,第n步top 1和top 2高的序列都是以<end>结尾,同理停止搜索并且加入解码结果集合。

候选序列得分已经低于已解码完的当前最优序列,早停

上面的条件仅限制了解码到<end>停止,如果一直解码不到<end>则直到解码到最长步长位置,比如我们设置最大步长为70则一共需要运行70次Transformer的推理过程。
由于解码长度越大,序列得分越小,如果都等到所有候选都解码到<end>则最终的解码结果集合会很大,很多得分很小的结果是没有必要的,应该有这么一种情况当解码到某个单词的时候已经可以断定不需要在继续以此为基础继续搜索了,Beam Search引入早停机制来实现这个效果。

Beam Search搜索过程3

早停的机制是比较当前得分和已经全部解码完的序列的得分,如果当前得分远远小于最优路径得分,则执行早停,通过给最大得分乘以α倍来控制这个远远小于的程度。
例如在第n步还剩下第3条序列没有解码到<end>,对该序列继续搜索,在第n+1步输出了新的top 3,令α=3,其中:

至于为什么要加上α倍数,答案就是到目前位置的得分还不是最终得分,需要结合上一节所讲的长度惩罚,如果早停不加上α倍数的宽放,可能会把长序列给早停了。
至此讨论的都是某条候选序列停止条件,什么时候整个待翻译文本结束搜索呢?答案是当前一个步下可用的候选为0的时候,该样本的Beam Search结束。如果是一个批次下好多文本输入给Transformer,则所有样本并行构造自身的候选,待所有样本的都已经没有可用的候选时候,整个批次文本的Beam Search解码停止。


Beam Search的时间复杂度分析

令解码步长为T,词表长度为N,束宽为K,则Beam Search的时间复杂度是O(K * N * T),因为在每一个步长T下都需要运算K次推理全部词表长度N,且对N进行排序。
Greedy Search的时间复杂度是O(N * T),每一步只需要推理一次对词表N进行一次排序即可。
由于Beam Search只是相比于Greedy Search更逼近最优解,实际仅有暴力求解Exhausitive Search才能保证结果一定是最优解,Exhausitive Search的时间复杂度是O(N * N * T),由于每一步都需要对前后进行全部词表的笛卡尔积操作,无法像Beam Search那样对左项进行排除。

解码策略 时间复杂度
Greedy Search贪婪搜索 O(N * T)
Beam Search集束搜索 O(K * N * T)
Exhausitive Search暴力搜索 O(N * N * T)

Beam Search源码解读

源码来自TensorFlow Keras实现的批量Beam Search(github:attention-is-all-you-need-keras),其中在模型验证和测试阶段作者使用beam_search函数来进行解码

if 'eval' in sys.argv:
    while True:
        quest = input('> ')
        rets = s2s.beam_search(quest.split(), topk=3, delimiter=' ')
        for x, y in rets:
            print(x, y)

在该函数内部作者先对输入数据进行padding预处理,然后直接输入给decode_batch_beam_search该函数

    def beam_search(self, input_seqs, topk=5, batch_size=8, length_penalty=1, delimiter='', verbose=0):
        src_seq = self.make_src_seq_matrix(input_seqs)
        ...
        decode_batch = lambda x: decode_batch_beam_search(x, topk, encode_model, decode_model,
                                                          start_mark, end_mark, max_len)

该函数实现了批量Beam Search,完整代码配合注释解读如下

def decode_batch_beam_search(src_seq, topk, encode_model, decode_model, start_mark, end_mark, max_len=128,
                             early_stop_mult=5):
    # TODO 真实的该批次下的样本数
    N = src_seq.shape[0]
    # TODO 原始序列id输入 沿着0方向复制5份, 样本数*5
    # TODO 设置了beam_size=5,最终会输出5条候选路径,从一开始就初始化5条然后一齐并行计算
    src_seq = src_seq.repeat(topk, 0)
    # TODO [topk, 该批次下的最大seq_len, embedding_size]
    enc_ret = encode_model(src_seq).numpy()
    # TODO 复制topK之后的batch_size
    bs = src_seq.shape[0]

    # TODO 初始化当前的单词信息,用于预测下一个词
    target_one = np.zeros((bs, 1), dtype='int32')
    # TODO 以<start>的token_id作为初始化
    target_one[:, 0] = start_mark
    d_model = decode_model.inputs[-1].shape[-1]
    n_dlayers = len(decode_model.inputs) - 3
    # TODO 自注意层的K,V, 注意batch_size是复制后的,每一条复制后的序列单独维护一个K,V
    dec_outputs = [np.zeros((bs, 1, d_model)) for _ in range(n_dlayers)]

    # TODO [(批次index, 序列token id, score),()...]
    final_results = []
    # TODO 记录batch_size*topK的所有路径序列
    decoded_indexes = [[] for x in range(bs)]
    # TODO 记录batch_size*topK的所有路径序列的得分的中间结果,约定下一步根据索引拿到上一步的分
    decoded_logps = [0] * bs
    # TODO last_k_size 每条真实样本的最新beam size,最大为topk=5,最小为0,为0的时候停止beam search
    # TODO 在一开始的时候是<start>,没有5个候选,只有一个统一候选<start>,因此lastks=[1,1...]
    lastks = [1 for x in range(N)]
    # TODO best_score 用于记录每条真实样本的最好得分,用于早停
    bests = {}
    for i in range(max_len - 1):
        # TODO transformer一次完整计算
        outputs = [x.numpy() for x in decode_model([target_one, src_seq, enc_ret] + dec_outputs)]
        # TODO new_dec_outputs => [5, 1, embedding_size] 最新的topK*batch_size个K,V向量相同,在<start>位置进入模型之后,所有样本K,V暂时相同
        # TODO output => [5, 1, 3665] 下一个词在所有池子中的概率分布
        new_dec_outputs, output = outputs[:-1], outputs[-1]
        for dec_output, new_out in zip(dec_outputs, new_dec_outputs):
            # TODO 拼接更新K,V
            dec_output[:, -1, :] = new_out[:, 0, :]
        # TODO 开始为K,V的拼接做准备
        dec_outputs = [np.concatenate([x, np.zeros_like(new_out)], axis=1) for x in dec_outputs]

        # TODO 将softmax得分转化为logsoftmax得分,加上log是为了后面可以将概率相加,防止概率相称数越来越小变得极小
        # TODO output [5, 3665] 手动softmax并且将softmax结果转化为log负数形式
        output = np.exp(output[:, 0, :])
        output = np.log(output / np.sum(output, -1, keepdims=True) + 1e-8)

        # TODO 原封不动的复制一下K,V
        next_dec_outputs = [x.copy() for x in dec_outputs]
        # TODO 以UKN开始 初始化每条复制后样本的序列路径
        # TODO 每一步重新初始化为全1向量,这个1没有任何意义,改成其他数也是一样的
        next_decoded_indexes = [1 for x in range(bs)]  # UKN

        # TODO 对transformer的计算结果进行beam search
        for ii in range(N):
            # TODO 复制topK后,每条样本的第一行在整个批次下的索引位置
            base = ii * topk  # 0
            # TODO 每条样本创建一个候选集合
            # TODO transformer每推断一步,重新初始化一个cands
            # TODO cands 记录了transformer每一步,每个样本的候选
            cands = []

            # TODO 由于range(lastks[ii])的范围在range(0) -> range(5)之间, 不存在跨样本情况
            # TODO output[base:, :] => [5, 3665] 每条样本的每个复制的位置
            # TODO 在首次由于初始了5个<start>,相当于初始只有一个,在初始层需要1*5次,cands长度为5
            # TODO 从第二次开始,每次初始最大有5种可能,在每一层需要5*5次,cands最大长度为25
            # TODO lastks控制了已经到end的不会再做beam search操作了
            for k, wprobs in zip(range(lastks[ii]), output[base:, :]):
                # TODO 拿到之前那个状态序列的具体索引位置,基础位置+候选内部的位置
                prev = base + k  # 0
                # TODO 这个用于判断上一步刚刚达到end的序列,这个刚刚end的不会再进入本轮的cands了
                if len(decoded_indexes[prev]) > 0 and decoded_indexes[prev][-1] == end_mark:
                    continue
                # TODO wprobs 是一维 => [3665, ],输出最大的5个值 的索引
                ind = np.argpartition(wprobs, -topk)[-topk:]
                # TODO wprobs[ind] 输出最大的5个logsoftmax值
                # TODO [(1129, -10.39108), (21, -0.00032543472), (5, -9.267304), (14, -10.058955), (533, -10.155)]
                wsorted = [(k, x) for k, x in zip(ind, wprobs[ind])]
                # wsorted = sorted(list(enumerate(wprobs)), key=lambda x:x[-1], reverse=True)   # slow
                # TODO 循环每个候选词,计算每个候选词之后最新的序列得分,同时根据序列得分做早停
                for wid, wp in wsorted[:topk]:
                    # TODO 截止到现在的路径得分, 采用logsoftmax得分相加等效于概率相乘, 每个复制位置初始都是0分
                    # TODO 得到除<start>位置外的第一个单词带来的序列得分
                    # TODO 这个地方只是计算,后续会回写到decoded_logps
                    wprob = decoded_logps[prev] + wp
                    # TODO 如果这个候选单词导致的最新得分,已经小于历史已经预测结束(到<end>)的某个完整预测序列得分,则放弃这个候选单词,不再加入候选
                    # TODO 该索引位置从conds剔除,相当于在此结束了后续对该复制序列的寻找
                    if wprob < bests.get(ii, -1e5) * early_stop_mult:
                        continue
                    # TODO [(0, 1129, -10.39108), (0, 21, -0.00032543472)...]
                    # TODO 经过早停逻辑之后的不超过5个
                    cands.append((prev, wid, wprob))
                # TODO cands => [(0, 1129, -10.391079902648926), (0, 21, -0.0003254347248002887), (0, 5, -9.267304420471191), (0, 14, -10.058955192565918), (0, 533, -10.154999732971191)]
            # TODO [(0, 21, -0.0003254347248002887), (0, 5, -9.267304420471191), (0, 14, -10.058955192565918), (0, 533, -10.154999732971191), (0, 1129, -10.391079902648926)]
            # TODO cands的容量从5开始,一直处于25,最后慢慢归0
            cands.sort(key=lambda x: x[-1], reverse=True)
            # TODO 从最大25个中挑选top5,且从大到小排序
            cands = cands[:topk]
            # TODO 更新最新的beam size
            lastks[ii] = len(cands)  # 5
            # TODO 开始对筛选之后的新top5做结果整理,将最新的得分高的排在最前面,然后按照这个顺序整理出最新的(给下一步预测使用的)Q,K,V
            for kk, zz in enumerate(cands):
                # TODO prev是前词在5个候选中的索引位置,从0到4,首次prev <start>全是0
                prev, wid, wprob = zz
                # TODO 拿到新top5在候选中的索引位置,npos是最新的序列得分排名
                npos = base + kk
                # TODO 将从25个筛选出的top5对齐到beam size最终输出的5个上
                # TODO 2层
                for k in range(len(next_dec_outputs)):
                    # TODO 总序列分从高到低,更新dec_outputs,形成下一次的K,V
                    next_dec_outputs[k][npos, :, :] = dec_outputs[k][prev]
                # TODO 总序列分从高到低,更新target_one,形成下一次的Q
                target_one[npos, 0] = wid
                # TODO 更新截止到现在的路径得分
                decoded_logps[npos] = wprob
                # TODO 不断添加记录下买一个预测的单词
                # print(next_decoded_indexes)
                # TODO 如果prev没有出现,则从该步开始,他的历史序列不会被记录,被next_decoded_indexes的初始化1替代
                next_decoded_indexes[npos] = decoded_indexes[prev].copy()
                next_decoded_indexes[npos].append(wid)
                if wid == end_mark:
                    # TODO (该批次下的位置号, 到目前位置的预测单词序列, 该序列的截止得分)
                    final_results.append((ii, decoded_indexes[prev].copy(), wprob))
                    # TODO 记录下该条样本的最佳得分
                    if ii not in bests or wprob > bests[ii]:
                        bests[ii] = wprob
        # TODO 该批次下所有样本候选数量为0
        if sum(lastks) == 0:
            break
        dec_outputs = next_dec_outputs
        decoded_indexes = next_decoded_indexes

    return final_results

其中在decode_model在做Transformer的推理,输出的dec_output为当下最新自注意信息K,V,output为预测词表的概率分布。作者进一步将得分转化为logsoftmax的形式

output = np.exp(output[:, 0, :])
output = np.log(output / np.sum(output, -1, keepdims=True) + 1e-8)

进入解码阶段,一开始作者先判断上一步的解码是否刚刚好将某候选解码到<end>,如果是则略过,这个略过操作会对后续的剩余有效候选数造成影响

if len(decoded_indexes[prev]) > 0 and decoded_indexes[prev][-1] == end_mark:
    continue

然后取output的topK和已有的候选序列组合,形成最大topK的平方条最新的候选路径,和候选序列的已有得分进行相加得到最新的topK的平方条得分wprob

ind = np.argpartition(wprobs, -topk)[-topk:]
wsorted = [(k, x) for k, x in zip(ind, wprobs[ind])]
# TODO 循环每个候选词,计算每个候选词之后最新的序列得分,同时根据序列得分做早停
for wid, wp in wsorted[:topk]:
    wprob = decoded_logps[prev] + wp

紧接着判断早停,如果远远小于已解码完成的得分则略过,在代码中作者给early_stop_mult设置为5

if wprob < bests.get(ii, -1e5) * early_stop_mult:
    continue

后续作者在候选的topK平方集合里面根据分数从高到底取了topK来完成本步的解码

cands.sort(key=lambda x: x[-1], reverse=True)
cands = cands[:topk]

同时由于有早停条件和判断<end>条件,注定候选cands会越来越小,作者用lastks(last k size)来记录每条样本还有多少个可以继续搜索的候选

lastks[ii] = len(cands)  # 5

早停和<end>导致cands减少,cands决定了lastks,而lastks控制上面的循环解码的次数,实现了候选逐渐退出的效果。
接下来作者开始更新每次排名最新的索引的顺序来更新下一步的Transformer的Q,K,V,然后对于结果进行收集,对刚刚达到<end>的进行收集,写入到最终的结果结合,并且记录最大得分bests用于早停

if wid == end_mark:
    final_results.append((ii, decoded_indexes[prev].copy(), wprob))
    if ii not in bests or wprob > bests[ii]:
        bests[ii] = wprob

待所有样本的可搜索候选都为0的时候,整个批次的Beam Search结束

if sum(lastks) == 0:
    break

最终的final_results的输出为由样本序号,序列token id,序列得分组成的三元组集合。

    # TODO final_results
    #  [(0, [21, 32, 88, 57, 1346, 9, 418, 177, 30, 4], -1.4877963637708262),
    #  (0, [21, 32, 88, 22, 1346, 9, 418, 177, 30, 4], -3.38223916420327),
    #  (0, [21, 32, 88, 11, 27, 1346, 9, 418, 177, 30, 4], -2.746393658571037),
    #  (0, [21, 32, 88, 23, 27, 1346, 9, 418, 177, 30, 4], -3.023495332688185),
    #  (0, [21, 32, 88, 23, 35, 1346, 9, 418, 177, 30, 4], -4.09097491984835)]

在得到候选序列和得分后,作者在外面一层函数beam_search函数实现了长度惩罚

for i, x, y in decode_batch(src_seq[iter:iter + batch_size]):
    rets.setdefault(iter + i, []).append((x, y / np.power(len(x) + 1, length_penalty)))

至此批量Beam Search的实现代码完毕,代码中还有一些细节,读者可以结合完整代码注视仔细跟读了解。

上一篇下一篇

猜你喜欢

热点阅读