Flash Attention

Flash Attention
可爱可倾Flash Attention
FlashAttention 是一种具有 IO 感知,且兼具快速、内存高效的新型注意力算法。为了缓解LLM输入输出序列s扩展时计算复杂度和空间复杂度都是\(O(s^2)\)的问题。
1 Transformer复杂度
transformer模型中self-attention的计算量和储存复杂度随着序列长度 s 呈二次方增长,这限制了大语言模型的最大序列长度 s 的大小。
1.1 计算复杂度
| 模块 | 数量 | 单次FLOPs | 总FLOPs |
|---|---|---|---|
| Embedding | 1 | 0 | 0 |
| LM Head(logits) | 1 | \(2bsVd\) | \(2bsVd\) |
| Self-Attention | l | \(8bsd^2+4bs^2d\) | \(8blsd^2+4bls^2d\) |
| MLP | l | \(16bsd^2\) | \(16blsd^2\) |
总FLOPs:\(2bsVd + 24blsd^2 + 4bls^2d\)
1.2 空间复杂度
| 模块 | 数量 | 单个占用 | 总占用 |
|---|---|---|---|
| Self-Attention | l | \(11bsd+5bs^2n\) | \(11blsd+5bls^2n\) |
| MLP | l | \(19bsd\) | \(19blsd\) |
| LayerNorm | 2l | \(2bsd\) | \(4blsd\) |
中间激活值总显存占用:\(34blsd+5bls^2n\) bytes
2 Standard Attention
首先回顾Attention的计算过程:
\[ \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d}})V \]
其中,\(Q, K, V in \mathbb{R}^{s \times d}\),\(d\)是embedding维度,\(s\)是序列长度。
评论
匿名评论隐私政策
TwikooGiscus
✅ 若未加载出评论区,请刷新页面~






