跳转到主要内容

混合专家 (MoE) 是一种流行的大型语言模型 (LLM) 模型架构。尽管它通过每个 token 激活更少的参数来减少训练和推理中的计算量,但它在实现最佳计算效率方面带来了额外的挑战,伴随着高内存和通信压力,以及处理模型动态性和稀疏性的复杂性。在此,我们介绍一种新的 MoE 推理解决方案,MetaShuffling,它使我们能够高效部署 Llama 4 模型进行生产推理。

Llama 4 架构

Llama 4 Scout 和 Maverick 模型已正式发布。Scout / Maverick 具有一个共享专家和 16 / 128 个路由专家,采用无丢弃 token 选择路由和每个 MoE 层的 Top-1 选择。此外,共享专家和路由专家都使用带有 3 个线性层的 SwiGLU 激活函数。有关模型的更多信息,请参阅Llama 4 群:原生多模态 AI 创新新纪元的开始

核心概念

有多种常见的解决方案来处理 MoE 层中引入的动态性和稀疏性问题。在此,我们演示了使用 Top-1 选择的 token 选择路由的不同解决方案。

上图显示了填充设计。每个框代表一个 token,黄色/绿色代表具有不同路由专家的有效 token,灰色代表填充 token。第二步中的每行框代表不同的路由专家。Ti 代表数据并行组当前 rank 中的第 i 个 token。

  • 填充:在这种方法中,我们将每个专家的激活填充到最大序列长度并运行单个批量矩阵乘法 (BMM)。它导致:
    • 持有填充物会增加内存。
    • 处理填充会增加延迟。请注意,可以通过锯齿状核避免处理填充,但当专家数量较多时,锯齿状核也可能导致高开销。
  • 切片:在这种方法中,我们根据每个专家的精确序列长度切片激活,并运行多个矩阵乘法 (MM)。它避免了填充问题,但会导致:
    • 核效率降低,因为对小形状重复启动核。
    • 设备利用率降低,因为对动态形状频繁进行主机和设备同步,以及额外的核启动开销,因为它与图捕获机制(例如 CUDAGraph 和 torch.compile)不兼容。

  • 拼接:在这种方法中,我们进一步在切片后拼接激活并运行单个分组矩阵乘法 (GMM)。它避免了切片中的核效率问题,但仍会导致:
    • 设备利用率降低,因为它仍然需要主机和设备同步,并且仍然与图捕获机制不兼容。

为了进一步改进解决方案,我们提出了一种基于洗牌的机制

  • 洗牌:在这种方法中,我们直接对 token 进行排序,使路由的 token 按路由专家的 ID 排序。通过这样做,不会引入填充或拆分,分配给同一专家的 token 被存储在一起,可以在 GroupedGEMM 内部一起处理。它提供了一个密集模型接口,并避免了上述所有问题。
    • 不填充,因为激活仍然是密集张量。
    • 没有主机和设备同步,因为激活仍然是静态形状的张量。

我们基于此设计构建了一个端到端 MoE 推理解决方案,MetaShuffling

运行时设计

单 GPU 推理不并行

上图是单 GPU 推理(无模型并行)的整体运行时设计。请注意,为了优化性能,SwiGLU 激活的第一个和第三个线性层合并为 GroupedGEMM13 / GEMM13。

  • 实心深蓝色/橙色框表示路由/共享专家流上的 Tensor Core 密集型核。
  • 实心浅蓝色/橙色框表示路由/共享专家流上的 CUDA 核或内存流量密集型核。
  • 红色箭头表示激活张量的数据流。
  • 绿色箭头表示元数据张量的数据流。

所有元数据张量都放置在设备上。没有阻塞设备到主机同步。所有核都背靠背启动,没有气泡。该图仅显示数据流,不代表实际的分析轨迹。

核接口和数据流

  • RoutingScores:处理路由分数计算的函数或融合核。
    • 输入:input_tokens: [T, D] (T: token 数量;D: 特征维度);router_weights: [D, E] (E: 专家数量);router_biases: [E];
    • 输出:routing_scores: [T, E];scaling_factors: [T, E];
  •  IndexShuffling:处理索引洗牌和排序的融合核。我们将在核设计部分介绍优化实现。
    • 输入:routing_scores: [T, E];K(Top-K 路由的阈值);
    • 输出:routed_token_indices: [K * T];routed_expert_indices: [K * T];routed_token_counts_per_expert: [E];
  • GatherMul:根据排序的索引对 token 进行洗牌和缩放的融合核。
    • 输入:input_tokens: [T, D];routed_token_indices: [K * T];routed_expert_indices: [K * T];scaling_factors: [T, E];
    • 输出:scaled_routed_tokens: [K * T, D]
  • GroupedGEMM优化后的 GroupedGEMM 核,处理 M 维度上批量数据的设备内形状信息,无限制。我们将在核设计部分介绍优化实现。
    • 输入:tokens: [K * T, D];weights: [E, D, HD] (HD: 隐藏维度);routed_token_counts_per_expert: [E];
    • 输出:tokens: [K * T, HD]
  • GEMM:一个优化的 GEMM 核。与密集模型接口类似。
  • 非线性:处理非线性的融合核。与密集模型接口类似。
  • ScatterAdd:一个优化的核,根据排序的索引反转 token 洗牌,并直接对共享专家输出执行 scatter add,而无需实例化未洗牌的张量。
    • 输入:shared_output_tokens: [T, D];routed_output_tokens: [K * T, D];routed_token_indices: [K * T]; 
    • 输出:combined_output_tokens: [T, D]

请注意,如果应用了量化,则激活量化核会融合到前置的非 GEMM 核中,这意味着对于 GroupedGEMM13 融合到GatherMul,对于 GroupedGEMM2 等融合到NonLinearity

请注意,如果使用较大的 K * T,GatherMul 和 ScatterAdd 操作可以进一步融合到后续/前置的 GroupedGEMM 操作中,这应该在序言/尾声中作为全局内存到共享内存/寄存器或共享内存到全局内存的步骤完成,但是,它在核设计级别上增加了与张量核心执行重叠的额外挑战。此外,融合 ScatterAdd 需要共享专家在路由专家之前完成,如果这些核可以用来隐藏 AlltoAll 延迟,这可能不是一个好的设计选择。

单主机推理的张量并行

上图是带有张量并行(TP)的单主机推理的整体运行时设计。与单 GPU 推理相比,额外步骤是

  • 实心浅薄荷色框表示网络流量密集型通信核。

所有元数据张量仍放置在设备上,没有设备到主机同步。所有核都背靠背启动,没有气泡。该图仅显示数据流,不代表实际的分析轨迹。

工作负载分片和附加核

与单 GPU 推理用例相比,没有引入额外的自定义核。对于 GEMM、GroupedGEMM 和非线性核,激活和权重都按不同维度共享到 1/TP,计算/内存开销也共享到 1/TP。

如果只应用了张量并行,最后一步应该是 AllReduce。或者,如果张量并行与序列并行一起应用,则为 ReduceScatter。

多主机推理的专家并行

为了实现专家并行(EP),我们将数据并行维度从路由专家中交换为路由专家内部的专家并行维度。请注意,张量并行可以进一步与专家并行交换,以获得更好的 GEMM 效率,但会增加路由不平衡的风险,但我们不会在本博客中介绍此设计。

如果启用了专家并行和 token 选择路由,那么我们必须决定是使用密集张量还是使用静态形状,因为路由到不同专家组的 token 数量是动态的。

  • 当优先使用 eager 模式以避免运行未填充的 AlltoAll 造成的网络流量和内存空间浪费时,我们使用密集张量和动态形状。
  • 当优先使用图模式以避免因 CPU 启动开销和通过 CUDAGraph 运行的设备到主机同步而导致的 GPU 气泡时,我们使用稀疏张量和静态形状。

请注意,通过自定义 AlltoAll 实现也可以避免填充激活造成的网络流量浪费,但我们不会在本博客中介绍任何关于自定义通信或通信与计算融合核的主题。

上图是多主机推理(带张量并行和专家并行)的整体运行时设计。与单主机推理(带张量并行)相比。

  • 实心红色箭头表示节点内通信。
  • 实心紫色箭头表示节点间通信。

核接口和数据流

对于新增的基于专家并行的通信,我们使用 3 步 All2All 通信来交换形状和 token

  • 第一次 A2A:交换关于路由到每个专家的 token 数量的设备内元数据张量,即 `routed_token_counts_per_expert: [E]`,这是从IndexShuffling核生成的输出。
  • 第二次 A2A:将 token 从基于数据并行交换到基于专家并行,根据路由分发到不同的 EP rank。
  • 第三次 A2A:将 token 从基于专家并行交换到基于数据并行,根据路由从不同的 EP rank 组合。

此外,我们还增加了 2 个额外的洗牌核和 1 个特殊的散列核

  • CombineShuffling (密集或填充):将收到的 token 从 rank 优先顺序重新洗牌为专家优先顺序。随后的T*表示从所有对等方收到的 token 总数,可以进一步解释为具有routed_token_counts_per_rank_per_expert 张量形状信息的锯齿维度。
    • 输入received_tokens: [T*, D](首先按 dp rank 排序,然后按专家索引排序);routed_token_counts_per_rank_per_expert: [EP, E // EP];
    • 输出:reshuffled_tokens: [T*, D](首先按专家索引排序,然后按 dp rank 排序);routed_token_counts_per_expert: [E // EP];
  • SplitShuffling (密集或填充):CombineShuffling相反的过程。将要发送的 token 从专家优先顺序重新洗牌为 rank 优先顺序。
    • 输入reshuffled_tokens: [T*, D](首先按专家索引排序,然后按 dp rank 排序);routed_token_counts_per_rank_per_expert: [EP, E // EP];
    • 输出:to_send_tokens: [T*, D](首先按 dp rank 排序,然后按专家索引排序);
  • ScatterAdd(填充):从填充张量中分散添加验证 token。
    • 输入:shared_output_tokens: [T, D];received_padded_routed_output_tokens: [EP, K*T, D]; routed_token_indices: [K * T]; routed_token_counts_per_expert: [E];
    • 输出:combined_output_tokens: [T, D]

我们将在“图形模式下带静态形状的填充通信”部分详细演示上述核。

即时模式下的非填充动态形状通信

运行时行为的高级图。不同组件的实际运行时可能因软件和硬件而异。

最小化动态形状的使用

由于每个 MoE 层上的路由是动态的,因此所需的最小设备/主机同步是每层一次。为了实现这一点,我们延迟了 `send_sizes` 的 D2H 复制,并将其与 `recv_sizes` 连接起来,通过一次 D2H 复制将它们一起传输。这将设备/主机同步减少到每层一次。

最小化动态形状的负面影响

为了进一步隐藏设备/主机同步开销,我们将共享专家进一步分为两部分。

  • 我们将第一部分在路由后立即分派,但在分派 A2A 之前。然后,当发生设备/主机同步时,设备仍然忙于运行共享专家。
  • 我们在 MoE 之后但在结合 A2A 之前分派第二部分。这将进一步有助于重叠第二个 A2A。

图形模式下带静态形状的填充通信

最大限度地减少填充的使用

采用无丢弃 token 选择设计,路由到任何单个专家的最大可能 token 数为 T。但是,如果我们通过专家并行分片将多个专家分组并放置在单个 GPU 上,对于 TopK 路由,

  • 路由到 1 个专家的最大 token 数为 T。
  • 路由到 2 个专家的最大 token 数为 2 * T。
  • 路由到 K 个专家的最大 token 数为 K * T。
  • 路由到 K + 1 个专家的最大 token 数仍然是 K * T。

因此,路由到 N 个专家的专家组的最大 token 数将上限为 min(N, K) * T 个 token。

对于 Top1 路由,路由到任何大小专家组的 token 数量将始终上限为 T 个 token,并且分配和容纳动态 token 所需的最小内存为 EP * T 个 token,因为有 EP 个专家组。

为了实现最小所需的填充,我们直接使用 AllGather 从不同的 EP rank 收集所有活动 token,然后通过自定义核在本地拆分和重新洗牌路由的 token。激活大小被压缩到 1 / (E // EP),这对应于内存和网络流量的减少。

上图显示了填充设计。每个框代表一个 token,蓝色/绿色代表具有专家分配的有效 token,灰色代表填充 token。RiTj代表专家并行组中第 i 个 rank 的第 j 个 token。

最小化填充的负面影响

即使填充减少到最小允许值,我们还确保填充仅引起内存空间(分配)和网络流量(通信),而不会引起冗余计算(GroupedGEMM / NonLinear)、冗余内存带宽(CombineShuffling / SplitShuffling / ScatterAdd),通过使用设备上的形状信息`routed_token_counts_per_expert`或`routed_token_counts_per_rank_per_expert`。

激活概念性解释

最重要的是,

  • 当所有 EP rank 中的活动 token 总数较少时,这样做很重要,以避免在GroupedGEMM中激活冗余专家并导致额外的内存流量。
  • 当所有 EP rank 中的活动 token 总数较大时,这样做也很重要,以避免将GroupedGEMM从内存受限转换为计算受限。

CombineShuffling:在 AllGather 之后,将分配给当前 EP rank 的 token 从专家优先顺序重新洗牌为 rank 优先顺序。未分配的 token 不会被复制,并且张量末尾剩余的已分配内存空间保持不变。

SplitShuffling:分配给当前 EP rank 的 token在 AlltoAll 之前,将从 rank 优先顺序重新洗牌为专家优先顺序。未分配的 token 不会复制,重新洗牌后的张量将以交错方式存储填充。

ScatterAdd(填充):每个 EP rank 最终会收到从所有其他 rank 计算出的激活,它会知道哪些是有效 token,哪些是填充 token,然后只读取有效 token 进行 scatter_add。

通信去重

不同的张量并行 rank 在第一次 GroupedGEMM 之前和第二次 GroupedGEMM 之后具有相同的激活,因此相同的 token 会在节点之间重复交换。

我们启用了通信去重,以通过引入额外的节点内通信,将节点间通信工作负载均匀分布到不同的 rank。例如 DP2/TP8/EP2

  • 对于即时模式下的第一次 AlltoAll,将T*D节点间 AlltoAll 拆分为T*D/8节点间 AlltoAll 和T*D节点内 AllGather。

  • 对于即时/图模式下的第二次 AlltoAll,将T*D节点间 AlltoAll 拆分为T*D/8节点内 ReduceScatter 和T*D/8节点间 AlltoAll。

  • 对于图模式下的第一次 AllGather,将2*T*D节点间 AlltoAll 拆分为2*T*D/8节点间 AllGather 和2*T*D节点内 AllGather。

核设计

我们实现了 10 多个自定义核来支持MetaShuffling MoE 推理设计,适用于 Nvidia H100 GPU 和 AMD MI300X GPU 上的不同用例。我们将所有计算核开源为 PyTorch 算子,位于FBGEMM 生成式 AI 核库中。我们希望它能帮助用户在其首选框架和加速器中高效服务 Llama 4 模型,例如 vLLM / SGLang。在本博客中,我们将重点介绍 2 个最有趣的核设计,它们是提高推理性能的关键:GroupedGEMMIndexShuffling

分组 GEMM

我们为 BF16 / FP16 / FP8 Rowwise 实现了基于 Triton 的 GroupedGEMM 核。

接口

def grouped_gemm_fp8_rowwise(
	x: torch.Tensor, 		# shape: [M, K]
	w: torch.Tensor, 		# shape: [G*N, K]
	m_sizes: torch.Tensor, 	# shape: [G]
	x_scales: torch.Tensor,	# shape: [M]
	w_scales: torch.Tensor, 	# shape: [G*N]
) -> torch.Tensor:               # shape: [M, N]
	...

该接口与单个 GEMM 非常相似,它接受单个 LHS、单个 RHS 张量,并产生单个输出。从运行时角度来看,没有动态性或稀疏性。

然而,核动态地使用`m_sizes`数据分割 LHS 张量的`M`维度,并静态地使用`m_sizes`形状分割 RHS 张量的`N`维度。这种设计有几个优点:

  • 不同批次的 M 中没有额外的填充或对齐要求。因此 `m_sizes` 可以存储任何非负值,只要其总和不超过 `M`。
  • “m_sizes”可以是零值,以跳过加载未激活专家的权重。
  • `m_sizes` 的总和可以小于 `M`,以跳过对末尾填充 token 的计算而无需额外开销。
  • m_sizes,或 LHS 激活的拆分,对于设备是已知的,但对于主机是未知的。因此,它支持动态路由信息,而不会导致设备到主机的同步。

工作负载分区

我们采用持久化核设计,为每个 SM 启动 1 个 CTA,并让所有 CTA 交错地运行所有已分区磁贴。从概念上讲,工作负载分区如下所示。


def partition_workload(G: int, Ms: List[int], N: int):
	partitions = []
	for g in range(G):
		for n in range(0, N, BLOCK_N):
			for m in range(0, Ms[g], BLOCK_M):
				partitions.append((g, m, n))
	paritions_per_cta = [[] for _ in NUM_SMS]
	for i, part in enumerate(partitions):
		paritions_per_cta[i % NUM_SMS].append(part)

分区是在设备端运行时动态计算的,开销很小。然而,通过这样做,我们可以实现:

  • 不同 SM 之间工作负载均衡。
  • 启动开销小,因为每个 SM 只会启动 1 个 CTA。
  • L2 缓存命中率高。工作负载分区的顺序确保权重/激活最有可能从 HBM 加载一次并缓存到 L2。因为相同权重/激活磁贴的使用几乎总是同时/连续地发生在不同 SM 上。

带 Warp 专用化的持久核

我们采用了基于主机端张量映射的激活和权重加载,以及可选的基于设备端张量映射的输出存储,以减少 Hopper GPU 上的内存传输开销。对于连续存储格式的激活,我们可以使用单个主机端 TMA (张量内存加速器) 描述符来加载激活并屏蔽掉属于其他 token 的 token。但是,我们需要创建多个设备端 TMA 描述符来存储输出,而无需动态屏蔽支持。

我们采用了基于 warp 专用化的核设计,使核以真正的持久方式运行,即每个 SM 在 3 个 warp 组(1 个生产者和 2 个消费者)之间切换。这种设计使 TMA 引擎、Tensor Core 和 CUDA Core 执行相互重叠,利用异步 TMA 指令和 WGMMA(异步 Warpgroup 级矩阵乘法累加)指令以及共享内存上的内存屏障。我们得到了 Meta Triton 编译器团队的大力帮助来实现这一点。只有通过 warp 专用化才能隐藏序言和尾声,因为传统的软件流水线方法无法处理复杂的控制流和指针追逐。

索引洗牌

我们实现了基于 CUDA/HIP 的索引洗牌核。

接口

def index_shuffling(
	scores: torch.Tensor,			        # shape: [T, E]
):
	token_counts: torch.Tensor = ...		# shape: [E]
	expert_indices: torch.Tensor = ...	        # shape: [T]
	token_indices: torch.Tensor = ...		# shape: [T]
	return token_counts, expert_indices, token_indices

该核获取所有 token 在所有专家上的路由分数,找出每个 token 被路由到的特定专家,重新排列 token 索引,使所有路由到同一专家的 token 连续放置,并返回:

  • `token_counts`:每个专家路由的 token 数量。它将作为输入传递给上面讨论的GroupedGEMM核。
  • `expert_indices`:每个洗牌 token 所属的专家索引。它将作为输入传递给上面讨论的GatherMul核。
  • `token_indices`:每个洗牌 token 所属的原始 token 索引。它将作为输入传递给上面讨论的GatherMulScatterAdd核。

协同核

我们采用了协同核设计,将核分为两个主要阶段:top-k 归约阶段和桶排序阶段,中间有一个全局同步。

  • 1. 加载分数: 
    • 它将路由分数块从全局内存 (HBM) 加载到共享内存 (SMEM),并将相关的专家索引也存储到 SMEM 上。
  • 2. 归约: 
    • 在 SMEM 上对 E 维度执行 TopK 归约。对于 Llama 4 用例,它执行 ArgMax 排序作为 Top1 归约,其中包括在 SMEM 上对分数和相关专家索引进行 2D 并行树归约。在不同的树归约阶段之间,
      • 所有线程将同时对 SMEM 上的多个 token 进行归约。
      • 每个线程将顺序对 SMEM 上的多个 token 进行归约。
  • 3. 计数与存储缓冲区:  
    • 它迭代磁贴上的所有 token,从 SMEM 获取选定的专家索引,将其存储到 HBM 上的缓冲区 (`buf_expert_index`),并对 HBM 上的输出计数器 (`token_counts`) 执行 `atomicAdd` 操作。 
    • 有趣的是,`atomicAdd` 操作将返回内存位置上先前的值,这表示 token 在组内的位置,我们将此值存储在缓冲区 (`buf_local_token_index`) 中,并用它来确定所有 token 的全局顺序。
  • 重复 1-3 步,直到处理完分配给 CTA 的所有 token。
  • 4. 全局同步: 
    • 它对 HBM 上的全局计数器执行 `atomicAdd` 操作。之后,所有 CTA 将等待,直到全局计数器达到 token 总数,并由 `st.release` + `ld.acquire` 屏障保护先前的存储操作和随后的加载操作,以确保正确性。
  • 5. 扫描: 
    • 它执行简单的加载和 `token_counts` 的前缀和,并将其转换为 SMEM 上的 `token_counts_cumsums`。
  • 6. 加载缓冲区并存储输出: 
    • 它迭代分配给此 CTA 的所有 token。对于每个 token,它从`buf_expert_index`加载分配给该 token 的专家索引,然后通过以下两者的总和计算洗牌后的新 token 索引: 
      • 在它之前属于先前专家的 token 数量,使用 SMEM 张量 `token_counts_cumsums`。
      • 在它之前属于同一专家的 token 数量,使用 HBM 张量 `buf_local_token_index`。
    • 然后,它只需直接将 `expert_indices` 和 `token_indices` 输出存储在洗牌后的新 token 索引处。

性能

示例核性能

我们的设置使用了 H100 80GB SMX5 HBM3 700W SKU、Python 3.12 和 CUDA 12.8。单个 H100 的理论峰值 HBM 内存带宽为 3.35 TB/s。

分组 GEMM

预填充性能

下表显示了 Llama 4 Scout 和 Maverick 单主机服务上核的预填充性能。实验设置假设 token 总数为 16,384,并进行了张量并行分片。

精度 G M N K 时间

(微秒)

计算

(TFlops)

内存

(GB/秒)

BF16 16 1,024 2,048 5,120 523.85 655.90 1,088.90
BF16 16 1,024 5,120 1,024 294.95 582.46 1,251.39
BF16 128 128 2,048 5,120 975.41 352.26 2,992.82
BF16 128 128 5,120 1,024 510.78 336.35 3,021.86
FP8 16 1,024 2,048 5,120 286.89 1,197.64 1,111.10
FP8 16 1,024 5,120 1,024 182.41 941.83 1,471.62
FP8 128 128 2,048 5,120 517.16 664.40 2,887.28
FP8 128 128 5,120 1,024 290.25 591.90 2,947.93

注:G表示组数。M表示每组的token数。N表示每组的输出特征维度。K表示每组的输入特征维度。FP8表示FP8行向缩放(激活按token缩放,权重按通道缩放),具有快速累加。量化核未包含在基准测试中。缩放未包含在内存带宽计算中。使用旋转缓冲区和CUDAGraphs进行基准测试。

解码性能

下表显示了 Llama 4 Scout 和 Maverick 单主机服务上核的解码性能。实验设置假设 token 总数为 128,并进行了张量并行分片。

精度 G M N K 时间

(微秒)

计算

(TFlops)

内存

(GB/秒)

BF16 16 8 2,048 5,120 112.54 23.85 2,997.82
BF16 16 8 5,120 1,024 60.00 22.37 2,822.50
BF16 128 1 2,048 5,120 861.22 3.12 3,119.07
BF16 128 1 5,120 1,024 433.15 3.10 3,102.26
FP8 16 8 2,048 5,120 59.81 44.88 2,824.60
FP8 16 8 5,120 1,024 34.86 38.50 2,447.64
FP8 128 1 2,048 5,120 440.53 6.09 3,049.44
FP8 128 1 5,120 1,024 225.14 5.96 2,987.15

索引洗牌

下表显示了 Llama 4 Scout 和 Maverick 单主机服务上核的性能,并与原生 PyTorch 实现进行了比较。

token 数量 专家数量 索引洗牌(微秒) 未融合操作(微秒) 加速比
128 16 5.08 36.71 722.53%
128 128 8.95 34.36 384.05%
2048 16 7.06 57.10 808.51%
2048 128 13.53 69.84 516.18%
4096 16 7.42 68.96 929.98%
4096 128 18.89 87.46 463.09%
8192 16 9.26 123.94 1339.16%
8192 128 30.56 165.21 540.71%

注意:使用旋转缓冲区和 CUDAGraphs 进行基准测试。

示例跟踪分析

Llama 4 Scout BF16 解码

这是使用我们的MetaShuffling MoE 推理解决方案对 64 个 token 的 Llama 4 Scout BF16 进行解码的示例跟踪。 

  • MoE 的总内存流量(忽略激活)
    • 路由器:5120x16x2 = 163,840 字节
    • 共享专家:(2048×5120 + 5120×1024)x2 = 31,457,280 字节
    • 路由专家:16x(2048×5120 + 5120×1024)x2 = 503,316,480 字节
    • 总计:163,840 + 31,457,280 + 503,316,480 = 534,937,600 字节
    • MoE 的总执行时间为 197.456us。实现的内存带宽为 534,937,600 / (197.456 * 10^-6) = 2,709,148,367,231 字节/秒 ≈ 2.71 TB/s,是 H100 80GB SMX5 HBM3 理论峰值 HBM 内存带宽 3.35 TB/s 的 80.90%。

以下是跟踪中不同组件的细分。

首先是路由和共享专家的分解。这两个组件在两个不同的流上并行运行,以实现更好的资源利用率。

对于路由器流(红色框标记)

  • 1. 路由器 GEMM:基于 CuBLAS 的 GEMM,采用 split-k 设计。它启动 2 个核,第二个核是归约核。
  • 2. Sigmoid (路由器激活):PyTorch 原生 Sigmoid。
  • 3. IndexShuffling:基于 FBGEMM 的索引洗牌,采用协同核设计。可以看作是 topk、bincount 和 sort 3 个操作的融合。它启动 2 个核,第一个核是设置核。
  • 4. GatherMul:基于 FBGEMM 的 gather 缩放。可以看作是 3 个操作的融合:gather (token)、gather (分数) 和 mul 操作。

对于共享专家流(橙色框标记)

  • 5. 共享专家 GEMM13:基于 CuBLAS 的 GEMM,采用 split-k 设计。它启动 2 个核,第二个核是归约核。
  • 6. SwiGLU:融合的 SwiGLU。可以看作是 2 个操作的融合:sigmoidmul
  • 7. SharedExpert GEMM2:基于 CuBLAS 的 GEMM。

第二部分是路由专家的细分。该组件专门在一个流上运行,以让 GroupedGEMM 核完全拥有所有 SM。

对于路由专家流(红色框标记)

  • 8. RoutedExperts GroupedGEMM13:基于 FBGEMM 的 GroupedGEMM,采用持久核设计。 
  • 9. SwiGLU:融合的 SwiGLU。如 6 中所述。
  • 10. RoutedExperts GroupedGEMM2:基于 FBGEMM 的 GroupedGEMM,采用持久核设计,并在尾声中融合了 scatter add。

解码步骤在稠密张量上使用 CUDAGraph 以静态形状运行。

Llama 4 Maverick FP8 预填充

以下是使用我们的MetaShuffling MoE 推理解决方案的 Llama 4 Maverick FP8 预填充跟踪示例,包含 5000 个 token。请注意路由专家使用 FP8 行向缩放,路由器和共享专家使用 BF16。

与解码跟踪相比

  • 它使用单个流来避免路由器和共享专家之间的核交互。由于核正在处理足够大的问题规模以饱和计算资源,因此额外的重叠只会导致争用,尤其是在 L2 缓存上。
  • 它在具有静态形状的密集张量上运行,但在即时模式下。由于核执行时间足够长,并且没有设备/主机同步,因此核可以背靠背启动而没有气泡。

在此,我们重点介绍这两个跟踪之间的核差异,除了执行时间。

  • 路由器 GEMMSharedExpertGEMM13:基于 CuBLAS 的 GEMM,不使用 split-k 设计。因此它只启动 1 个核而不是 2 个。

  • 4. GatherMul (FP8 行向量化): 基于 FBGEMM 的 gather 缩放和量化。可以看作是 8 个操作的融合:gather (token)、gather (分数)、mul、max、divide、mul、clamp 和 typecast。
  • 9. SwiGLU(FP8 行向量化):融合的 SwiGLU 和量化。可以看作是 7 个操作的融合:sigmoidmul、max、divide、mul、clamp 和 typecast。

总结

我们逐步采取以下措施来优化我们的 MoE 解决方案的推理性能

    • 通过避免主机和设备同步来提高设备级利用率。
    • 通过移除填充或避免处理填充来减少资源浪费。
    • 通过积极的核融合来减少核启动和 I/O 开销。
    • 通过各种核优化来提高计算和内存效率,将性能推向硬件极限。
    • 通过计算、内存流量或网络流量密集型核的并发执行来提高硬件组件级别的利用率,但同时避免不必要的争用。

单主机服务

我们使用 1000 个带随机提示的请求,对 Llama 4 Maverick 和 Llama 4 Scout 的单主机服务性能进行了基准测试,采用我们内部基于MetaShuffling的 MoE 推理堆栈。我们在 8xH100 主机上运行 Maverick(FP8)和 Scout(BF16),最大批处理大小为 64。我们的设置使用了 H100 80GB SMX5 HBM3 700W SKU、Python 3.12 和 CUDA 12.8。我们已将用于MetaShuffling MoE 推理堆栈的所有计算核开源到 FBGEMM,并提供了一个MetaShuffling的示例实现作为参考。

为保持最佳精度,我们对 Llama 4 Maverick 的路由专家使用了 FP8 精度。对注意力线性层、注意力、共享专家、路由器和 KV 缓存使用了 BF16 精度。

 

我们对 Llama 4 Scout 的所有线性层(注意力线性层、共享专家、路由器和路由专家)、注意力以及 KV 缓存都使用 BF16 精度进行了基准测试。

最后,我们希望社区能够不断打破记录,提高 Llama 4 模型服务的效率,并期待报告更好的数据。

致谢

我们衷心感谢Jing Zhang, Ying Zhang, 和 Manman Ren 提供的技术评审和指导。

我们还要感谢Bradley Davis, Yan Cui, Rengan Xu, Josh Fromm, Jiawen Liu, Sarunya Pumma, Jie Wang, Xinfeng Xie, Benson Ma, Michael Shu, Bingzhe Liu, Jingyi Yang, Min Si, Pavan Balaji, Dhruva Kaushal. 对此项目的贡献。