KV Cache

KV Cache
可爱可倾KV Cache
KV-Cache是一种加速Transformer推理的策略,几乎所有自回归模型都内置了KV-Cache
1 为什么需要KV-Cache
Transformer每一层分为两个部分,一个是Self-Attention,另一个是Feed Forward Network(FFN)。
- Self-Attention: \(Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V\)
- FFN: \(FFN(x) = ReLU(xW_1 + b_1)W_2 + b_2\)
自回归模型采用每次推理都将前文整句输入模型,然后预测下一个token的方式,这种方式会存在相同结果的重复推理。 令前一次待推理的文本长度为S,下一次为S+1,由于网络中的各项参数已经固定,因此两次推理对于前S个token的计算结果是完全相同的,包括Embedding映射,每一层、每一个注意力头下的KQV映射,注意力权重,以及后续的FFN层都在重复计算。
那么既然下一个token是由当前最后一个token的网络输出所决定的,那能不能仅输入最后一个token来进行推理? 答案是否定的,虽然在结果层仅由最后一个token来决定,但是中间的注意力过程它的计算依赖于前文所提供的Key、Value向量(这就是前文信息),因此也不能抛弃前文不管。
S+1位置token的推理依赖于两个要素:
- 首先是当前第S个token在网络中完整forward一遍
- 其次是除最后一个token以外,之前所有的S-1位置的token在每一层、每个注意力头下的Key,Value信息。
所以可以将Key、Value信息缓存下来,下次推理时直接使用,这就是KV-Cache的思想。
2 KV-Cache的实现
从第二次推理开始,仅需要输入当前最后一个token,单独对该token做Q,K,V映射,将past_key_values中前文所有的K,V和该token的K,V进行拼接得到完成的Key、Value向量(只是简单的拼接矩阵),最终和该token的Query计算注意力,拼接后的Key、Value也同步更新到past_key_values。
3 存储结构
KV-Cache会将截止当前各个token在每一层、每个头的Key向量和Value向量存储在内存中,在HuggingFace的代码实现中使用past_key_values变量进行存储,past_key_values是一个矩阵,其维度为[n, 2, b, h, s, d],类似一个六维的矩阵,每个维度的含义如下
- 第一维 num_layers:在外层是以每一个堆叠的Block为单位,例如堆叠12层,则一共有12组Key、Value信息
- 第二维 2:代表Key和Value这两个信息对象,索引0取到Key向量,索引1取到Value向量
- 第三维 batch_size:代表batch_size,和输入需要推理的文本条数相等,如果输入是一条文本,则b=1
- 第四维 num_heads:代表注意力头的数量,例如每层有12个头,则h=12
- 第五维 seq_len:代表截止到当前token为止的文本长度,在每一个历史token位置上该token在每一层每个头下的Key,Value信息
- 第六维 d:代表Key、Value向量的映射维度,若token总的映射维度为768,注意力头数为12,则d=768/12=64
可以发现每一步推理后的差异仅仅产生在seq_len这个维度上(seq_len维度大小会加1),它是由新推理的那一个token所对应的Key,Value拼接到上一个past_key_values的seq_len维度中所导致的。 例如 [12, 2, 1, 12, 5, 64] -> [12, 2, 1, 12, 6, 64]
4 KV-Cache内存占用、FLOPs下降分析
4.1 内存占用
KV-Cache本质上是用空间换时间,存储的Key、Value矩阵会额外占用内存。
假设以float16精度(占用两个字节)来存储,每个token的存储占用公式如下: \[ 2 \times n_{layers} \times 2 \times n_{heads} \times d \]
以LLaMa-7B-FP16为例,模型加载占用显存14GB,向量维度4096,堆叠32层,最大推理步长4096。 若推理一个batch为2,长度为4096的句子,KV-Cache占用的存储空间为2×2×32×4096×2×4096=4294967296字节,约等于4GB,随着推理的batch增大,推理长度变长,KV-Cache占用的存储空间可能超过模型本身。
4.2 FLOPs下降
另一方面KV-Cache极大地降低了FLOPs(浮点计算量),表面上KV-Cache省去了之前每个token的Key、Value的计算量,每个token在所有层下计算Key、Value的FLOPs公式如下: \[ 2 \times 2 \times n_{layers} \times d^2 \]
其中d平方代表从token Embedding到Key或者Value向量的过程,乘以2是矩阵相乘中逐位相乘再相加导致有两个操作,再乘以2代表Key、Value各一个。
以LLaMa-7B为例,推理一个batch为2,长度为4096的句子,光计算KV一共节省了(2×2×32×4096×4096)×4096×2=17592B FLOPs的计算量
额外的,不仅省去了前文所有token的Key、Value的映射,由此导致后续这些token的注意力权重计算,注意力的MLP层,FFN前馈传播层也都不需要再计算了,相当于推理阶段的计算复杂度永远等于只对一个token进行完整的forward推理,因此计算量大幅降低。






