LLM 参数量及 FLOPs 计算

LLM 参数量及 FLOPs 计算
可爱可倾LLM 参数量、FLOPs计算和中间激活值的存储
分析基于Decoder-only的LLM框架
| 参数 | 符号 | 说明 |
|---|---|---|
| Decoder层数 | l | |
| Token嵌入维度 | d | |
| Attention层嵌入维度 | d | |
| MLP隐藏层维度 | 4d | 通常设置为嵌入维度4倍 |
| Attention head数量 | n | 要求其整除d |
| 词表尺寸 | V | |
| batch_size | b | |
| 模型输入长度 | s |
1. 模型参数量
- Embedding层:\([V, d]\) —— \(Vd\)
- Self-Attention层(通常是没有带偏置项) —— \(4d^2 (+ 4d)\)
- Q K V 矩阵: \([d, d]\) —— \(3d^2\)
- O 矩阵:\([d, d]\) —— \(d^2\)
- 如果带偏置,每个都是\(d\) —— \(4d\)
- MLP层(通常带偏置项) —— \(8d^2 +
5d\)
- X->H:\([d, 4d]\) —— \(4d^2+4d\)
- H->O:\([4d, d]\) —— \(4d^2+d\)
- 两个LayerNorm层 —— \(2 * 2d\)
- 缩放参数 Scale:\([d]\) —— \(d\)
- 平移参数 Bias:\([d]\) —— \(d\)
| 模块 | 数量 | 单个参数量 | 总参数量 |
|---|---|---|---|
| Embedding | 1 | \(Vd\) | \(Vd\) |
| Self-Attention | l | \(4d^2 (+ 4d)\) | \(4ld^2 (+ 4ld)\) |
| MLP | l | \(8d^2 + 5d\) | \(8ld^2 + 5ld\) |
| LayerNorm | 2l | \(2d\) | \(4ld\) |
总参数量:\(Vd + 12ld^2 + 9ld (+ 4ld)\) 近似于 \(12ld^2\)
比如LLaMa-7B,\(V=128k, d=4096, l=32\)
- 总参数量为\(128k*4096 + 12*32*4096^2 + 9*32*4096 \approx 6.98B\)
- 粗略计算为\(12*32*4096^2 \approx 6.44B\)
2. 显存占用
2.1 训练过程
在训练神经网络的过程中,占用显存的大头主要分为四部分:
- 模型参数
- 前向计算过程中产生的中间激活(中间激活的显存占用后面会详细介绍)
- 反向传递计算得到的梯度
- 优化器状态
训练大模型时通常会采用AdamW优化器,并用混合精度训练来加速训练,基于这个前提分析显存占用。
在一次训练迭代中,每个可训练模型参数都会对应1个梯度,并对应2个优化器状态(Adam优化器梯度的一阶动量和二阶动量)。 float16数据类型的元素占2个bytes,float32数据类型的元素占4个bytes。 在混合精度训练中,会使用float16的模型参数进行前向传递和后向传递,计算得到float16的梯度; 在优化器更新模型参数时,会使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。
设模型参数为\(N\),则显存占用为\(20N\) bytes(不包括中间激活)。
- 模型参数:\(2N\) bytes (float16) + \(4N\) bytes (float32)
- 梯度:\(2N\) bytes (float16) + \(4N\) bytes (float32)
- 优化器状态:\(2 \times 4N\) bytes (float32)
2.2 推理过程
不需要存储梯度和优化器状态,也无需存储中间激活值,只需要存储模型参数,显存占用为\(2N\) bytes (采用float16推理)。 注:如果启用KV缓存,则需要额外存储KV缓存。
3. FLOPs
LLM 中的主要运算是矩阵乘法,故考察 LLM 计算量时,通常只关注矩阵乘法运算对应的浮点计算量 \(m \times n\) 矩阵乘以 \(n \times k\) 矩阵( n-1次加法和n次乘法 )的运算量为 \(mk(2n-1)\),但是 GPU 计算矩阵乘法时一般使用 FMA (fused multiply–add) 进行计算,一次 FMA 可以计算一个乘法和一个加法,因此实际的 FLOPs 计算量为 \(mk(2n)\)
假设输入的batch为\(b \times s \times d\),其中\(b\)为batch_size,\(s\)为输入长度,\(d\)为嵌入维度
- Embedding层:查表操作,不涉及矩阵乘法,故不计入FLOPs
- 预测多分类头(logits)将尺寸为 d 的隐藏向量映射为词表大小:\([b \times s \times d] \times [d \times V] = [b \times s \times V]\) —— \(2bsVd\)
- Self-Attention层:\(8bsd^2+4bs^2d\)
\[
Q=xW_{Q}, K=xW_{K}, V=xW_{V}\\
Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V\\
x_{out}= AW_{O}+x
\]
- 计算 QKV 矩阵:\([b \times s \times d] \times [d \times d] = [b \times s \times d]\) —— \(6bsd^2\)
- 计算注意力权重\(QK^T\):\([b \times n \times s \times d/n] \times [b \times n \times d/n \times s] = [b \times n \times s \times s]\) —— \(2bs^2d\)
- 汇聚价值信息\(AV\): \([b \times n \times s \times s] \times [b \times n \times s \times d/n] = [b \times n \times s \times d/n]\) —— \(2bs^2d\)
- 拼接多头注意力: 不涉及矩阵乘法,故不计入FLOPs
- 输出矩阵\(O\): \([b \times s \times d] \times [d \times d] = [b \times s \times d]\) —— \(2bsd^2\)
- MLP层:\(16bsd^2\) \[
h=Relu(x_{out}W_1+b_1)\\
h_{out}=hW_2+b2\\
x=h_{out}+x_{out}
\]
- 第一个线性层:\([b \times s \times d] \times [d \times 4d] = [b \times s \times 4d]\) —— \(8bsd^2\)
- 第二个线性层:\([b \times s \times 4d] \times [4d \times d] = [b \times s \times d]\) —— \(8bsd^2\)
- 反向传播:反向传播过程中每个非第一层都有两次矩阵乘法操作,而相应的前向过程中只有一次(第一层的后向-前向FLOPs比率是1:1,而其他层后向-前向FLOPs比率是2:1)。随着网络层数增加,反向传播计算量会越来越接近正向传播计算量的两倍。
| 模块 | 数量 | 单次FLOPs | 总FLOPs |
|---|---|---|---|
| Embedding | 1 | 0 | 0 |
| LM Head(logits) | 1 | \(2bsVd\) | \(2bsVd\) |
| Self-Attention | l | \(8bsd^2+4bs^2d\) | \(8blsd^2+4bls^2d\) |
| MLP | l | \(16bsd^2\) | \(16blsd^2\) |
前向传播总FLOPs:\(2bsVd + 24blsd^2 + 4bls^2d\),近似为\(24blsd^2\) 加上反向传播的FLOPs,一次训练迭代的总计算量为\(72blsd^2\)
3.1 计算量和参数量的关系
进一步考虑Token数据量为\(D=bs\),参数量为\(N=12ld^2\)
- 推理:\(2DN\)
- 训练:\(6DN\)
可以近似认为:在一次推理中,对于每个token,每个模型参数,需要进行2次浮点数运算,即一次乘法法运算和一次加法运算,而在一次训练中,需要进行6次浮点数运算。
3.2 估计训练时间
训练神经网络的一次迭代分为三步
- 前向传递计算损失函数;
- 后向传递计算梯度;
- 优化器更新模型参数。
后向传递的耗时几乎是前向传递的两倍,相比之下,优化器更新的耗时几乎可以忽略。进一步的
\[ 训练时间 \approx \frac{6DN}{FLOPS \times {GPU 数量} \times {GPU利用率}} \]
如果使用激活重计算技术来减少中间激活显存需要进行一次额外的前向传递,则
\[ 训练时间 \approx \frac{8DN}{FLOPS \times {GPU 数量} \times {GPU利用率}} \]
一般来讲,GPU利用率一般在\(0.3-0.55\)之间。
4 中间激活值的存储
前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量(不包含模型参数和优化器状态,但包含dropout操作需要用到的mask矩阵)
大模型在训练过程中通常采用混合精度训练,中间激活值一般是float16或者bfloat16数据类型的。在分析中间激活的显存占用时,假设中间激活值是以float16或bfloat16数据格式来保存的,每个元素占了2个bytes。唯一例外的是,dropout操作的mask矩阵,每个元素只占1个bytes。
- Self-Attention层:\(11bsd+5bs^2n\)
- 对于QKV,需要保存输入x:\([b \times s \times d]\) —— \(2bsd\)
- 对于\(QK^T\),需要保存Q和K矩阵:\([b \times s \times d]\) —— \(4bsd\)
- 对于Softmax,需要保存\(QK^T\):\([b \times n \times s \times s]\) —— \(2bs^2n\)
- 会有一个dropout操作,需要保存mask矩阵,与\(QK^T\)尺寸相同:\([b \times n \times s \times s]\) —— \(bs^2n\)
- 之后计算A,需要保存V矩阵和softmax:\([b \times s \times d]\)和\([b \times n \times s \times s]\) —— \(2bsd + 2bs^2n\)
- 计算输出矩阵O,需要保存A:\([b \times s \times d]\) —— \(2bsd\)
- 会有一个dropout操作,需要保存mask矩阵,与A尺寸相同:\([b \times s \times d]\) —— \(bsd\)
- MLP层:\(19bsd\)
- 对于第一个线性层,需要保存输入x:\([b \times s \times d]\) —— \(2bsd\)
- 激活函数需要保存输入:\([b \times s \times 4d]\) —— \(8bsd\)
- 对于第二个线性层,需要保存输入h:\([b \times s \times 4d]\) —— \(8bsd\)
- 会有一个dropout操作,需要保存mask矩阵,与h尺寸相同:\([b \times s \times d]\) —— \(bsd\)
- 两个LayerNorm层:\(2 \times 2bsd\) 保存输入:\([b \times s \times d]\) —— \(2bsd\)
| 模块 | 数量 | 单个占用 | 总占用 |
|---|---|---|---|
| Self-Attention | l | \(11bsd+5bs^2n\) | \(11blsd+5bls^2n\) |
| MLP | l | \(19bsd\) | \(19blsd\) |
| LayerNorm | 2l | \(2bsd\) | \(4blsd\) |
中间激活值总显存占用:\(34blsd+5bls^2n\) bytes
4.1 中间激活值和模型参数的关系
一次训练中,占用显存的四大部分中,模型参数、梯度和优化器状态的显存占用是与模型参数量成正比的,与输入数据量无关;而中间激活值的显存占用是与输入数据量(批次大小b和序列长度s)成正相关的。 所以当训练时出现显存不足的情况时,可以通过减少batch_size来减少中间激活值的显存占用。
随着批次大小 b 的增大,中间激活占用的显存远远超过了模型参数显存。通常会采用激活重计算技术来减少中间激活,理论上可以将中间激活显存从\(O(n)\)降低到\(O(\sqrt{n})\)(每\(\sqrt{n}\)层存储一个检查点)。
激活重计算本质是时间换空间,通过在反向传播时重新计算部分激活值,而不是在前向传播时存储所有激活值,从而显著减少内存使用。
- 前向传播阶段:在前向传播中,不存储所有层的激活值。仅存储一些关键的检查点(checkpoints),即在若干层之后保存一次激活值。
- 反向传播阶段:当需要计算梯度时,重新计算未存储的激活值。这意味着在反向传播阶段需要重新执行一部分前向传播计算。
- 内存与计算的权衡:通过减少激活值的存储来节省内存,但代价是增加了计算开销,因为需要重新计算激活值。








