LLM模型之MInference

MInference

通过动态稀疏注意力加速长上下文llm的预填充 原文链接:MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention

1. 简介

MInference是一种利用空间聚集模式的动态稀疏注意力来加速长序列预填充阶段的方法。

  1. 将注意力头分为三种可以用于gpu上的高效稀疏计算类型: A-shape, Vertical-Slash, Block-Sparse。
  2. 采用核感知的最优稀疏模式搜索方法,离线确定每个注意力头的最优模式。
  3. 在推理过程中利用快速逼近方法为不同的输入构建动态稀疏掩码,然后应用这些掩码进行稀疏注意力计算。

只计算注意权值中最重要的部分

局限性

  1. 上下文长度较小时,构建动态索引所需的时间会因为注意力计算时间的减小而增大。
  2. 当使用较高的稀疏率时,模型性能可能会明显下降。

2. 原理

2.1 稀疏注意力头类型

如图1和2所示,将注意力头分为三种类型

  1. A-shape: 该模式的注意权重集中在初始令牌和局部窗口,表现出相对较高的稳定性。
  2. Vertical-Slash: 该模式的注意权重集中在特定的标记(竖线)和固定间隔的标记(斜线)上。该图案中竖线和斜线的位置随上下文内容动态变化。
  3. Block-Sparse: 该模式是最动态的,表现出更分散的分布,保持了空间聚类的一些特征。
图1: 注意力权重三种类型
图2: 简化三种类型

2.2 稀疏注意力计算

当使用稀疏注意力计算加速长上下文llm的预填充阶段时,注意力矩阵可以表示为:

\[ \begin{equation} \bm{A(M)} = \text{Softmax}(\frac{1}{\sqrt{d}}\bm{Q}\bm{K}^\top - c(1-\bm{M})) \tag{1} \end{equation} \]

\(M_{i,j} \in \{0,1\}\) 表示注意矩阵第\({i,j}\)项的动态稀疏掩码,\(c\)是一个大的常数,确保softmax后\(M_{i,j} = 0\)那些不太重要的注意权重接近0。

所以目标函数为:

\[ \begin{equation} \begin{aligned} \min \enspace & \enspace \enspace |\bm{A}(\bm{M}) - \bm{A}_{\text{dense}}|, \\ \min \enspace & t_{\text{sparse}}(\bm{M}) + t_{\text{overhead}}(\bm{M}), \end{aligned} \tag{2} \end{equation} \]

2.3 实现步骤

2.3.1 离线核感知最优稀疏模式搜索

通过一个参考示例遍历搜索空间,以决定最优模式和设置。用于确定每个注意头将使用哪种稀疏模式,以及实际计算中模式的最佳设置(例如,VS模式中垂直/斜线的数量;或BS模式中top-k块的数量)

  1. 初始化或获取全局的注意力掩码
  2. 按照公式(1)计算注意力权重
  3. 初始化最佳分数和对应参数
  4. 遍历每种模式(stream_llm、vertical_and_slash、block_sparse)及其参数组合,计算分数(类似于卷积核提取特定模式的特征?)
    1. vertical_and_slash 垂直注意力:对查询-键点积进行softmax,累加所有列的注意力权重,并选择top-k列。 斜线注意力:计算所有斜线元素的和,选择top-k斜线。 合并:将垂直和斜线注意力矩阵组合。
    2. stream_llm 掩码计算:生成上下三角掩码矩阵,设置指定大小的垂直和斜线区域。 应用掩码:将掩码应用于注意力权重矩阵。
    3. block_sparse 块划分:将查询和键分成多个块,并对每个块进行池化。 选择 top-k:计算块级别的注意力权重,选择前top-k个块。 合并块稀疏注意力:将块稀疏矩阵扩展到原始维度
  5. 记录所有信息,并更新最佳分数和参数
  6. 选择每个注意力头的最佳模式

2.3.2 稀疏度指标逼近与动态稀疏注意力计算

在推理阶段,根据分配的模式和准确的输入对注意力矩阵进行在线估计,以动态确定我们的稀疏指数的空间分布。

图4: 稀疏度指标逼近与动态稀疏注意力计算
1) Vertical-Slash
  1. 生成估计的注意力矩阵 \(\boldsymbol{\hat{A}}\): 由于垂直线和斜线的连续性,我们将最后一个查询向量 \(\bm{Q}_{[-\text{last\_q}:]}\) 和键向量 \(\bm{K}\) 进行矩阵乘法
  2. \(\boldsymbol{\hat{A}}\)来确定垂直线 \(\bm{i}_v\) 和斜线 \(\bm{i}_s\) 的索引。
  3. 将垂直线和斜线的稀疏索引转换为稀疏格式 \(\bm{i}_{vs}\)
  4. 利用这些稀疏索引,我们执行注意力权重和注意力输出的块稀疏计算。
2) Block-Sparse
  1. \(\bm{Q}\)\(\bm{K}\) 应用均值池化以获得 ( ) 和 ( )。
  2. 将这两个矩阵相乘以得到估计的块级注意力权重 \(\bm{\hat{A}}\)。由于均值池化和矩阵乘法操作是可交换的,结果注意力权重大致等同于均值池化后的实际注意力权重。这使得我们能够以最小的开销近似实际注意力权重的块稀疏模式。
  3. 构建一个稀疏索引 \(\bm{i}_b\),并使用它来计算稀疏注意力权重和注意力输出。

**2.3.3 利用优化后的GPU kernel进行稀疏注意力计算*

Stream-LLM
  1. 初始化块范围和掩码:根据滑动窗口和块大小计算要处理的键和值向量的列范围。
  2. 循环处理块:遍历范围内的块,加载对应的键和值向量。
  3. 计算查询-键点积并应用掩码:计算查询和键向量的点积,并应用滑动窗口掩码确保注意力机制的方向性。
  4. 更新累加器和softmax计算:计算softmax概率并更新累加器,最终得到注意力权重的加权和。
Block-Sparse
  1. 初始化掩码和块计数:确定哪些查询向量在当前块内有效,并计算需要处理的块数量。
  2. 循环处理稀疏块:遍历每个稀疏块,加载相应的块索引和列索引。
  3. 加载键和值向量的块:使用生成的列索引加载对应的键向量和值向量块。
  4. 计算查询和键的点积,并应用因果掩码:计算查询和键向量的点积,并应用因果掩码以确保注意力机制的方向性。
  5. 更新累加器和softmax计算:通过softmax计算和累加更新累加器,最终得到注意力权重的加权和。
  6. 写回输出:将累加和除以归一化因子后写回输出。
Vertical-Slash
  1. 索引排序和范围计算
    • 根据竖线和斜线索引排序,并计算每个块的范围,确定哪些元素需要进行注意力计算。
  2. 块内并行计算
    • 在每个块内分别计算竖线部分和斜线部分的注意力分数,通过加载查询、键、值向量块并进行点积计算。
    • 计算竖线注意力分数
    1. 循环遍历块索引 block_index,在每个块内计算竖线部分的注意力分数 qk。
    2. 使用 tl.dot(q, k) 计算查询向量和键向量的点积,得到注意力分数。
    3. 应用因果掩码 causal_mask 确保仅计算当前时间步及之前的时间步,并排除掉不相关的未来时间步。
    4. 计算新的最大值 m_i_new 和比例因子 alpha 来更新注意力分数。
    5. 通过 p 加权并更新累加器 acc。
    • 计算斜线部分
    1. 在另一个循环中,通过列索引 start_n 遍历块内所有列,计算斜线部分的注意力分数和累加值。
    2. 同样使用 tl.dot(q, k) 计算查询向量和键向量的点积,并应用掩码操作 m_mask 和 n_mask 确保计算的正确性。
    3. 计算新的最大值 m_i_new 和比例因子 alpha,并通过 p 更新累加器 acc。
  3. 掩码应用和归一化
    • 应用掩码确保只计算相关部分的注意力分数,并进行最大值更新和归一化处理。
  4. 最终输出
    • 将归一化后的注意力输出写回到最终输出张量中。
获取 Vertical-Slash Index: 返回块计数、块索引、列计数和列索引
  1. 排序竖线和斜线索引
    • 对竖线索引进行增量排序 IncrementalSort(i_v)
    • 对斜线索引进行降序排序 DescendingSort(i_s)
  2. 计算块数
    • 计算块数量 ( N )
  3. 初始化输出
    • 初始化块计数 block count ( c_{} ^{N} )
    • 初始化块索引 block index ( i_{} ^{N k_v} )
    • 初始化列计数 column count ( c_{} ^{N} )
    • 初始化列索引 column index ( i_{} ^{N k_s} )
  4. 并行化处理
    • 对每个块进行遍历,找到交叉竖线和斜线的位置,记录范围并更新块索引和列索引。
  5. 合并点(竖线索引和斜线索引范围)
    • 通过合并点和范围来更新块和列信息。
获取 Vertical-Slash Flash Attention
  1. 初始化
    • 缩放因子 ( )
    • 初始化输出 ( O (0)^{S d_h} )
  2. 并行化处理(竖线部分)
    • 对每个块进行遍历,加载查询块 ( Q_{} ),初始化输出块 ( O_{} )
    • 初始化最大值 ( m ()^{B} ) 和累计量 ( l (0)^{B} )
    • 遍历块索引,加载键和值块,计算注意力分数 ( S Q_{} K_{}^T )
    • 应用掩码和缩放,更新最大值和累计量,计算加权和更新输出块 ( O_{} )
  3. 并行化处理(斜线部分)
    • 遍历列索引,加载键和值块,计算注意力分数 ( S Q_{} K_{}^T )
    • 应用掩码和缩放,更新最大值和累计量,计算加权和更新输出块 ( O_{} )
  4. 写回输出
    • 对输出块进行归一化 ( O_{} (l^{-1}) O_{} )
    • 保存最终输出 ( O_i O_{} )