LLM推理优化——KV Cache篇(百倍提速)
注意:KV Cache本质上是空间换时间的技术。与计算机组成原理中的cache不同,它不涉及访存优化。
不知道大家在用LLM的时候,有没有注意到一个问题:我们在输入我们的问题之后,需要等待一段漫长的时间才能看到第一个字符的响应,而等待第一个之后,后续的响应却非常之快,这就是使用了KV Cache加速后的带来的优势。
LLM推理过程与自注意力机制
LLM的推理过程和训练过程是有所区别的,LLM在训练过程中使用因果掩码并行化训练,无需像RNN一样等待之前的结果运算结束。但是他的推理过程却类似RNN,并且有大量的重复计算。详细解释可以参照上一篇Blog:LLM的训练与推断。我们已经知道LLM推断时实际上是一种自回归模型 f θ f_\theta fθ,假设在 t t t步及其之前的输入和推断表示为 x 1 : t x_{1:t} x1:t,我们可以用公式表示它:
x 1 : t + 1 = f θ ( x 1 : t ) (1) % \begin{align} x_{1:t+1} = f_\theta(x_{1:t}) \tag{1} % \end{align} x1:t+1=fθ(x1:t)(1)
从公式1我们可以看到,即使 x 1 : t x_{1:t} x1:t已经是计算过的,但是在计算 x t + 1 x_{t+1} xt+1时却还要重复计算 x 1 : t x_{1:t} x1:t,当序列愈来愈长,这个计算代价也越来越不可接受,尤其绝大多数计算都是在做无用功(重复计算)。
基于此,我们就有一个朴素的优化思考:我们能否去除掉这些重复计算?那答案是当然可以,不然就没有这篇博客了🤣。
那么如何去除呢?如果有动态规划、搜索基础的小伙伴可能会比较眼熟,类似于常用的剪枝操作。将后续仍然需要使用的计算结果保存起来,后续使用时不必再次计算。
那么来看一下具体是如何实施的,从LLM的训练与推断中我们可以知道,需要重复计算的矛盾点来自于自注意力机制,首先回顾一下它:
a t t n ( Q , K , V ) = s o f t m a x ( Q T K d k ) V attn(Q,K,V) = softmax\left(\frac{Q^TK}{\sqrt{d_k}}\right)V attn(Q,K,V)=softmax(dkQTK)V
我们可以用行列向量的形式表示第t步计算 x t x_t xt时的 ( Q T K ) t (Q^TK)_t (QTK)t
( Q T K ) t = ( q 1 T k 1 q 1 T k 2 … q 1 T k t q 2 T k 1 q 2 T k 2 … q 2 T k t ⋮ ⋮ ⋱ ⋮ q t T k 1 q t T k 2 … q t T k t ) ( Q T K ) t + 1 = ( q 1 T k 1 q 1 T k 2 … q 1 T k t q 1 T k t + 1 q 2 T k 1 q 2 T k 2 … q 2 T k t q 2 T k t + 1 ⋮ ⋮ ⋮ ⋮ q t T k 1 q t T k 2 … q t T k t q t T k t + 1 q t + 1 T k 1 q t + 1 T k 2 … q t + 1 T k t q t + 1 T k t + 1 ) \begin{align} (Q^TK)_t&= \left( \begin{array}{cccc} q^T_1k_1 & q^T_1k_2&\ldots&q^T_1k_t\\ q^T_2k_1 & q^T_2k_2&\ldots&q^T_2k_t\\ \vdots& \vdots&\ddots&\vdots\\ q^T_tk_1 & q^T_tk_2&\ldots&q^T_tk_t\\ \end{array}\right)\notag \\ (Q^TK)_{t+1}&= \left( \begin{array}{ccccc} q^T_1k_1 & q^T_1k_2&\ldots&q^T_1k_t&q^T_1k_{t+1}\\ q^T_2k_1 & q^T_2k_2&\ldots&q^T_2k_t&q^T_2k_{t+1}\\ \vdots& \vdots&\ &\vdots&\vdots\\ q^T_tk_1 & q^T_tk_2&\ldots&q^T_tk_t&q^T_tk_{t+1}\\ q^T_{t+1}k_1 & q^T_{t+1}k_2&\ldots&q^T_{t+1}k_t&q^T_{t+1}k_{t+1}\\ \end{array}\right)\notag \\ \end{align} (QTK)t(QTK)t+1= q1Tk1q2Tk1⋮qtTk1q1Tk2q2Tk2⋮qtTk2……⋱…q1Tktq2Tkt⋮qtTkt = q1Tk1q2Tk1⋮qtTk1qt+1Tk1q1Tk2q2Tk2⋮qtTk2qt+1Tk2…… ……q1Tktq2Tkt⋮qtTktqt+1Tktq1Tkt+1q2Tkt+1⋮qtTkt+1qt+1Tkt+1
从式子中可以看到,在计算 ( Q T K ) t + 1 (Q^TK)_{t+1} (QTK)t+1时,前t行t列已经被计算过,我们只需要计算最后一行并且保存即可。由 q t + 1 T k 1 , q t + 1 T k 2 , … , q t + 1 T k t , q t + 1 T k t + 1 q^T_{t+1}k_1, q^T_{t+1}k_2,\ldots,q^T_{t+1}k_t,q^T_{t+1}k_{t+1} qt+1Tk1,qt+1Tk2,…,qt+1Tkt,qt+1Tkt+1可知,我们还需要知道之前每次计算出的 k k k值,所以在保存每行值之外,也应当保存 k 1 , k 2 , … , k t + 1 k_1, k_2,\ldots,k_{t+1} k1,k2,…,kt+1的值。
这里可能会有疑惑,最后一列用了之前的 q 1 , q 2 , … , q t q_1,q_2, \ldots,q_t q1,q2,…,qt,为什么不保存 q q q的值。实际上, Q T K Q^TK QTK矩阵会经过一个掩码操作,他最后得到是一个下三角矩阵,我们不必再次计算,直接补0即可。
保存 V V V值亦是同理。
保存 K , V K,V K,V是不是KV Cache名称的由来呢?