混合专家模型(MoE)是大型语言模型(LLM)中一种主流的模型架构。尽管它通过在每个 token 上激活较少数量的参数来降低训练和推理时的计算量,但在应对高内存带宽和通信压力,以及处理模型动态性和稀疏性所带来的复杂性时,仍面临实现最佳计算效率的挑战。为此,我们推出了一种全新的 MoE 推理解决方案 —— MetaShuffling,它使我们能够高效地部署 Llama 4 模型并进行生产环境推理。
Llama 4 架构

Llama 4 Scout 和 Maverick 模型现已正式发布。Scout / Maverick 模型包含一个共享专家(shared expert)以及 16 / 128 个路由专家(routed experts),每个 MoE 层均采用无丢弃 token 选择路由(dropless token-choice routing)和 Top-1 选择机制。此外,共享专家和路由专家均使用带有 3 层线性层的 SwiGLU 激活函数。有关该模型的更多信息,请参考 Llama 4 家族:原生多模态 AI 创新的新时代开端。
核心概念
针对 MoE 层引入的动态性和稀疏性问题,目前有多种常见的解决方案。在此,我们展示几种针对 Top-1 选择的 token 选择路由的不同方案。

上图展示了填充(Padding)设计的逻辑。每个方框代表一个 token,黄色/绿色代表分配给不同路由专家的有效 token,灰色代表填充 token。第二步中的每一行方框代表不同的路由专家。Ti 表示来自数据并行组当前 rank 的第 i 个 token。
- 填充(Padding):在这种方法中,我们将激活值填充到每个专家的最大序列长度,并执行单次批量矩阵乘法(BMM)。它导致:
- 内存占用增加,用于存储填充数据。
- 处理填充数据导致延迟增加。虽然可以通过锯齿核(jagged kernels)避免处理填充,但在专家数量较多时,锯齿核也可能带来较高的开销。

- 切片(Slicing):在这种方法中,我们将激活值切片为每个专家对应的精确序列长度,并运行多个矩阵乘法(MM)。它避免了填充带来的问题,但导致:
- 核函数(kernel)效率降低,这是由于在小形状(small shapes)上重复启动核函数所致。
- 设备利用率降低,这是由动态形状导致的主机与设备频繁同步,以及额外的核函数启动开销造成的,且它与图形捕获机制(如 CUDAGraph 和 torch.compile)不兼容。

- 拼接(Concatenation):在这种方法中,我们在切片后进一步拼接激活值,并运行单次分组矩阵乘法(GMM)。它避免了切片中的核函数效率问题,但仍然导致:
- 设备利用率降低,因为它仍需进行主机与设备同步,且同样不兼容图形捕获机制。
为了进一步优化方案,我们提出了一种基于洗牌(shuffling)的机制


- 洗牌(Shuffling):在这种方法中,我们直接对 token 进行排序,使路由后的 token 按路由专家 ID 排列。通过这种方式,无需进行填充或拆分,且分配给相同专家的 token 被存储在一起,可以在 GroupedGEMM 中一起处理。它提供了一个稠密(dense)模型接口,并避免了上述所有问题。
- 没有填充,因为激活值保持为稠密张量。
- 无需主机与设备同步,因为激活值保持为静态形状张量。
基于此设计,我们构建了一个端到端的 MoE 推理解决方案:MetaShuffling。
运行时设计
单 GPU 推理的非并行模式

上图是无模型并行下单 GPU 推理的总体运行时设计。注意,为了优化性能,SwiGLU 激活的第一层和第三层线性层被合并为 GroupedGEMM13 / GEMM13。
- 深蓝色/橙色实心框表示路由专家/共享专家流上繁重的 Tensor Core 核函数。
- 浅蓝色/橙色实心框表示路由专家/共享专家流上繁重的 CUDA Core 或内存流量密集型核函数。
- 红色箭头表示激活张量的数据流。
- 绿色箭头表示元数据张量的数据流。
所有元数据张量均放置在设备上。不存在阻塞性的设备到主机同步。所有核函数均连续启动,无空闲等待(bubbles)。图示仅展示数据流向,并非实际性能分析追踪的模拟。
核函数接口与数据流
- RoutingScores(路由分数):处理路由分数计算的函数或融合核函数。
- 输入:input_tokens: [T, D](T:token 数量;D:特征维度);router_weights: [D, E](E:专家数量);router_biases: [E];
- 输出:routing_scores: [T, E];scaling_factors: [T, E];
- IndexShuffling(索引洗牌):处理索引洗牌和排序的融合核函数。我们将在“核函数设计”章节介绍其优化实现。
- 输入:routing_scores: [T, E];K(Top-k 路由的阈值);
- 输出:routed_token_indices: [K * T];routed_expert_indices: [K * T];routed_token_counts_per_expert: [E];
- GatherMul(收集乘法):基于排序后的索引对 token 进行洗牌并缩放的融合核函数。
- 输入:input_tokens: [T, D];routed_token_indices: [K * T];routed_expert_indices: [K * T];scaling_factors: [T, E];
- 输出:scaled_routed_tokens: [K * T, D]
- GroupedGEMM:一个经过优化的 GroupedGEMM 核函数,可无限制地处理 M 维度上关于批次的设备内形状信息。我们将在“核函数设计”章节介绍其优化实现。
- 输入:tokens: [K * T, D];weights: [E, D, HD](HD:隐藏维度);routed_token_counts_per_expert: [E];
- 输出:tokens: [K * T, HD]
- GEMM:经过优化的 GEMM 核函数。接口与稠密模型相似。
- NonLinearity(非线性):处理非线性运算的融合核函数。接口与稠密模型相似。
- ScatterAdd(分散加法):经过优化的核函数,它基于排序后的索引反转 token 洗牌,并直接将结果累加到共享专家输出中,无需物化一个未经洗牌的张量。
- 输入:shared_output_tokens: [T, D];routed_output_tokens: [K * T, D];routed_token_indices: [K * T];
- 输出:combined_output_tokens: [T, D]
注意:如果应用了量化,激活量化核函数会融合到前面的非 GEMM 核函数中。这意味着在 GroupedGEMM13 中融合到 GatherMul,在 GroupedGEMM2 中融合到 NonLinearity 等。
注意:如果使用较大的 K * T,GatherMul 和 ScatterAdd 操作可以进一步融合到后续/前置的 GroupedGEMM 操作中。这应当在序言/尾声中完整实现为全局内存到共享内存/寄存器,或共享内存到全局内存的步骤。然而,这在核函数设计层面为与 Tensor Core 执行重叠带来了额外的挑战。此外,融合 ScatterAdd 需要共享专家先于路由专家完成,如果这些核函数被用于隐藏 AlltoAll 的延迟,这可能不是一个好的设计选择。
单机推理的张量并行

上图是具有张量并行(TP)能力的单机推理的总体运行时设计。与单 GPU 推理相比,额外的步骤是:
- 浅薄荷绿实心框表示网络流量密集型通信核函数。
同样,所有元数据张量均放置在设备上,没有设备到主机的同步。所有核函数连续启动,无空闲等待。图示仅展示数据流向,并非实际性能分析追踪的模拟。
工作负载分片与额外核函数
与单 GPU 推理用例相比,未引入额外的自定义核函数。对于 GEMM、GroupedGEMM 和非线性核函数,激活值和权重均沿不同维度分片至 1/TP,计算/内存开销也被分摊至 1/TP。
如果仅应用张量并行,最后一步应为 AllReduce。如果同时应用张量并行与序列并行,则为 ReduceScatter。
多机推理的专家并行
为了实现专家并行(EP),我们将数据并行维度从路由专家中交换出来,作为路由专家内部的专家并行维度。注意,张量并行可以进一步与专家并行交换,以在增加路由不平衡风险的情况下获得更好的 GEMM 效率,但本博客不涉及此设计。
如果通过 token 选择路由启用了专家并行,则必须决定使用稠密张量还是静态形状,因为路由到不同专家组的 token 数量是动态的。
- 当倾向于使用 eager 模式时,我们使用稠密张量和动态形状,以避免运行未填充(unpadded)的 AlltoAll 导致的浪费网络流量和内存空间。
- 当倾向于使用图模式(graph mode)时,我们使用稀疏张量和静态形状,以避免运行 CUDAGraph 时因 CPU 启动开销和设备到主机同步而产生的 GPU 空闲等待。
注意:使用填充激活值所浪费的网络流量也可以通过自定义 AlltoAll 实现来避免,但本博客不涉及任何关于自定义通信或通信与计算融合核函数的主题。

上图是具有张量并行和专家并行能力的跨机推理的总体运行时设计。与具有张量并行能力的单机推理相比。
- 红色实心箭头表示节点内通信。
- 紫色实心箭头表示跨节点通信。
核函数接口与数据流
对于新增的基于专家并行的通信,我们使用 3 次 All2All 通信来交换形状和 token。
- 第一次 A2A:交换关于路由到每个专家的 token 数量的设备内元数据张量,即 IndexShuffling 核函数生成的输出 `routed_token_counts_per_expert: [E]`。
- 第二次 A2A:从数据并行基础交换 token 到专家并行基础,根据路由调度到不同的 EP rank。
- 第三次 A2A:从专家并行基础交换 token 到数据并行基础,根据路由从不同的 EP rank 汇总。
此外,我们添加了 2 个额外的洗牌核函数和 1 个特殊的散列核函数。
- CombineShuffling(稠密或填充):将接收到的 token 从“按 rank 排序”重新洗牌为“按专家排序”。后续的 T* 表示从所有对等点接收到的 token 总数,它可以被进一步解释为一个具有来自 routed_token_counts_per_rank_per_expert 张量的形状信息的锯齿维度。
- 输入:received_tokens: [T*, D](先按 DP rank 排序,再按专家索引排序);routed_token_counts_per_rank_per_expert: [EP, E // EP];
- 输出:reshuffled_tokens: [T*, D](先按专家索引排序,再按 DP rank 排序);routed_token_counts_per_expert: [E // EP];
- SplitShuffling(稠密或填充):CombineShuffling 的反向过程。将要发送的 token 从“按专家排序”重新洗牌为“按 rank 排序”。
- 输入:reshuffuled_tokens: [T*, D](先按专家索引排序,再按 DP rank 排序);routed_token_counts_per_rank_per_expert: [EP, E // EP];
- 输出:to_send_tokens: [T*, D](先按 DP rank 排序,再按专家索引排序);
- ScatterAdd(填充):从填充张量中将验证 token 进行分散加法运算。
- 输入:shared_output_tokens: [T, D];received_padded_routed_output_tokens: [EP, K*T, D];routed_token_indices: [K * T];routed_token_counts_per_expert: [E];
- 输出:combined_output_tokens: [T, D]
我们将在“图模式下的静态形状填充通信”章节详细演示上述核函数。
Eager 模式下动态形状的非填充通信


运行时行为的高层逻辑图。不同组件的实际运行时可能会因软件和硬件而异。
最小化动态形状的使用
由于路由在每个 MoE 层都是动态的,所需的最小设备/主机同步频率是每层一次。为了实现这一点,我们推迟了 `send_sizes` 的 D2H 拷贝,并将其与 `recv_sizes` 拼接,以便通过单次 D2H 拷贝进行传输。这减少了每层所需的设备/主机同步次数。
最小化对动态形状的负面影响
为了进一步隐藏设备/主机同步开销,我们将共享专家拆分为两个部分。
- 我们在路由后、派遣 A2A 之前分发第一部分。这样,当设备/主机同步发生时,设备仍处于运行共享专家的繁忙状态。
- 我们在 MoE 之后、合并 A2A 之前分发第二部分。这将进一步有助于重叠第二个 A2A 的开销。
图模式下的静态形状填充通信

最小化填充的使用
通过无丢弃 token 选择设计,路由到任何单一专家的最大 token 数为 T。然而,如果我们将多个专家组合在一起,并通过专家并行分片将它们放置在单个 GPU 上,对于 TopK 路由:
- 路由到 1 个专家的最大 token 数为 T。
- 路由到 2 个专家的最大 token 数为 2 * T。
- …
- 路由到 K 个专家的最大 token 数为 K * T。
- 路由到 K + 1 个专家的最大 token 数仍为 K * T。
- …
因此,路由到由 N 个专家组成的专家组的最大 token 数将被限制为 min(N, K) * T。
对于 Top1 路由,无论专家组大小,路由到专家组的 token 数始终被限制为 T,并且动态 token 分配所需的最少内存是 EP * T,因为共有 EP 个专家组。
为了实现最小填充,我们直接使用 AllGather 从不同的 EP rank 收集所有活跃 token,然后通过自定义核函数在本地拆分和重新洗牌路由后的 token。激活值大小被压缩至 1 / (E // EP),这对应于内存和网络流量的减少。

上图展示了填充设计。每个方框代表一个 token,蓝色/绿色代表具有专家分配的有效 token,灰色代表填充 token。RiTj 代表专家并行组中第 i 个 rank 的第 j 个 token。
最小化填充带来的负面影响
尽管填充已减少到最低限度,我们也确保填充仅造成内存空间(分配)和网络流量(通信)的影响,而不会引起冗余计算(GroupedGEMM / NonLinear)或冗余内存带宽(CombineShuffling / SplitShuffling / ScatterAdd),具体通过使用设备内的形状信息 `routed_token_counts_per_expert` 或 `routed_token_counts_per_rank_per_expert` 来实现。

激活的概念性解释
最重要的是:
- 当所有 EP rank 上的活跃 token 总数很少时,这样做对于避免在 GroupedGEMM 中激活冗余专家并引起额外内存流量至关重要。
- 当所有 EP rank 上的活跃 token 总数很大时,这样做对于避免将 GroupedGEMM 从内存受限转变为计算受限也至关重要。
CombineShuffling:分配给当前 EP rank 的 token 在 AllGather 之后立即从“按专家排序”重新洗牌为“按 rank 排序”。未分配的 token 不会被复制,且张量末尾剩余的已分配内存空间保持不变。

SplitShuffling:分配给当前 EP rank 的 token 在 AlltoAll 之前立即从“按 rank 排序”重新洗牌为“按专家排序”。未分配的 token 不会被复制,重新洗牌后的张量以交错方式存储填充数据。

ScatterAdd(填充):每个 EP rank 最终接收到来自所有其他 rank 计算出的激活值,它会识别出哪些是有效 token,哪些是填充 token,然后仅读取有效 token 进行 scatter_add 操作。

通信去重
不同的张量并行 rank 在第一次 GroupedGEMM 之前和第二次 GroupedGEMM 之后具有相同的激活值,因此相同的 token 会在节点间重复交换。

我们启用了通信去重功能,将跨节点通信工作负载均匀分配给不同的 rank,同时引入额外的节点内通信。以 DP2/TP8/EP2 为例:
- 对于 eager 模式下的第一次 AlltoAll,将 T*D 的跨节点 AlltoAll 拆分为 T*D/8 的跨节点 AlltoAll 和 T*D 的节点内 AllGather。

- 对于 eager / 图模式下的第二次 AlltoAll,将 T*D 的跨节点 AlltoAll 拆分为 T*D/8 的节点内 ReduceScatter 和 T*D/8 的跨节点 AlltoAll。

- 对于图模式下的第一次 AllGather,将 2*T*D 的跨节点 AlltoAll 拆分为 2*T*D/8 的跨节点 AllGather 和 2*T*D 的节点内 AllGather。
核函数设计
我们实现了超过 10 个自定义核函数,以支持在 Nvidia H100 和 AMD MI300X GPU 上运行不同用例的 MetaShuffling MoE 推理设计。我们将所有计算核函数作为 PyTorch 算子开源在 FBGEMM 生成式 AI 核函数库中。我们希望它能帮助用户以他们首选的框架和加速器高效服务 Llama 4 模型,例如 vLLM / SGLang。在本博客中,我们将聚焦于对提升推理性能至关重要的两个最有趣的核函数设计:GroupedGEMM 和 IndexShuffling。
GroupedGEMM
我们为 BF16 / FP16 / FP8 行列量化实现了基于 Triton 的 GroupedGEMM 核函数。
接口
def grouped_gemm_fp8_rowwise(
x: torch.Tensor, # shape: [M, K]
w: torch.Tensor, # shape: [G*N, K]
m_sizes: torch.Tensor, # shape: [G]
x_scales: torch.Tensor, # shape: [M]
w_scales: torch.Tensor, # shape: [G*N]
) -> torch.Tensor: # shape: [M, N]
...
其接口与单 GEMM 非常相似,接收单一 LHS、单一 RHS 张量并产生单一输出。从运行时角度看,没有动态性或稀疏性。
然而,该核函数使用 `m_sizes` 数据动态拆分 LHS 张量的 M 维度,并使用 `m_sizes` 的形状静态拆分 RHS 张量的 N 维度。此设计具有多项优势:
- 在不同 M 批次内无需额外的填充或对齐要求。因此 `m_sizes` 可以存储任何非负值,只要其总和不超过 `M`。
- `m_sizes` 可以为零,以跳过加载未激活专家的权重。
- `m_sizes` 的总和可以小于 `M`,以跳过对末尾填充 token 的计算,且无额外开销。
- `m_sizes`(即 LHS 激活的分片)对于设备已知但对于主机未知。因此,它支持动态路由信息而无需引入设备到主机的同步。
工作负载分区
我们采用持久化核函数(persistent kernel)设计,每个 SM 启动 1 个 CTA,并让所有 CTA 以交错方式运行所有分区 tile。从概念上讲,工作负载分区如下所示。

def partition_workload(G: int, Ms: List[int], N: int):
partitions = []
for g in range(G):
for n in range(0, N, BLOCK_N):
for m in range(0, Ms[g], BLOCK_M):
partitions.append((g, m, n))
paritions_per_cta = [[] for _ in NUM_SMS]
for i, part in enumerate(partitions):
paritions_per_cta[i % NUM_SMS].append(part)
分区在运行时于设备端动态计算,开销很小。然而,通过这样做,我们可以实现:
- 跨不同 SM 的均衡工作负载。
- 较小的启动开销,因为每个 SM 仅启动 1 个 CTA。
- 高 L2 缓存命中率。工作负载分区的顺序确保了权重/激活值很可能从 HBM 加载一次并缓存在 L2 中。因为相同权重/激活 tile 的使用几乎总是由不同的 SM 并发/连续发生。
具有 warp 专门化的持久化核函数

我们在 Hopper GPU 上采用了基于主机端张量映射的激活和权重加载,以及可选的基于设备端张量映射的输出存储,以减少内存传输开销。利用激活的连续存储格式,我们可以使用单一主机端 TMA(张量内存加速器)描述符来加载激活,并屏蔽掉属于其他专家的 token。然而,我们需要创建多个设备端 TMA 描述符来存储输出,而无需动态屏蔽支持。
我们采用了基于 warp 专门化(warp specialization)的核函数设计,使核函数以真正的持久化方式运行,即每个 SM 在 3 个 warp 组(1 个生产者和 2 个消费者)之间切换。此设计使 TMA 引擎、Tensor Core 和 CUDA Core 的执行能够彼此重叠,利用异步 TMA 指令以及带共享内存内存屏障的 WGMMA(异步 Warpgroup 级矩阵乘加)指令。我们获得了 Meta Triton 编译器团队的巨大帮助来启用它。只有通过 warp 专门化才能隐藏序言和尾声,因为传统的软件流水线方法无法处理带有指针追踪的复杂控制流。
IndexShuffling(索引洗牌)
我们实现了基于 CUDA / HIP 的索引洗牌核函数。
接口
def index_shuffling(
scores: torch.Tensor, # shape: [T, E]
):
token_counts: torch.Tensor = ... # shape: [E]
expert_indices: torch.Tensor = ... # shape: [T]
token_indices: torch.Tensor = ... # shape: [T]
return token_counts, expert_indices, token_indices
该核函数获取所有 token 在所有专家上的路由分数,找出每个 token 路由到的特定专家,对 token 索引进行重新排序,使得路由到同一专家的所有 token 被连续放置,并返回:
- `token_counts`:路由到每个专家的 token 数量。它将被输入到上述 GroupedGEMM 核函数中。
- `expert_indices`:每个洗牌后 token 所属的专家索引。它将被输入到上述 GatherMul 核函数中。
- `token_indices`:每个洗牌后 token 的原始 token 索引。它将被输入到上述 GatherMul 和 ScatterAdd 核函数中。
协作核函数
我们采用了协作核函数设计,将核函数分为两个主要阶段:Top-k 归约(reduction)阶段和桶排序(bucket sort)阶段,中间进行一次全局同步。

- 1. 加载分数:
- 它将一块路由分数从全局内存(HBM)加载到共享内存(SMEM),并随之将关联的专家索引存储在 SMEM 上。
- 2. 归约:
- 在 E 维度上对 SMEM 执行 TopK 归约。对于 Llama 4 用例,它执行作为 Top1 归约的 ArgMax 排序,其中包括在 SMEM 上对分数和关联专家索引进行的 2D 并行树归约。在不同的树归约阶段之间:
- 所有线程将并发地在 SMEM 上处理多个 token 的归约。
- 每个线程将顺序地在 SMEM 上处理多个 token 的归约。
- 在 E 维度上对 SMEM 执行 TopK 归约。对于 Llama 4 用例,它执行作为 Top1 归约的 ArgMax 排序,其中包括在 SMEM 上对分数和关联专家索引进行的 2D 并行树归约。在不同的树归约阶段之间:
- 3. 计数与存储缓冲区:
- 它迭代 tile 上的所有 token,从 SMEM 获取选定的专家索引,将其存储到 HBM 上的缓冲区(`buf_expert_index`),并在 HBM 上的输出计数器(`token_counts`)上执行 `atomicAdd` 操作。
- 有趣的是,`atomicAdd` 操作会返回内存位置之前的值,这表明了该 token 在组内的位置,我们将该值存储在缓冲区(`buf_local_token_index`)中,并用它来确定所有 token 之间的全局顺序。
- 重复 1-3 步,直到分配给 CTA 的所有 token 都被处理完毕。
- 4. 全局同步:
- 它在 HBM 的全局计数器上执行 `atomicAdd` 操作。此后,所有 CTA 将等待全局计数器达到 token 总数,并使用 `st.release` + `ld.aquire` 屏障来保护前面的存储操作和后续的加载操作,以确保正确性。
- 5. 扫描(Scan):
- 它执行简单的加载和 `token_counts` 前缀和,并将其转换为 SMEM 上的 `token_counts_cumsums`。
- 6. 加载缓冲区与存储输出:
- 它迭代所有分配给此 CTA 的 token。对于每个 token,它从 `buf_expert_index` 加载该 token 所属的专家索引,然后将洗牌后的新 token 索引计算为以下两项之和:
- 其之前属于先前专家的 token 数量,使用 SMEM 张量 `token_counts_cumsums`。
- 其之前属于相同专家的 token 数量,使用 HBM 张量 `buf_local_token_index`。
- 之后,它直接将 `expert_indices` 和 `token_indices` 输出存储在洗牌后的新 token 索引位置。
- 它迭代所有分配给此 CTA 的 token。对于每个 token,它从 `buf_expert_index` 加载该 token 所属的专家索引,然后将洗牌后的新 token 索引计算为以下两项之和:
性能
核函数性能示例
我们的配置使用 H100 80GB SMX5 HBM3 700W SKU、Python 3.12 和 CUDA 12.8。单个 H100 上的理论峰值 HBM 内存带宽为 3.35 TB/s。
GroupedGEMM
Prefill(预填充)性能
下表显示了 Llama 4 Scout 和 Maverick 单机服务中的核函数预填充性能。实验设置假设总 token 数为 16,384,并采用张量并行分片。
| 精度 | G | M | N | K | 时间
(us) |
Compute
(TFlops) |
Memory
(GB/s) |
| BF16 | 16 | 1,024 | 2,048 | 5,120 | 523.85 | 655.90 | 1,088.90 |
| BF16 | 16 | 1,024 | 5,120 | 1,024 | 294.95 | 582.46 | 1,251.39 |
| BF16 | 128 | 128 | 2,048 | 5,120 | 975.41 | 352.26 | 2,992.82 |
| BF16 | 128 | 128 | 5,120 | 1,024 | 510.78 | 336.35 | 3,021.86 |
| FP8 | 16 | 1,024 | 2,048 | 5,120 | 286.89 | 1,197.64 | 1,111.10 |
| FP8 | 16 | 1,024 | 5,120 | 1,024 | 182.41 | 941.83 | 1,471.62 |
| FP8 | 128 | 128 | 2,048 | 5,120 | 517.16 | 664.40 | 2,887.28 |
| FP8 | 128 | 128 | 5,120 | 1,024 | 290.25 | 591.90 | 2,947.93 |
注意:G 表示组数。M 表示每组的 token 数。N 表示每组的输出特征维度。K 表示每组的输入特征维度。FP8 表示 FP8 行列量化(激活上的每个 token 缩放和权重上的每个通道缩放)以及快速累加。量化核函数未包含在基准测试中。缩放因子未包含在内存带宽计算中。使用旋转缓冲区和 CUDAGraphs 进行基准测试。
Decode(解码)性能
下表显示了 Llama 4 Scout 和 Maverick 单机服务中的核函数解码性能。实验设置假设总 token 数为 128,并采用张量并行分片。
| 精度 | G | M | N | K | 时间
(us) |
Compute
(TFlops) |
Memory
(GB/s) |
| BF16 | 16 | 8 | 2,048 | 5,120 | 112.54 | 23.85 | 2,997.82 |
| BF16 | 16 | 8 | 5,120 | 1,024 | 60.00 | 22.37 | 2,822.50 |
| BF16 | 128 | 1 | 2,048 | 5,120 | 861.22 | 3.12 | 3,119.07 |
| BF16 | 128 | 1 | 5,120 | 1,024 | 433.15 | 3.10 | 3,102.26 |
| FP8 | 16 | 8 | 2,048 | 5,120 | 59.81 | 44.88 | 2,824.60 |
| FP8 | 16 | 8 | 5,120 | 1,024 | 34.86 | 38.50 | 2,447.64 |
| FP8 | 128 | 1 | 2,048 | 5,120 | 440.53 | 6.09 | 3,049.44 |
| FP8 | 128 | 1 | 5,120 | 1,024 | 225.14 | 5.96 | 2,987.15 |
IndexShuffling(索引洗牌)
下表显示了 Llama 4 Scout 和 Maverick 单机服务中的核函数性能,与原生 PyTorch 实现进行了对比。
| Num Tokens | Num Experts | IndexShuffling (us) | Unfused Ops (us) | 加速比 |
| 128 | 16 | 5.08 | 36.71 | 722.53% |
| 128 | 128 | 8.95 | 34.36 | 384.05% |
| 2048 | 16 | 7.06 | 57.10 | 808.51% |
| 2048 | 128 | 13.53 | 69.84 | 516.18% |
| 4096 | 16 | 7.42 | 68.96 | 929.98% |
| 4096 | 128 | 18.89 | 87.46 | 463.09% |
| 8192 | 16 | 9.26 | 123.94 | 1339.16% |
| 8192 | 128 | 30.56 | 165.21 | 540.71% |
注意:使用旋转缓冲区和 CUDAGraphs 进行基准测试。
Trace(追踪)分析示例
Llama 4 Scout BF16 解码
这是使用我们的 MetaShuffling MoE 推理解决方案对 64 个 token 的 Llama 4 Scout BF16 进行解码的示例追踪。

- MoE 的总内存流量为(忽略激活值):
- 路由器:5120x16x2 = 163,840 字节
- 共享专家:(2048×5120 + 5120×1024)x2 = 31,457,280 字节
- 路由专家:16x(2048×5120 + 5120×1024)x2 = 503,316,480 字节
- 总合计:163,840 + 31,457,280 + 503,316,480 = 534,937,600 字节
- MoE 的总执行时间为 197.456us。实现的内存带宽为 534,937,600 / (197.456 * 10^-6) = 2,709,148,367,231 字节/秒 ~= 2.71 TB/s,即达到 H100 80GB SMX5 HBM3 理论峰值 HBM 内存带宽(3.35 TB/s)的 80.90%。
以下是追踪中不同组件的细分。

首先是路由和共享专家的细分。这两个组件在 2 个不同的流上同时运行,以实现更好的资源利用。
对于路由器流(标有红色框):
- 1. Router GEMM:基于 CuBLAS 的 split-k 设计 GEMM。它启动 2 个核函数,第二个核函数为归约核函数。
- 2. Sigmoid(路由激活):PyTorch 原生 sigmoid。
- 3. IndexShuffling:基于 FBGEMM 的索引洗牌,采用协作核函数设计。它可视为 topk、bincount 和 sort 这 3 个操作的融合。它启动 2 个核函数,第一个核函数为设置核函数。
- 4. GatherMul:基于 FBGEMM 的收集缩放。它可视为 gather(token)、gather(分数)和 mul 这 3 个操作的融合。
对于共享专家流(标有橙色框):
- 5. SharedExpert GEMM13:基于 CuBLAS 的 split-k 设计 GEMM。它启动 2 个核函数,第二个核函数为归约核函数。
- 6. SwiGLU:融合的 SwiGLU。它可视为 sigmoid 和 mul 这 2 个操作的融合。
- 7. SharedExpert GEMM2:基于 CuBLAS 的 GEMM。

其次是路由专家的细分。该组件在 1 个流上排他性运行,让 GroupedGEMM 核函数完全占有所有 SM。
对于路由专家流(标有红色框):
- 8. RoutedExperts GroupedGEMM13:基于 FBGEMM 的持久化 GroupedGEMM。
- 9. SwiGLU:融合的 SwiGLU。同上文 6。
- 10. RoutedExperts GroupedGEMM2:基于 FBGEMM 的持久化 GroupedGEMM,在尾声中与 scatter add 融合。
解码步骤在使用 CUDAGraph 的静态形状稠密张量上运行。
Llama 4 Maverick FP8 预填充
这是使用我们的 MetaShuffling MoE 推理解决方案对 5000 个 token 的 Llama 4 Maverick FP8 进行预填充的示例追踪。注意路由专家使用 FP8 行列量化,路由器和共享专家使用 BF16。

与解码追踪相比:
- 它使用单个流来避免路由器和共享专家之间核函数的相互干扰。由于核函数正在处理足够大的问题规模以饱和计算资源,额外的重叠只会导致资源争用,尤其是在 L2 缓存上。
- 它在静态形状的稠密张量上运行,但在 eager 模式下。由于核函数执行时间足够长,且没有设备/主机同步,核函数可以连续启动而无空闲等待。
在此,我们重点介绍这两个追踪之间的核函数差异,执行时间除外。
- Router GEMM 和 SharedExpertGEMM13:基于 CuBLAS 且不使用 split-k 设计。因此它启动 1 个核函数而不是 2 个。

- 4. GatherMul (FP8 行列量化):基于 FBGEMM 的收集缩放和量化。它可视为 gather(token)、gather(分数)、mul、max、divide、mul、clamp 和 typecast 这 8 个操作的融合。
- 9. SwiGLU (FP8 行列量化):融合的 SwiGLU 和量化。它可视为 sigmoid、mul、max、divide、mul、clamp 和 typecast 这 7 个操作的融合。

总结
我们采取以下步骤逐步优化 MoE 解决方案的推理性能:
-
- 提高设备级利用率,通过避免主机与设备同步。
- 减少资源浪费,通过移除填充或避免处理填充。
- 减少核函数启动和 I/O 开销,通过积极的核函数融合。
- 提高计算和内存效率,通过各种核函数优化,将性能推向硬件极限。
- 提高硬件组件级利用率,通过计算、内存流量或网络流量密集型核函数的并发执行,同时避免不必要的争用。
单机服务
我们使用 1000 个随机提示的请求,对我们内部基于 MetaShuffling 的 MoE 推理栈进行了 Llama 4 Maverick 和 Llama 4 Scout 的单机服务性能基准测试。我们在 8xH100 主机上运行 Maverick(使用 FP8)和 Scout(使用 BF16),最大批处理量为 64。我们的配置使用 H100 80GB SMX5 HBM3 700W SKU、Python 3.12 和 CUDA 12.8。我们已将 MetaShuffling MoE 推理栈中使用的所有计算核函数开源在 FBGEMM 上,并提供了一个 MetaShuffling 实现示例作为参考。
为了保持最佳精度,我们对 Llama 4 Maverick 路由专家使用 FP8 精度,对注意力线性层、注意力、共享专家、路由器和 KV 缓存使用 BF16 精度。
我们对 Llama 4 Scout 的所有线性层(注意力线性、共享专家、路由器和路由专家)、注意力及 KV 缓存均使用了 BF16 精度。
最后,我们希望社区能不断刷新 Llama 4 模型服务的效率记录,并期待有更好的数据出现。
致谢
感谢 Jing Zhang、Ying Zhang 和 Manman Ren 为本项目提供的技术评审和指导。
我们还要感谢 Bradley Davis, Yan Cui, Rengan Xu, Josh Fromm, Jiawen Liu, Sarunya Pumma, Jie Wang, Xinfeng Xie, Benson Ma, Michael Shu, Bingzhe Liu, Jingyi Yang, Min Si, Pavan Balaji, Dhruva Kaushal 对本项目的贡献。


