跳转到主要内容
博客

使用局部感知内核设计加速 MoE 模型推理

1.0 概述

我们展示了通过实施列主序调度以改善数据局部性,我们可以将MoE(专家混合)的核心Triton GEMM(通用矩阵乘法)内核在A100上加速高达4倍,在H100 Nvidia GPU上加速高达4.4倍。本文演示了MoE GEMM的几种不同的工作分解和调度算法,并从硬件层面解释了为什么列主序调度能产生最高的加速。

仓库和代码可在以下地址获取: https://github.com/pytorch-labs/applied-ai/tree/main/kernels/triton/inference/col_major_moe_gemm

Figure 1A. Optimized Fused MoE GEMM Kernel TFLOPs on A100 for varying Batch Sizes M

图1A. 在A100上针对不同批次大小M的优化融合MoE GEMM内核TFLOPs

Figure 1B. Optimized Fused MoE GEMM Kernel TFLOPs on H100 for varying Batch Sizes M

图1B. 在H100上针对不同批次大小M的优化融合MoE GEMM内核TFLOPs

2.0 背景

OpenAI的Triton是一种与硬件无关的语言和编译器,正如我们之前的博客文章所示,它可用于加速量化工作流。我们还表明,在内核开发方面,许多与CUDA相同的学习和性能分析工具可以被利用,以提供关于Triton内核如何在幕后工作以及在延迟敏感环境中加速这些内核的后续措施的相似见解。随着Triton在生产设置中日益普及,开发人员了解开发高性能内核的常见技巧和窍门,以及这些方法对各种不同架构和工作流的通用性至关重要。因此,本文将探讨我们如何使用经典技术优化了vLLM为流行的专家混合(MoE)Mixtral模型开发的Triton内核,以及如何在Triton中实现这些技术以获得性能提升。

Mixtral 8x7B是一种稀疏专家混合语言模型。与经典的密集Transformer架构不同,每个Transformer块包含8个MLP层,其中每个MLP都是一个“专家”。当一个token流经时,一个路由器网络会选择8个专家中的2个来处理该token,然后将结果组合。同一token在每一层选择的专家都不同。因此,尽管Mixtral 8x7B共有47B参数,但在推理过程中只有13B参数是活跃的。

MoE GEMM(通用矩阵乘法)内核接收一个包含所有专家的堆叠权重矩阵,然后必须利用路由器网络产生的结果分数对应的映射数组,将每个token路由到TopK(Mixtral为2)个专家。在本文中,我们提供了在推理时间,特别是在自回归(或解码阶段)期间,高效并行化此计算的方法。

3.0 工作分解 – SplitK

我们之前已经证明,对于LLM推理中出现的矩阵问题大小,特别是在W4A16量化推理的背景下,可以通过应用SplitK工作分解来加速GEMM内核。因此,我们通过在vLLM MoE内核中实现SplitK,开始了MoE加速研究,这比数据并行方法带来了大约18-20%的加速。

这个结果表明,SplitK优化可以作为一种更公式化的方法来改进/开发推理设置中的Triton内核。为了对这些不同的工作分解建立直观理解,我们来看一个简单的例子,针对两个4×4矩阵的乘法,且SplitK=2。

在下面显示的数据并行GEMM内核中,输出矩阵的单个块的计算将由1个线程块TB0处理。

Figure 2. Data Parallel GEMM

图2. 数据并行GEMM

相比之下,在SplitK内核中,计算输出矩阵中1个块所需的工作在2个线程块TB0和TB1之间“分割”或共享。这提供了更好的负载均衡和更高的并行度。

Figure 3. SplitK GEMM

图3. SplitK GEMM

关键思想是我们已经将并行度从 MN 增加到 MN*SplitK。这种方法确实会产生一些成本,例如通过原子操作增加线程块间通信。然而,与节省其他受限GPU资源(如共享内存和寄存器)相比,这些成本微乎其微。最重要的是,SplitK策略为瘦矩阵(如MoE推理中的情况)提供了卓越的负载均衡特性,并且是解码和推理过程中常见的矩阵配置文件。

4.0 GEMM硬件调度——列主序

为了在SplitK约20%的加速基础上进一步提升,我们将研究重点放在控制Triton内核中GEMM硬件调度的逻辑上。我们对vLLM MoE内核的分析显示L2缓存命中率较低,因此我们研究了三种调度选项——列主序、行主序和分组启动。由于MoE模型的一些内在特性,例如大型专家矩阵以及在内核执行期间需要动态加载TopK(Mixtral为2)个矩阵,缓存重用/命中率成为这种优化要解决的瓶颈。

作为背景,在我们之前的博客中,我们谈到了“瓦片置换(tile swizzling)”的概念,这是一种实现更高L2缓存命中率的方法。这个概念与软件如何将GEMM调度到GPU的SMs有关。在Triton中,这个调度由pid_m和pid_n计算决定。我们的关键见解是,对于瘦矩阵乘法,列主序确保了权重矩阵B的列的最佳重用。为了说明这一点,让我们来看一下列主序计算pid_m和pid_n的代码片段:

Figure 4. Column Major ordering in PyTorch

图4. PyTorch中的列主序

从上面我们可以看出,通过这种映射,我们调度GEMM,使其按照以下顺序计算C的输出块:C(0, 0), C(1, 0), C(2, 0),…等等。为了理解其含义,我们提供了以下插图:

Activation matrix / Weight matrix
L1/L2 Cache
C - Output Matrix

图5. 列主序GEMM调度下的缓存重用模式

在上述列主序调度的简化视图中,假设对于具有瘦激活矩阵A的GEMM,整个矩阵可以放入GPU缓存中,这对于我们在MoE推理中遇到的问题规模是一个合理的假设。这使得权重矩阵B的列能够最大程度地重用,因为B的列可以用于相应的输出瓦片计算,即C(0,0)、C(1,0)和C(2,0)。相比之下,考虑行主序调度,即C(0,0)、C(0,1)、C(0,2)等。我们将不得不逐出B的列,并发出多个加载指令到DRAM以计算相同数量的输出块。

优化内核时的一个重要设计考虑是能产生最少全局加载指令的内存访问模式。这种最优内存访问模式是通过列主序调度实现的。以下结果展示了我们研究的三种调度的性能

Figure 6. Comparison of GEMM Schedules on A100 for varying Batch Sizes M

图6. 在A100上针对不同批次大小M的GEMM调度比较

列主序调度比其他模式提供了高达4倍的加速,正如我们将在下一节中展示的那样,由于数据局部性的大幅改进,它提供了最优的内存访问模式。

5.0 Nsight Compute 分析——吞吐量和内存访问模式

为了进行性能分析,我们重点关注H100在M = 2的情况。对于A100也可以进行类似的研究,因为许多相同的观察结果也适用于A100。我们注意到以下显著结果,展示了我们优化的影响。

Figure 7. H100 Memory Throughput Chart for M = 2.  Note the very large increase in the cache hit rates L1 cache hit rate (+2696%) and L2 cache hit rate (+254%).

图7. H100内存吞吐量图,M = 2。请注意L1缓存命中率(+2696%)和L2缓存命中率(+254%)的显著增加。

Figure 8. H100 Memory Instruction Statistics M = 2. Note the 49% reduction in global memory loads.

图8. H100内存指令统计 M = 2。请注意全局内存加载减少了49%。

这些统计数据表明,我们的优化达到了预期效果,这可以通过减少的缓存未命中、减少的内存访问以及由此产生的2.7倍加速来体现。更具体地说,跟踪显示L2命中率增加了2.54倍(图7),DRAM访问减少了约50%(图8)。

这些改进最终带来了延迟的降低,优化后的内核对于bs=2时速度提高了2.7倍,对于bs=512时速度提高了4.4倍。

6.0 未来工作

我们的内核在FP16中进行了测试,这展示了MoE列主序调度的数值和性能,但大多数生产模型都使用BFloat16。我们在Triton中遇到了一个限制,即tl.atomic_add不支持Bfloat16,并且遇到了启动延迟问题,这需要CUDA图支持才能用于列主序生产。在初始测试中,这转换为端到端70%的加速,但我们在端到端环境中遇到了一些专家映射不一致的问题,这在测试环境中没有体现,因此需要进一步的工作才能完全实现这些加速。\

对于未来的工作,我们打算将其转移到CUDA内核中,这将确保完全支持BFloat16并相对于Triton降低启动延迟,并可能解决专家路由不一致问题。我们之前也发布了工作,介绍了如何使用Triton GEMM内核启用GPTQ W4A16,因此后续的自然工作将包括将解量化融合到此内核中,以实现GPTQ量化推理路径。

7.0 可复现性

我们已经开源了Triton内核代码以及一个易于运行的性能基准测试,供有兴趣在自己的GPU上比较或验证性能的读者使用。

致谢

我们要感谢 Daniel Han、Raghu Ganti、Mudhakar Srivatsa、Bert Maher、Gregory Chanan、Eli Uriegas 和 Geeta Chauhan 对所呈材料的审阅,以及 vLLM 团队的 Woosuk,我们基于他的 Fused MoE 内核实现进行了构建。