Transformer系列:注意力机制的优化,MQA和GQA原理
关键词:Transformer
,注意力机制
,MQA
,GQA
前言
多查询注意力(MQA)、分组查询注意力(GQA)是Transformer中多头注意力(MHA)的变种,它们大幅提高了解码器的推理效率,在LLaMA-2,ChatGLM2等大模型中有广泛使用,本篇介绍MQA、GQA的原理并分析其源码实现。
内容摘要
- 使用MQA,GQA的背景介绍
- MQA,GQA原理简述
- MQA,GQA推理加速分析
- ChatGLM2-6B中的MQA/GQA源码分析
使用MQA,GQA的背景介绍
多查询注意力(Multi Query Attention,MQA)提出于2019年的论文《Fast Transformer Decoding: One Write-Head is All
You Need》,旨在解决Transformer增量推理阶段效率低下的问题,在当时并没有引起关注,而随着近几年Transformer和GPT成为生成式大模型的基座,面临着产业落地的实际情况,导致GPT的推理加速备受关注,因此MQA又重新被提及起来。
分组查询注意力(Group Query Attention,GQA)提出于2023年,是MQA更一般的形式,它介于MQA和MHA之间,是模型预测表现和模型推理性能之间的一个折衷。
MQA,GQA原理简述
MQA的原理很简单,它将原生Transformer每一层多头注意力的Key线性映射矩阵、Value线性映射矩阵改为该层下所有头共享,也就是说K、V矩阵每层只有一个,而Q矩阵不受影响,其数量和注意力头数相等。以ChatGLM2-6B为例,一共28层,32个注意力头,输入从4096经过Q、K、V矩阵映射维度为128,若采用原生多头注意力机制,则Q、K、V矩阵各有28×32个,而采用MQA的方式则整个模型包含28×32个Q矩阵,28个K矩阵,28个V矩阵,示意图如下
MHA和MQA的差别
可想而知,MQA这种方式大幅减小了参数数量,带来推理加速的同时会造成模型性能损失,且在训练过程使得模型变得不稳定,因此在此基础上提出了GQA,它将Query进行分组,每个组内共享一组Key、Value,令组的数量为N,若N等于1此时等效于MQA,若N等于Query头的数量,此时退化为MHA。GQA是推理效率和模型性能的trade-off。
GQA(中)和MQA(右)对比MQA,GQA推理加速分析
MQA能够大幅加速采用MHA的Transformer的推理,但是会有明显的性能损失,而GQA通过设置合适的分组大小,可以和MQA的推理性能几乎相等,同时逼近MHA的模型性能。作者在GQA的论文中给到了实验结论来印证这一点。
作者采用T5模型作为研究对象,模型版本采用T5-Large和T5-XXL,它们都采用MHA注意力方式,其中Large参数量770M,XXL参数量11B,在此基础上作者通过up-training方法将T5-XXL改造为MQA和GQA,最终一共四个版本模型进行精度和推理效率的对比,结果如下
横轴代表平均每条样本的推理耗时,越大代表延迟越大,纵轴代表在众多数据集上的评价得分,越大代表得分越高。在MHA方式下,由于XXL的模型参数更大,因此MHA-XXL的推理延迟高于MHA-Large,同时MHA-XXL的模型评分在所有版本里面最高。
经过up-training方法将MHA改造为MQA之后,MQA-XXL获得了所有版本的最低延迟,甚至还低于小一个型号的Large模型,同时MQA使得其模型评分比MHA-XXL降低了,但还是超越了小一个信号的Large的模型,表明大参数量的MQA模型不论在精度还是效率上都超越了小参数量的MHA模型。
而经过up-training方法将MHA改造为GQA,GQA-XXL的推理延迟几乎和MQA-XXL相等,而其性能评分也和参数量最大的MHA-XXL十分接近,整体上GQA达到了最佳效果,推理性能和模型评分都十分优秀。
GQA的分组数是一个超参数,组数越大越接近MHA,推理延迟越大,同时模型精度也越高,作者给出了他的实验结论表明,当组数量从1逐渐上升到8时,模型推理的开销并没有明显的增长,在8以后推理开销显著变大,最终作者采用8个分组作为他的最佳选择。
以上从实验结果层给到结论,MQA略微损失了模型精度,但是确实能够大幅降低推理开销,而如果选择了合适的分组数,GQA能够两者皆得。在理论层,MQA和GQA对推理的帮助主要是以下两点
- 降低内存读取模型权重的时间开销:由于Key矩阵和Value矩阵数量变少了,因此权重参数量也减少了,需要读取到内存的数量量少了,因此减少了读取权重的等待时间
- KV-Cache空间占用降低:KV-Cache会将之前推理过的Key、Value向量存储在内存中,而随着步长和batch_size的增长,KV-Cache空间占用越来越高,使得KV-Cache不能被高效的读写,而MHA和GQA方式使得KV-Cache需要存储的参数量降低了head_num倍,从而提高KV-Cache的读写效率;另一方面,可以有空间来增大batch_size,从而提高模型推理的吞吐量
注意MQA和GQA并没有降低Attention的计算量(FLOPs),因为Key、Value映射矩阵会以广播变量的形式拓展到和MHA和一样,因此计算量不变,只是Key、Value参数共享。
ChatGLM2-6B中的MQA/GQA源码分析
本节采用ChatGLM2-6B的模型源码modeling_chatglm.py来说明MQA和GQA的实现,这两者在代码上没有区别,因为MQA是GQA的特例,当分组数等于1时就是MQA,而chatglm2-6B采用的是分组数为2的GQA,从它的配置文件config.json可以观察得到
{
"_name_or_path": "THUDM/chatglm2-6b",
"model_type": "chatglm",
"architectures": [
"ChatGLMModel"
],
"auto_map": {
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration"
},
...
"multi_query_attention": true,
"multi_query_group_num": 2,
....
}
其中multi_query_attention代表是否开启多查询注意力,multi_query_group_num代表分组数。
MQA和GQA仅涉及到注意力层,因此直接定位到SelfAttention的代码块
class SelfAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
super(SelfAttention, self).__init__()
self.layer_number = max(1, layer_number)
# TODO 128 * 32
self.projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads
# TODO true 多查询注意力
self.multi_query_attention = config.multi_query_attention
# TODO qkv线性映射层到3*d
self.qkv_hidden_size = 3 * self.projection_size
if self.multi_query_attention:
self.num_multi_query_groups_per_partition = config.multi_query_group_num # 2
self.qkv_hidden_size = (
# TODO (128 * 32 + 2 * 128 * 2)
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
)
self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device, **_config_to_kwargs(config)
)
该SelfAttention代表某一层下,32个注意力头的运算。在初始化阶段,作者采用大矩阵方案用一个矩阵将所有头QKV创建出来,因此projection_size为128×3,然后以MHA的方式将映射维度乘以3得到qkv_hidden_size,当采用多查询注意力时对qkv_hidden_size重新修改,它等于所有头的Q矩阵,加上KV矩阵各一个,因此projection_size为128 * 32 + 2 * 128 * 2=4608,最后通过一个Linear层实现QKV矩阵的创建。
在推理阶段使用query_key_value进行同意QKV映射,然后通过split算子将QKV进行分解,分解之后所有头的Query为4096维,Key和Value为256维,因为分组数为2,所有存在两个Key和两个Value。
def forward(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
):
# TODO [1(seq), 1(batch), 4608]
mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention:
# TODO [17, 1, 4096], [17, 1, 256], [17, 1, 256]
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, # TODO 32 *128
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, # TODO 2 * 128
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, # TODO 2 * 128
],
dim=-1,
)
然后作者将QKV向量的维度进行reshape,将头维度拿出来,准备在下面的代码中对KV进行广播
# TODO [17, 1, 32, 128]
query_layer = query_layer.view(
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
)
# TODO [17, 1, 2, 128]
key_layer = key_layer.view(
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
)
# TODO TODO [17, 1, 2, 128]
value_layer = value_layer.view(
value_layer.size()[:-1]
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
)
在这之前,需要将原始的Query和Key携带旋转位置编码RoPE,使得注意力能够感知到相对位置信息,此步骤和Value无关
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
然后作者对Key和Value做广播,因为Query有32个头,而GQA有2组,因此要翻32/2=16倍数,通过torch的expand进行广播实现参数共享,广播之后QKV三者的shape变成一致
if self.multi_query_attention:
# TODO TODO [17, 1, 2, 128] => [17, 1, 2, 1, 128]
key_layer = key_layer.unsqueeze(-2)
# TODO [17, 1, 2, 16, 128]
key_layer = key_layer.expand(
# TODO expand 进行广播,k,v向量共享
# TODO 只能对维度值是1的进行拓展,如果某些维不需要拓展,写为-1, 32 // 2=16
# TODO 有32个头,KV组只有2组,要复制16份
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
)
# TODO [17, 1, 2, 16, 128] => [17, 1, 32, 128]
key_layer = key_layer.contiguous().view(
key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
)
value_layer = value_layer.unsqueeze(-2)
value_layer = value_layer.expand(
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
)
value_layer = value_layer.contiguous().view(
value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
)
最后所有处理好之后计算注意力权重,并且将权重和Value相乘得到注意力的输出,这里和传统的MHA没有任何却别
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
MQA和GQA的主要代码流程结束,核心是先创建分组数的Key和Value矩阵,注意力点乘之前将Key和Value广播到和Query一致即可,全文完毕。