1.0 摘要
我们展示了,通过实施列主序调度来改善数据局部性,我们可以将针对 MoE(专家混合模型)的核心 Triton GEMM(通用矩阵乘法)内核在 A100 上加速高达 4 倍,在 H100 英伟达 GPU 上加速高达 4.4 倍。本文展示了几种不同的 MoE GEMM 工作分解和调度算法,并在硬件层面揭示了为何列主序调度能带来最高的速度提升。
代码库和代码可在此处获取:https://github.com/pytorch-labs/applied-ai/tree/main/kernels/triton/inference/col_major_moe_gemm。
图 1A. 不同批量大小 M 下 A100 上的优化融合 MoE GEMM 内核 TFLOPs 性能
图 1B. 不同批量大小 M 下 H100 上的优化融合 MoE GEMM 内核 TFLOPs 性能
2.0 背景
OpenAI 的 Triton 是一种硬件无关的语言和编译器,正如我们之前的博客文章所示,可用于加速量化工作流。我们还表明,在内核开发方面,CUDA 的许多相同学习经验和性能分析工具都可以用来深入了解 Triton 内核的底层工作原理以及在延迟敏感环境中加速这些内核的后续措施。随着 Triton 在生产环境中得到越来越多的采用,开发者理解开发高性能内核的常用技巧和诀窍,以及这些方法对各种不同架构和工作流的普适性非常重要。因此,本文将探讨我们如何使用经典技术优化 vLLM 为流行的专家混合模型 (MoE) Mixtral 开发的 Triton 内核,以及如何在 Triton 中实现这些技术以获得性能提升。
Mixtral 8x7B 是一种稀疏的专家混合语言模型。与经典的密集 Transformer 架构不同,每个 Transformer 块包含 8 个 MLP 层,其中每个 MLP 都是一个“专家”。当一个 token 流经模型时,一个路由网络会选择 8 个专家中的哪 2 个来处理该 token,然后将结果进行组合。对于同一个 token,不同层选择的专家可能会不同。因此,虽然 Mixtral 8x7B 总共有 470 亿个参数,但在推理过程中只有 130 亿个参数是活跃的。
MoE GEMM(通用矩阵乘法)内核接收包含所有专家的堆叠权重矩阵,然后必须利用路由网络结果分数生成的映射数组,将每个 token 路由到 TopK(Mixtral 为 2 个)专家。在本文中,我们提供了在推理时,特别是自回归(或解码阶段)期间,有效并行化此计算的方法。
3.0 工作分解 - SplitK
我们之前已经表明,对于 LLM 推理中遇到的矩阵问题规模,特别是在 W4A16 量化推理的背景下,可以通过应用SplitK 工作分解来加速 GEMM 内核。因此,我们通过在vLLM MoE 内核中实现 SplitK,开始了我们的 MoE 加速研究,这比数据并行方法提高了约 18-20% 的速度。
这一结果表明,SplitK 优化可以作为在推理设置中改进/开发 Triton 内核的一种更规范方法的一部分。为了帮助理解这些不同的工作分解,我们来看一个简单的 4x4 矩阵乘法示例,其中 SplitK=2。
在下图所示的数据并行 GEMM 内核中,输出矩阵的一个块的计算将由 1 个线程块 TB0 处理。
图 2. 数据并行 GEMM
相比之下,在 SplitK 内核中,计算输出矩阵中 1 个块所需的工作被“分割”或由 2 个线程块 TB0 和 TB1 共享。这提供了更好的负载均衡和更高的并行性。
图 3. SplitK GEMM
关键思想是我们将并行性从 MN 增加到 MN*SplitK。这种方法确实会产生一些开销,例如通过原子操作增加线程块间的通信。然而,与共享内存和寄存器等其他受限 GPU 资源的节省相比,这些开销微不足道。最重要的是,SplitK 策略为瘦矩阵(如 MoE 推理中的情况)提供了卓越的负载均衡特性,并且是解码和推理过程中的常见矩阵形态。
4.0 GEMM 硬件调度 - 列主序
为了在 SplitK 带来的约 20% 提速基础上进一步改进,我们将研究重点放在控制 Triton 内核中 GEMM 硬件调度的逻辑上。我们对 vLLM MoE 内核的性能分析显示其 L2 缓存命中率较低,因此我们研究了三种调度选项:列主序、行主序和分组启动。由于 MoE 模型的一些固有特性,例如庞大的专家矩阵,以及在内核执行期间需要动态加载 TopK(Mixtral 为 2 个)矩阵,缓存重用/命中率成为瓶颈,这正是本次优化要解决的问题。
作为背景知识,在我们的上一篇博客文章中,我们探讨了“tile swizzling”的概念,这是一种提高 L2 缓存命中率的方法。此概念与软件如何将 GEMM 调度到 GPU 的 SMs(流式多处理器)有关。在 Triton 中,此调度由 pid_m 和 pid_n 计算决定。我们的关键洞察是,对于瘦矩阵乘法,列主序可以确保最优地重用权重矩阵 B 的列。为了说明这一点,我们来看一个列主序计算 pid_m 和 pid_n 的片段:
图 4. PyTorch 中的列主序
从上图可以看出,通过这种映射,我们调度 GEMM,使得我们按照以下顺序计算输出块 C:C(0, 0), C(1, 0), C(2, 0),... 等等。为了理解其影响,我们提供了以下图示:
图 5. 列主序 GEMM 调度的缓存重用模式
在上面简化的列主序调度视图中,假设对于一个激活矩阵 A 为瘦矩阵的 GEMM,整个矩阵可以放入 GPU 缓存中,这对于我们在 MoE 推理中遇到的问题规模来说是合理的假设。这使得权重矩阵 B 的列能够被最大程度地重用,因为 B 的同一列可以被 C(0,0), C(1, 0) 和 C(2, 0) 等相应的输出块计算重复使用。反之,考虑行主序调度,计算顺序为 C(0,0), C(0,1), C(0, 2) 等。我们将不得不驱逐 B 的列,并向 DRAM 发出多次加载指令来计算相同数量的输出块。
优化内核时的一个重要设计考虑因素是采用能够产生最少全局加载指令的内存访问模式。这种最优内存访问模式可以通过列主序调度实现。以下结果展示了我们研究的三种调度的性能:
图 6. 不同批量大小 M 下 A100 上 GEMM 调度的比较
列主序调度比其他模式提供了高达 4 倍的加速,正如我们将在下一节中展示的,由于数据局部性的大幅改善,它提供了最优的内存访问模式。
5.0 Nsight Compute 分析 - 吞吐量和内存访问模式
为了进行性能分析,我们重点关注 H100 上 M = 2 的情况。A100 也可以进行类似的研究,因为许多相同的观察结果同样适用。我们注意到以下显著结果,展示了我们优化的影响。
图 7. M = 2 时 H100 内存吞吐量图表。注意缓存命中率的大幅提升:L1 缓存命中率 (+2696%) 和 L2 缓存命中率 (+254%)。
图 8. M = 2 时 H100 内存指令统计。注意全局内存加载减少了 49%。
这些统计数据表明,我们的优化达到了预期效果,这体现在缓存未命中减少、内存访问减少以及随之而来的 2.7 倍加速。更具体地说,跟踪显示 L2 命中率提高了 2.54 倍(图 7),DRAM 访问减少了约 50%(图 8)。
这些改进最终带来了延迟的降低,优化后的内核在 bs=2 时速度提高了 2.7 倍,在 bs=512 时速度提高了 4.4 倍。
6.0 未来工作
我们的内核在 FP16 中进行了测试,展示了 MoE 列主序调度的数值精度和性能,但大多数生产模型使用 BFloat16。我们在 Triton 中遇到了 tl.atomic_add 不支持 BFloat16 的限制,并且遇到了启动延迟问题,这需要 cuda graph 支持才能在生产环境中使用列主序。在初步测试中,这转化为端到端加速 70%,但是我们在端到端环境中遇到了一些专家映射不一致的问题,这些问题在测试环境中没有反映出来,因此需要进一步工作才能完全实现这些加速。\
对于未来工作,我们打算将其移植到 CUDA 内核中,以确保完全支持 BFloat16 并降低相对于 Triton 的启动延迟,并可能解决专家路由不一致问题。我们之前也发表过关于使用 Triton GEMM 内核支持 GPTQ W4A16 的工作,因此自然的后续工作将包括将去量化融合到此内核中,以实现 GPTQ 量化推理路径。
7.0 可复现性
我们已经开源了 Triton 内核代码以及一个易于运行的性能基准测试,供有兴趣在自己的 GPU 上比较或验证性能的读者使用。
致谢
我们感谢 Daniel Han、Raghu Ganti、Mudhakar Srivatsa、Bert Maher、Gregory Chanan、Eli Uriegas 和 Geeta Chauhan 对所呈现材料的审阅,以及 vLLM 团队的 Woosuk,我们在他实现的融合 MoE 内核的基础上进行了开发。