本来想写一篇KV Cache压缩的综述性博客,结果写到MLA部分的时候发现越写越多,完全值得单独拿出来写篇博客,遂从KV Cache压缩博客中单独揪出MLA进行介绍。
MLA(Multi-query Latent Attention)是国内创业公司deepseek在24年5月份发布的DeepSeek-V2大模型中用到的KV Cache压缩技术,正是在该技术的加持下DeepSeek-V2可以大幅压缩KV Cache的大小,进而大幅提升吞吐量,也正是从该模型开始,大模型推理的价格一下降低到一个很低的水平。MLA是少有的由国内公司做出的硬核创新,感谢deepseek,感谢MLA!我觉得在出现新的KV Cache压缩技术之前后续的大模型可能都会采用MLA,它的压缩效果接近MQA,但是生成效果却还比MHA更好,值得大家跟进。
MLA相比MQA和MHA相比做到了既要又要,着实牛逼,代价就是太难懂了。花了一天的时间仔细研究了一下苏剑林苏神的博客《缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA》,对MLA有了一个大概的理解,下面我按照自己的思路尝试解释一下MLA,在解释的过中我抛弃了MLA论文中复杂的符号定义,重新按照自己的理解去定义相关矩阵。
常规的Attention计算会用到Q、K、V,这也是我们需要保存KV的原因,而MLA则不保留KV Cache,另外引入了一个C Cache来代替KV Cache。在执行Attention计算的时候,通过一系列等价变换,将公式中出现的K和V均用C来代替。不仅可以用C来代替KV,而且Q所有的head都共享同一个C Cache,从这个层面来说MLA和MQA很类似,只需要保留一个head的缓存即可获得非常好的结果。另外一点,现在的大模型一般会在计算Attention之前将Q和K进行RoPE(旋转位置编码),如图1所示。这就会导致单纯的C丢失了位置信息,为了弥补这个缺陷,MLA中给Q和K额外增加了 d r d_r dr个维度用来添加RoPE,其中K新增的维度也是每个Head共享。大家看到这里可能比较懵逼,且听我一一道来。
1. KV Cache变C Cache
在MHA的单head计算过程中,输入X(维度为N x d)会先送入三个projection矩阵 W Q , W K , W V W_Q,W_K,W_V WQ,WK,WV进行线性变换: