Flash Attention

Flash Attention

FlashAttention 是一种具有 IO 感知,且兼具快速、内存高效的新型注意力算法。为了缓解LLM输入输出序列s扩展时计算复杂度和空间复杂度都是\(O(s^2)\)的问题。

1 Transformer复杂度

transformer模型中self-attention的计算量和储存复杂度随着序列长度 s​ 呈二次方增长,这限制了大语言模型的最大序列长度 s​ 的大小。

1.1 计算复杂度

模块 数量 单次FLOPs 总FLOPs
Embedding 1 0 0
LM Head(logits) 1 \(2bsVd\) \(2bsVd\)
Self-Attention l \(8bsd^2+4bs^2d\) \(8blsd^2+4bls^2d\)
MLP l \(16bsd^2\) \(16blsd^2\)

总FLOPs:\(2bsVd + 24blsd^2 + 4bls^2d\)

1.2 空间复杂度

模块 数量 单个占用 总占用
Self-Attention l \(11bsd+5bs^2n\) \(11blsd+5bls^2n\)
MLP l \(19bsd\) \(19blsd\)
LayerNorm 2l \(2bsd\) \(4blsd\)

中间激活值总显存占用:\(34blsd+5bls^2n\) bytes

2 Standard Attention

首先回顾Attention的计算过程:

\[ \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d}})V \]

其中,\(Q, K, V in \mathbb{R}^{s \times d}\)\(d\)是embedding维度,\(s\)是序列长度。