KV cache
平时在看大模型推理相关的技术博客时,常常看到 KV cache 这个技术。这项技术的思想是利用显存存放下 K、V 两个矩阵,从而加快大模型推理的速度,是一种“空间换时间”的优化。
1. 单层 Attention
注:下面用 作为 self-attention 的输入表示, 作为 self-attention 的输出表示。
忽略 batch 和 multi-head 维度后,self-attention 的计算可以写成:
- 计算 、、 矩阵:、、
- 计算注意力权重矩阵:
- 得到输出:
当 中新增加一个 token 时,最直接的做法是将长度为 4 的整个序列重新计算一遍;不过,如果我们只关心新位置对应的输出 ,就会发现其中有不少计算其实可以复用。
具体来说,在单层 self-attention 中,前 3 个 token 的输入表示并没有发生变化,因此它们对应的 和 也不会变化。于是,在加入 后,只需要计算它对应的 、、,再将 和 追加到原有的 、 后面,得到扩展的 和 。接下来,只需要用 与 、 计算第 4 个位置的输出:
也就是说,在这个过程中,历史 token 对应的 和 都可以直接复用,没有必要重复计算。因此,可以将它们缓存起来,供后续生成继续使用,这就是 KV cache 最直观的来源。
注:这里强调的是“自回归解码时对新位置输出的计算”。KV cache 主要用于推理阶段,是因为训练通常对整个序列并行计算,而推理是逐 token 进行的,历史 、 存在明确的复用价值。
2. 多层 Attention
对于多层 Transformer,KV cache 要稳定成立,还需要一个额外条件:Attention 必须是 causal self-attention。
原因在于,多层结构中,后一层的输入来自前一层的输出。假设输入为 ,经过第一层 Attention 后得到输出 。当序列中新增一个 token 时,如果 Attention 不是 causal 的,那么原来的位置 、、 在计算注意力时也可能关注到这个新加入的 。这样一来,第一层中原本对应的输出 、、 就会发生变化,而不仅仅是在原有 的末尾多出一个新的 。
一旦第一层的历史输出发生变化,第二层的输入也会随之改变。进一步地,第二层中历史位置对应的 和 也需要重新计算,后续各层同理。也就是说,在没有 causal mask 的情况下,新 token 的加入会沿着网络层层向前传播,导致先前缓存的历史 、 失效,因此无法直接复用。
而在 causal self-attention 中,由于第 个位置只能关注自己以及之前的 token,历史位置不会受到未来新 token 的影响。因此,当新增 时,前面各层中对应历史位置的输出都保持不变,模型只需要为新 token 计算各层新的 、、,历史缓存即可继续复用。这也是为什么 KV cache 通常与自回归的 causal attention 一起出现。