博客

加速 Mamba2 的内核融合

总结

在本文中,我们将讨论如何使用融合 Triton 内核(fused Triton kernel)优化 Mamba-2 状态空间对偶(SSD)模块。该优化在 NVIDIA A100 和 H100 GPU 上实现了 1.50 倍至 2.51 倍 的加速。为了实现这一目标,我们通过精细的同步机制将所有五个 SSD 内核融合为了一个 Triton 内核。据我们所知,这是首个针对全部五个 SSD 内核的端到端 Triton 融合方案。这减少了内核启动开销,并避免了冗余的内存操作,从而在所有输入规模下提升了内核运行速度。本文余下部分将介绍我们融合 SSD 内核的方法、仍然存在的瓶颈、基准测试结果,以及我们计划将其开源以造福社区的方案。

图 1. 融合 SSD Triton 内核在 A100 和 H100 上的加速效果

背景

Mamba-2 是一种基于状态空间对偶(SSD)框架的序列模型。作为原版 Mamba 模型的优化继任者,它将结构化状态空间模型(SSM)与基于注意力机制的 Transformer 相结合。Mamba 类模型的一个关键优势是能够扩展到长序列。Mamba 的状态空间机制随上下文长度呈线性扩展。实际上,输入序列长度翻倍只会使 Mamba 的计算和内存需求大约翻倍,而自注意力机制则会使其需求变为原来的四倍。这使得 Mamba-2 对于 128K token 及以上的超长上下文场景极具吸引力。

IBM 的 Granite 4.0 模型系列最近采用了混合架构,将 Mamba-2 块与 Transformer 块相结合。在 Granite 4.0 中,每 1 个注意力层对应 9 个 Mamba-2 层,以高效处理长距离上下文。随着 Mamba-2 成为此类模型的核心部分,优化其性能对于更快的推理至关重要。Mamba-2 计算的核心是 SSD 模块,它取代了每一层中的注意力机制。原始的 Mamba2 SSD 实现主要受限于内存带宽和延迟,并涉及中间数据的写入与读取,因此存在改进空间。在本文中,我们专注于利用优化后的融合内核来加速 SSD 的预填充(prefill)操作。

Mamba2 操作

构成典型 Mamba2 块的操作列于 表 1 中。我们专注于融合这五个 SSD 内核,因为它们在概念上表现为同一个 SSD 操作,尽管正如后文所述,进一步的融合(如卷积和层归一化)也是可能的。

 

层归一化 (Layernorm) 有助于数值稳定性
输入投影 (In Projection) 将输入投影到 SSD 通道/尺寸
深度卷积 (Depthwise Convolution) 混合最后几个 token
SSD 分块累加 (SSD Chunk Cumsum) 计算每个 token 的 dt 以及分块内的累积衰减
SSD 分块状态 (SSD Chunk State) 独立计算该分块末尾的状态
SSD 状态传递 (SSD State Passing) 计算每个分块末尾的全局状态
SSD BMM (批矩阵乘法) 计算输入的每个分块如何影响输出 y 的相应分块
SSD 分块扫描 (SSD Chunk Scan) 根据输入的相应分块及前一个分块的全局状态,计算 y 的每个分块
层归一化 (Layernorm) 有助于数值稳定性
输出投影 (Out Projection) 将输出投影到模型的隐藏维度

表 1. Mamba2 操作

为什么需要内核融合?

在预填充阶段(即 token 生成前对提示词或输入序列的前向传播)中,Mamba-2 的 SSD 模块执行五个 GPU 内核的流水线。在原始实现中,这五个内核在 GPU 上顺序运行。
然而,依次启动多个小内核会产生巨大的开销,并阻碍 GPU 在阶段间高效重用数据。通过应用内核融合,我们可以获得几个关键益处:

  • 消除内核启动开销: 一次启动代替五次,减少了 CPU-GPU 同步和调度延迟。
  • 提高缓存局部性: 一个阶段产生的数据在同一个线程块内被下一个阶段立即消费,提高了缓存命中率并减少了全局内存流量。
  • 重叠计算: 融合内核的不同部分可以并行执行(在独立的情况下),从而更好地利用 GPU 资源。

我们的解决方案将所有五个内核融合为一个 Triton 内核,使得该层整个 SSD 预填充计算都在一次 GPU 启动内完成。

高效的内核融合技术

与简单的矩阵乘法 + 激活函数融合不同,SSD 融合非常复杂,因为计算跨越了多个步骤且存在复杂的依赖关系。原始实现依赖于内核间的隐式同步,而在我们融合一切后,这种同步消失了。在本节中,我们将讨论为什么这一点很重要,以及我们如何使融合在实践中可行。

Mamba2 SSD 的五个步骤最初实现为五个独立的内核:Chunk Cumsum、BMM、Chunk State、State Passing 和 Chunk Scan,它们在固定大小的 token 分块上运行。下图说明了这些内核之间的依赖关系。

图 2. Mamba2 SSD 预填充内核图

State Passing 步骤在分块之间存在依赖关系,原始的 State Passing 内核通过在线程块内循环遍历分块并跨线程块拆分状态通道以实现并行处理来解决此问题。借助此 State Passing 循环以及内核启动间的隐式全局同步,原始内核处理了所有依赖关系。

当我们尝试将所有五个内核融合到一次启动中时,真正的技术挑战出现了。一旦融合,我们就失去了原始内核所依赖的隐式全局同步,因此必须显式管理分块内和跨分块的依赖关系。大部分依赖关系存在于同一分块的不同步骤之间,因此对于最大的三个内核(Chunk State、State Passing 和 Chunk Scan),这些分块内依赖可以通过在同一个线程块上运行特定分块的所有步骤来处理。这也使我们能够将步骤间的中间数据保留在寄存器或 L1 缓存(每个 SM 私有)中,因为数据将在同一个线程块中使用。

然而,这种方法既不可能也不正确。原始的 State Passing 内核具有上述循环,这使其线程块网格无法与原始的 Chunk State 和 Chunk Scan 内核匹配。此外,为每个分块使用单独的线程块将消除通过在单个线程块内循环遍历分块所提供的自然同步和正确性。

为了使融合成为可能,我们将 State Passing 循环在分块间的迭代拆分到不同的线程块中,以便线程块网格匹配。我们通过使用原子操作(atomics)对这些线程块进行排序来获得正确性。这是一种序列化形式,表面上看起来效率低下,但可以通过与另外两个部分重叠来缓解。

例如,如果我们并行运行 8 个分块,我们会预期 State Passing 序列化会导致约 8 倍的本地减速。然而,融合后的 State Passing 只是三个大步骤中的一小部分,特别是由于它不再需要从全局内存中读取状态(它已经在融合的 Chunk State 中处于线程块内了)。

根据阿姆达尔定律(Amdahl’s law),我们预期运行时间变为 (State Passing 分数) * 8 + (1 – State Passing 分数) * 1。例如,如果 State Passing 步骤仅占排除同步后的组合时间的 1/7,我们将得到 (1/7) * 8 + (6/7) * 1 = 2,这意味着整体减速 2 倍。然而,这并没有考虑重叠。由于 State Passing 的同步可以与 Chunk State 和 Chunk Scan 的计算重叠,减速大约为:

State Passing 计算时间 + max(其他计算时间, State Passing 同步时间)

= 1/7 + max(6/7, 1/7 * 7) = 1.14 倍

如果 State Passing 占总运行时间的比例更小,或者同时处理的分块更少,我们理论上可以在除前几个分块之外的所有分块中避免任何序列化减速。

图 3. State Passing 开销重叠

图 3 展示了理论上的同步延迟,这对于并行运行的前几个分块来说很高,但在所有后续分块中趋于较低的开销。我们可以看到,虽然分块 8 依赖于分块 7,但它只需忙等待 1 个单位的时间而不是 8 个,因为分块 0 的 Chunk Scan 和分块 8 的 Chunk State 与分块 1-6 的 State Passing 发生了重叠。在实践中,NVIDIA Nsight Compute 基准测试显示,少于 3% 的 warp 停顿(空闲线程时间)是由 State Passing 同步引起的,这意味着序列化延迟被隐藏了。

BMM 和 Chunk Cumsum 步骤相对于其他三个步骤极快。BMM 沿 ngroups 而非 nheads 拆分工作,Chunk Cumsum 则由其线程块为效率处理多个 head。为简单起见,我们为这两个步骤启动单独的线程块(前几个线程块处理它们),并让其他三个步骤的线程块通过原子操作等待其 BMM 和 Chunk Cumsum 依赖项。

当线程块开始执行内核时,它被分配处理 Chunk Cumsum 步骤,除非所有 Chunk Cumsum 工作已被分配完毕。同样,如果没有未分配的 Chunk Cumsum 工作,该线程块将被分配到 BMM 步骤(如果可用)。在两个快速步骤全部被分配给线程块后,后续线程块各自开始在 Chunk State 中处理一个分块,在 State Passing 中处理同一个分块,并在 Chunk Scan 后输出该分块。

虽然内核融合提高了数据重用并加速了 SSD,但要实现最大性能还需要额外的优化。这些包括:重新排列线程块以隐藏序列化延迟、为加载/存储添加缓存提示以优先考虑重用数据、将特殊情况移出融合内核以减少寄存器压力、改变某些中间数据类型、调整分块大小以及重构操作以减少延迟。这些优化技术将在附录 A 中详细描述。

剩余瓶颈

在本节中,我们使用 Nsight Compute 分析了优化后的融合 SSD 内核的瓶颈,以检查最终的利用率、停顿模式和资源权衡。

总体而言,我们可以查看融合内核的计算和内存利用率,以了解是什么限制了该内核。

图 4. A100 Nsight Compute 摘要

图 5. H100 Nsight Compute 摘要

我们可以看到,总体融合 SSD 计算利用率约为 40-50%,内存利用率约为 65-75%。由于初始加载/存储延迟和其他开销,不可能达到 100% 的利用率,但在优化良好的内核中通常可以达到至少 80%。作为参考,Mamba2 中使用的 H100 和 A100 矩阵乘法达到了 85-96% 的计算利用率。由于 SSD 内核中的计算和内存利用率都不理想,瓶颈比单纯的内存带宽或计算吞吐量更复杂。

我们可以查看 warp 状态统计信息,以查看 warp 在停顿什么。“Selected”表示 warp 执行了一条新指令,但“Stall Long Scoreboard”和“Stall Barrier”表示 warp 正处于空闲状态,等待 L2/VRAM 或进行同步。

图 6. 融合 SSD 内核在 H100 上的 Warp 状态统计

有几种方法可以减少这些停顿的影响并提高计算或内存利用率:

  1. 提高占用率 (Occupancy)
  2. 提高指令级并行度 (Instruction-Level Parallelism)
  3. 优化代码以使用更少的同步和内存操作或更好地缓存数据

占用率

现代 NVIDIA GPU 每个 warp 调度器有 12-16 个 warp(32 个线程一组),每个调度器可以在每个周期发布一条新指令。如果我们每个调度器只有 1 个 warp,那么每次 warp 停顿时我们都会浪费周期。然而,如果我们每个调度器有 16 个 warp,那么每个 warp 即使停顿 15/16 的时间也不会使硬件处于空闲状态。占用率是实际填充了活跃 warp 的可用 warp 插槽的比例。提高占用率有助于隐藏内存和指令延迟,提高 GPU 利用率。

图 7. 融合 SSD 内核在 H100 上的占用率

在当前的配置下,该融合内核的占用率仅为 25%,受到寄存器和共享内存的限制。虽然我们可以增加 warp 数量并减少每个线程的寄存器来提高占用率,但在实践中这会降低性能,这很可能是由于同步成本增加和寄存器压力增大导致的。

指令级并行度

指令级并行度意味着设计/优化代码以减少指令间的立即依赖关系,允许 warp 在上一条指令尚未完成时运行后续指令。这提供了与提高占用率相同的隐藏延迟的好处,但不需要更多的 warp。

减少同步和数据传输

由于 warp 通常在等待加载/存储内存或屏障,我们可以通过减少屏障的数量,或通过更好的缓存或不同的分块大小减少总数据传输来提高性能。

不幸的是,这三种优化技术可能会直接冲突并引入权衡。GPU 中的每个 SM 都有有限的寄存器和共享内存,因此如果每个线程块使用过多,占用率就会下降。我们可以通过分阶段加载数据来提高指令级并行度,但这需要更多的寄存器和共享内存,导致占用率降低。我们也可以改变块大小以减少传输的总数据或提高缓存命中率,但这同样需要更多资源并降低占用率。

这就是为什么融合内核没有极高的内存或计算利用率的原因。

内存利用率详情

图 8. 融合 SSD 内核在 H100 上的内存图表

从该图表中我们可以看到,报告的 65–75% 的内存利用率大部分来自对 L2 缓存的读取。这些读取可能包括 (i) 适合 L2 的张量,(ii) 在多个线程块间重用的张量,(iii) 线程块间的状态传输,以及 (iv) 自然通过 L2 的 VRAM 读取。由于 L1 缓存是每个 SM 私有的,且在线程块间不一致,因此将此流量转移到 L1 是不可行的。同样,绕过 L2 进行 VRAM 流量也没有帮助,因为所有全局内存访问都通过 L2。

这张内存图表表明,除了次优的内存利用率外,该内核实际上受限于 L2 而非 DRAM。因此,进一步优化将需要 (1) 提高内存利用率,(2) 调整块大小/配置,或 (3) 进行彻底的算法更改。

逐行停顿

Nsight Compute 分析显示了逐行的 warp 停顿,帮助我们检查 warp 停顿是否出于合理原因。正如预期,融合内核中的大多数 warp 停顿来自于加载数据、同步和计算,原子操作和跨分块同步的开销很小。更多细节请参见附录 B。

基准测试

我们在典型的推理场景中对 Triton 内核进行了基准测试,包括 batch size 1-32、序列长度从 1K 到 256K token,以及 fp16 状态。这些图表突显了我们的内核相对于基准未融合内核的加速效果。

图 9. NVIDIA A100 融合内核加速图

图 10. NVIDIA H100 融合内核加速图

融合后的 SSD 内核在 SSD 部分比未融合实现快 1.50 倍 - 2.51 倍。在短序列长度(特别是 batch=1 时),内核启动的开销使得融合内核优势明显,但这些固定成本在更长的序列中被摊薄。在更长的序列中,融合内核较少的数据移动变得更加有利,因为缓存抖动(cache thrashing)增加。对于像 Mamba-2 2.7B 这样 batch=1 和 seq=128K 的模型,在 NVIDIA A100 和 H100 GPU 上,SSD 的加速转化为大约 8-13% 的端到端加速。在较短的序列长度下,端到端加速在 1K 上下文时可以达到 ~20%,这很可能是由于减少了内核启动开销。

准确性和正确性

融合内核总体上是准确且正确的,但融合内核和参考解决方案的输出之间存在细微差异。这些差异取决于它运行的 GPU 以及某些计算的精度。融合内核在内部对某些原始内核使用 fp32 的计算使用了 fp16,因为这带来了约 16% 的加速。此外,原始内核支持 fp32 或 fp16 状态,但我们报告的加速是针对 fp16 状态的。融合内核仍然支持相同的中间数据类型和 fp32 状态。在本节中,我们将解释这些不同数据类型配置在准确性和性能方面的权衡。

表 2 中,我们报告了输出 y 张量的准确性,即与原始内核输出匹配的元素百分比。我们使用无阈值(元素必须完全匹配)、1e-3 的绝对和相对容差(小阈值)以及 1e-2(中等阈值)进行测试。在该表中,“精确数据类型”指所有计算均使用与原始内核相同的数据类型,而“宽松数据类型”指部分计算使用 fp16。融合内核和原始内核在每列中均使用相同的状态数据类型运行。

fp32 状态

精确数据类型

fp16 状态

精确数据类型

fp32 状态

宽松数据类型

fp16 状态

宽松数据类型

匹配 @ atol,rtol=0 99.696% 99.337% 67.307% 66.823%
匹配 @ atol,rtol=1e-3 100.000% 100.000% 99.819% 99.743%
匹配 @ atol,rtol=1e-2 100.000% 100.000% 100.000% 100.000%

表 2. H100 准确性表

浮点加法不是完全结合的,因此我们不能期望输出张量的所有元素在 0 阈值下匹配。即使是不同的 Triton 启动配置也可能导致同一内核的输出存在极小差异。对于“精确数据类型”(fp16 和 fp32 状态),输出在所有实际目的上都是相同的,因此该内核在最敏感的模型中也能工作。对于“宽松数据类型”(我们在加速图中使用),我们可以看到约 1/3 的元素与原始内核的输出不完全匹配。然而,如果我们允许 1e-3 的紧密阈值,超过 99.7% 的输出元素是匹配的。此外,在常用的 atol=1e-2, rtol=1e-2 (1%) 容差下,所有配置均实现 >99.9995% 的准确性,实际上等同于 100%。在实际应用中,我们预计“宽松数据类型”具有无法区分的准确性。

图 11. H100 fp32 与 fp16 准确性图

图 11 中,我们展示了当状态为 fp32 而非 fp16 时加速比如何变化。当状态为 fp32 时,融合内核和原始内核在 chunk_size=256 时都更快。这代表了为了更小的状态张量而进行的更高计算的权衡。对于 fp32 状态,融合内核的加速比 fp16 状态要小,这很可能是因为计算和数据移动的平衡不同。

其他架构

融合后的 SSD 内核不仅限于 Mamba-2。它也直接适用于线性注意力(linear attention),因为当 A = 1 时,SSD 公式简化为线性注意力更新。在这种特殊情况下,融合内核可以进一步简化和优化以获得更好的性能。

新的 GPU 特性

融合 SSD 内核目前尚未使用诸如 Hopper GPU 上的 Tensor Memory Accelerator (TMA) 和线程块集群(thread block clusters),或者 Blackwell GPU 中的张量内存等更新的 GPU 特性。这些特性可以极大地降低寄存器压力,这将加速 SSD 并可能实现更快的 Triton 配置(例如,更大的块大小)。线程块集群对于广播加载 SSD 内核中跨一组 head 共享的 C、B 和 CB 矩阵特别有用。如果需要,这可以在新 GPU 上带来进一步的加速。

进一步融合:卷积和层归一化

在这个融合的 SSD 内核中,我们融合了 5 个原始的 SSD 内核。然而,SSD 之前的卷积和 SSD 之后的层归一化也是有吸引力的融合候选者,因为融合每一个都将消除内核之间的一次完整的读取和写入。由于卷积是深度卷积(无通道混合),SSD 可以沿 seqlen 维度加载 d_conv 额外数据,并加载卷积权重以在寄存器或共享内存中执行卷积。

我们进行了一些融合层归一化的实验,但效果有限。有两种融合层归一化的方法:

  1. 单独启动层归一化线程块。这些线程块可以等待相应的 SSD 线程块完成,然后从 L2 缓存读取输出 y,而不是 VRAM。
  2. 在 head 之间同步 SSD 线程块,交换归一化值,并在寄存器或共享内存中计算层归一化。

方法 2 非常慢,因为 SSD 线程块在同步时停滞,且在等待时没有其他工作可做。方法 1 有效,但从 L2 读取而不是 VRAM 并没有提供像寄存器/共享内存那样多的好处。到目前为止,加速远低于理论极限,考虑到增加的复杂性,目前尚不清楚进一步优化是否值得。

关于模型设计的洞察

通过五个 SSD 内核的优化融合,Mamba2 预填充现在比以前更便宜。这改变了 Mamba2 层的运行时间-准确性权衡,这可能使扩大 Mamba2 层的大小和数量成为新 LLM 中的最佳平衡。更多设计洞察包括:

  • 计算强度:目前的融合内核在最快的分块大小时计算利用率较低,因此我们或许可以承受稍复杂的运算。虽然我们可以通过增加分块大小来增加计算强度,但这也会增加所需的寄存器和其他资源,导致整体减速。
  • 状态精度:在融合内核和原始内核中,State Passing 步骤必须是串行的而不是并行的。虽然存在次线性延迟的并行扫描算法,但在实践中,它们可能比 Mamba2 中使用的串行版本慢得多。因此,最小化 State Passing 计算占总延迟的比例对于隐藏序列化延迟至关重要。如果状态可以保持在低精度(如 fp16),这将显著帮助融合内核。如果没有快速的 State Passing 步骤,我们可能需要沿其他维度(如 headdim)拆分线程块,这将整体拖慢融合内核。
  • VRAM 与 L2 权衡:由于融合内核的 L2 带宽利用率高于 VRAM 带宽利用率,在线程块之间共享更少数据的成本较低。如果一个架构的性能从较小的组中获益良多,那么增加的 VRAM 读取对性能的负面影响可能比原始内核更小。另一方面,诸如 TMA 多播加载之类的新 GPU 特性可以减少 L2 带宽利用率,从而加速 SSD 并减少这种不平衡。

vLLM 集成

为了支持带有初始状态但无需填充的变长序列,vLLM 引入了“伪分块(pseudo chunks)”的概念。任何包含多个序列 token 的分块都有多个伪分块,每个序列对应一个。5 个内核中的大多数功能相同,State Passing 在新序列开始时加载初始状态。然而,Chunk Scan 有一个更大的线程块网格,它遍历的是伪分块而不是分块。为了在融合内核中支持这一点,我们有一个 for 循环来处理当前分块中的所有伪分块。vLLM Chunk Scan 根据伪分块在实际分块中的开始位置偏移其读取和写入。我们改为使用基于序列索引的掩码,因为掩码提供了加速。偏移和掩码在运行时读写相同数量的数据,但掩码对于编译器来说可能更可预测、对齐更好或只是更简单。vLLM 融合内核仍在集成中,但显示出了类似的加速效果。

结论

总之,我们将 Mamba-2 SSD 预填充的五个 Triton 内核融合为一个,SSD 本身获得了 2 倍的加速,转化为 ~8–20% 的端到端推理加速。这显著提高了使用 Mamba-2 层模型的吞吐量。我们很高兴将这些内核改进集成到开源项目中,以便社区可以轻松地利用 Mamba-2 模型进行更快的推理。随着该融合 SSD 内核进入 Mamba 代码库和 vLLM 等推理框架,请关注后续更新。

附录 A:优化详情

线程块顺序

State Passing 步骤导致序列化。对于给定的 head,除一个线程块外,所有线程块都停滞等待前一个分块准备就绪。当我们的 GPU 并发运行约 256-1024 个线程块但只有一个取得进展时,我们会获得显著的减速。部分序列化被 Chunk State 步骤的延迟所隐藏,因为后续分块可能仍在计算 Chunk State 而不是在 State Passing 中停滞,但这还不够。我们在 SSD 中既有代表领域并行性(独立工作)的 nheads 维度,也有 batch 维度。与其在进入下一个之前为一个特定的 batch 和 head 启动线程块,不如为多个 (batch, head) 组合启动线程块。如果我们为同一个分块启动 n 个不同的 (batch, head) 组合,然后再进入下一个分块,我们的序列化将以 n 倍数下降(不再只有一个线程块取得进展,而是 n 个线程块取得进展)。这个 n 必须仔细平衡,因为如果太大,我们就会失去传递状态的 L2 缓存局部性,如果太小,线程块就会停滞。作为一个简单的启发式方法,我们在进入下一个分块之前启动所有 nheads 的线程块,但在 batch 维度上前进之前完成所有分块。对于具有更多或更少 head 或显著不同维度的模型,更复杂的线程块顺序可能涉及显式结合 nheads 和 batch,然后将其拆分为内部和外部维度,内部维度在下一个分块之前启动。

缓存提示

诸如 Mamba2 SSD 等操作的输入和输出张量通常太大,无法放入缓存。例如,对于 128 个 head、64 维、fp16 的 16k 上下文 Mamba2 SSD,输入和输出各消耗 16k * 128 * 64 * 2B = 256 MiB。典型的 GPU L2 缓存为 40-50 MiB。因此,在内核运行期间,部分数据将从 L2 缓存中驱逐。

由于大部分输出张量不适合 L2 缓存,因此不值得使用 L2 缓存容量来存储输出以尝试加速下一次操作。我们可以使用缓存提示(cache hint)来指示输出张量具有最低的缓存优先级。通常,一旦我们在内核中最后一次访问数据,我们就可以将其标记为缓存的低优先级。对于频繁重用的数据,如 CB(在组内的 head 间共享),我们可以使用高优先级缓存提示来减少被驱逐的几率。

我们还可以通过指定“发布(release)”语义来避免在某些同步原子操作期间刷新 L1 缓存。这告诉编译器,之前写入的数据必须在原子操作之前全局可见(例如,如果我们正在设置一个“就绪”标志),但此线程不需要使任何缓存失效。

条件分离

在 State Passing 步骤中,我们有两个特殊情况:读取初始状态而不是前一个分块的全局状态,以及写入最终状态而不是写入全局状态张量。虽然在概念上这些特殊情况应该只涉及交换读/写的基本指针,但初始和最终状态条件增加了寄存器压力并降低了融合内核的速度。为了解决这个问题,我们可以在融合 SSD 内核之外处理这些特殊情况。如果我们用 nchunks + 1 替换状态张量中的 nchunks 维度,我们可以将初始状态复制到第 0 个分块,并将最终状态从最后一个分块中复制出来。这些复制使用 PyTorch 的切片赋值语法完成,这导致了运行时或启动开销可忽略不计的小内核。

中间数据类型

对于某些计算,例如在 Chunk Scan 中将 A 衰减应用于 B,我们可以在计算中使用 fp16 而不是 fp32。这也通过仅对比例进行向下转换来替换向上转换 B 和向下转换结果的过程,从而减少了转换指令。

编译时掩码

Triton 要求线程块中张量块的维度是编译时已知的 2 的幂。这迫使所有存储和加载操作都在 2 的幂块上运行,这些块可能无法精确整除目标张量。因此,我们使用掩码覆盖整个张量,但避免读取或写入越界数据(或下一个数据块)。这些掩码与张量块的维度相同。然而,这些掩码并不总是必要的,因为像 headdim 这样的模型维度通常可以被块大小整除,并且在不同输入之间不会改变。Triton 支持 tl.constexpr 编译时参数,并通过 @triton.heuristics 基于其他参数设置它们。因此,我们可以在运行时根据 headdim 是否能被块大小整除,自动启用或禁用掩码的 headdim 维度。虽然这发生在“运行时”,但实际上它只在该模型的初始 JIT 编译期间发生一次。

分块大小

Mamba2 SSD 算法每个 token 的计算复杂度渐近为常数(计算随序列长度线性扩展),但它有一个以二次方计算的分块大小基准情况。在分块之间,使用线性算法,但在分块内,使用二次算法。更多细节,请参见 https://tridao.me/blog/2024/mamba2-part1-model/#state-space-duality

最佳分块大小代表了所需的更高计算和资源与更高的硬件利用率和更少的中间状态之间的权衡。在原始未融合内核中,Mamba2-2.7B 的最佳分块大小为 256。然而,使用新的融合内核,同一模型的最佳分块大小现在为 128。这个较小的分块大小还有一个好处,即减少了寄存器压力,使内核对启用掩码或对中间结果使用更高精度等微小变化不太敏感。

目前,Mamba2 模型的惯例是在模型的配置中指定分块大小。然而,由于最佳分块大小因原始内核与融合内核而异,使用启发式方法或自动调整分块大小可能更好。这可能并不简单,因为 SSD 内核周围的代码可能假定特定的分块大小。

比例乘法操作数

对于 Chunk State,我们可以等效地将 A 衰减应用于 X 而不是 B,因为要缩放的维度是 X 和 B 矩阵乘法的内部维度。本质上,我们做 (X * A[None, :]) @ B 而不是 (X @ (A[:, None] * B))。这更快,可能是因为更相似的布局导致较少的寄存器数据移动。例如,由于所需的 Tensor Core 数据布局,每个线程可能已经拥有了与其 X 值相乘所需的 A 值,但为了缩放 B,我们可能不得不以不同的布局加载并重新调整数据回所需的 Tensor Core 布局。

附录 B:停顿原因摘要

如果我们查看 NVIDIA Nsight Compute 中的源代码,可以看到 H100 上融合内核中每一行代码和汇编指令的 warp 停顿。假设内核和块大小是最优的,warp 停顿可以揭示潜在的优化领域。

  1. 为了确保正确性,我们使用原子加法按递增顺序获取线程块 id。这占总 warp 停顿的约 3%。
  2. 融合内核的 Chunk Cumsum 和 BMM 部分都非常快,因此它们各自导致的 warp 停顿不到 2%。
  3. 原子检查 Chunk Cumsum 和 BMM 线程块是否已为该 Chunk State 线程块准备好数据,占 warp 停顿的约 1.5%。
  4. Chunk State 在加载 dA、X,尤其是 B 时,占总 warp 停顿的约 12%。它在与缩放和使用 Tensor Core 相关的屏障中也有约 7% 的停顿。
  5. 尽管沿分块序列化,State Passing 在同步(包括等待前一个分块)上的停顿不到 3%。加载前一个状态不会导致明显的停顿,但更新状态和存储会导致约 6% 的停顿在等待共享内存或屏障。
  6. 对于 Chunk Scan 中前一个状态的贡献,加载 C 约占 5% 的加载停顿,prev_states 约占 3% 的屏障停顿,计算约占 8% 的屏障、加载(用于缩放)和指令依赖停顿。
  7. 当前分块在 Chunk Scan 中的贡献在加载数据时有约 13% 的停顿,在计算(包括缩放)时有约 18% 的停顿。
  8. 残差(由 D 缩放)占加载、共享内存和计算总停顿的约 6%。

总体而言,这些停顿是有合理原因的,不容易通过优化消除。