KV cache

·

平时在看大模型推理相关的技术博客时,常常看到 KV cache 这个技术。这项技术的思想是利用显存存放下 K、V 两个矩阵,从而加快大模型推理的速度,是一种“空间换时间”的优化。

1. 单层 Attention

:下面用 X=[x1,x2,x3]X=[x_1, x_2, x_3] 作为 self-attention 的输入表示,Y=[y1,y2,y3]Y=[y_1, y_2, y_3] 作为 self-attention 的输出表示。

忽略 batch 和 multi-head 维度后,self-attention 的计算可以写成:

  • 计算 QQKKVV 矩阵:Q=XWQQ = XW_QK=XWKK = XW_KV=XWVV = XW_V
  • 计算注意力权重矩阵:A=softmax(QKdk)A = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)
  • 得到输出:Y=AVY = AV

XX 中新增加一个 token x4x_4 时,最直接的做法是将长度为 4 的整个序列重新计算一遍;不过,如果我们只关心新位置对应的输出 y4y_4,就会发现其中有不少计算其实可以复用。

具体来说,在单层 self-attention 中,前 3 个 token 的输入表示并没有发生变化,因此它们对应的 KKVV 也不会变化。于是,在加入 x4x_4 后,只需要计算它对应的 q4=x4WQq_4 = x_4 W_Qk4=x4WKk_4 = x_4 W_Kv4=x4WVv_4 = x_4 W_V,再将 k4k_4v4v_4 追加到原有的 KKVV 后面,得到扩展的 K~\tilde KV~\tilde V。接下来,只需要用 q4q_4K~\tilde KV~\tilde V 计算第 4 个位置的输出:y4=softmax(q4K~dk)V~y_4 = \operatorname{softmax}\left(\frac{q_4 \tilde K^\top}{\sqrt{d_k}}\right)\tilde V

也就是说,在这个过程中,历史 token 对应的 KKVV 都可以直接复用,没有必要重复计算。因此,可以将它们缓存起来,供后续生成继续使用,这就是 KV cache 最直观的来源。

:这里强调的是“自回归解码时对新位置输出的计算”。KV cache 主要用于推理阶段,是因为训练通常对整个序列并行计算,而推理是逐 token 进行的,历史 KKVV 存在明确的复用价值。

2. 多层 Attention

对于多层 Transformer,KV cache 要稳定成立,还需要一个额外条件:Attention 必须是 causal self-attention。

原因在于,多层结构中,后一层的输入来自前一层的输出。假设输入为 XX,经过第一层 Attention 后得到输出 YY。当序列中新增一个 token x4x_4 时,如果 Attention 不是 causal 的,那么原来的位置 x1x_1x2x_2x3x_3 在计算注意力时也可能关注到这个新加入的 x4x_4。这样一来,第一层中原本对应的输出 y1y_1y2y_2y3y_3 就会发生变化,而不仅仅是在原有 YY 的末尾多出一个新的 y4y_4

一旦第一层的历史输出发生变化,第二层的输入也会随之改变。进一步地,第二层中历史位置对应的 KKVV 也需要重新计算,后续各层同理。也就是说,在没有 causal mask 的情况下,新 token 的加入会沿着网络层层向前传播,导致先前缓存的历史 KKVV 失效,因此无法直接复用。

而在 causal self-attention 中,由于第 ii 个位置只能关注自己以及之前的 token,历史位置不会受到未来新 token 的影响。因此,当新增 x4x_4 时,前面各层中对应历史位置的输出都保持不变,模型只需要为新 token 计算各层新的 qqkkvv,历史缓存即可继续复用。这也是为什么 KV cache 通常与自回归的 causal attention 一起出现。

3. 参考资料