博客

面向 MoE 的 MXFP8 训练:在 GB200 集群上结合 TorchAO 和 TorchTitan,使 Llama4 Scout 的训练速度较 BF16 提升 1.3 倍

简而言之:我们最近展示了 Llama4 Scout 的训练速度提升了 30.2%,且收敛效果与 bfloat16 相当。这得益于我们在 TorchAO 中使用了 MXFP8 MoE 训练原语!在给定通用矩阵乘法(GEMM)和分组 GEMM 转换为 MXFP8 的情况下,这大约是该模型训练配置理论最大可实现加速比(roofline 1.37 倍)的 81%。这些实验是在 Crusoe Cloud 上进行的。

在本博客中,我们将讨论

  1. 训练运行结果,以及如何使用 TorchTitan 和 TorchAO 复现这些结果
  2. MoE 训练中动态量化 MXFP8 分组 GEMM 的前向和反向传播的深度解析

收敛性实验及结果

我们在一个包含 64 个节点/256 个设备的 GB200 集群上使用 TorchTitan Llama4 Scout 进行的训练运行表明,其收敛效果与 bfloat16 训练基线相当。这与我们之前关于 密集模型 MXFP8 训练 的扩展实验结果一致。我们使用了以下训练配置:

  • 模型:Llama4 Scout
  • 数据集:C4
  • 序列长度:8192
  • 本地批次大小:1
  • 学习率:1e-4
  • 学习率调度器预热步数:2000
  • 并行策略
    • FSDP=256(用于注意力层、共享专家、密集层 FFN)以及 256/4=64(用于路由专家)
    • EP=16(用于路由专家)
  • 激活检查点模式:全量(重新计算所有中间激活值而非存储它们,以降低峰值内存需求)
  • TorchTitan 中的 torch.compile 已在以下组件上启用:模型、损失函数
  • mxfp8 应用于路由专家计算(分组 GEMM)
  • mxfp8 应用于所有线性层,除匹配以下 FQN 的层外:output, router.gate, attention.wk, attention.wv
    • 输出嵌入投影对低精度过于敏感,会对收敛产生不利影响
    • Wk/Wv 太小,无法从动态 mxfp8 量化中获得明显的性能收益

版本

  • torch:2.11.0.dev20260122+cu130
  • torchtitan:0.2.1
  • torchao:0.17.0+git41e02b5fb

展示了 3000 多步训练中与 bf16 等效收敛性的训练损失曲线

我们进行了一项长时间的收敛实验(3000 步),以评估 bfloat16 基线与 mxfp8 之间的收敛行为是否等效。为了保持集群上的时间窗口可控,我们将运行的本地批次大小设为 1。所绘出的损失曲线显示训练损失几乎完全相同。

性能基准测试

接下来,为了评估 MXFP8 可实现的性能提升,我们将本地批次大小从 1 增加到 16,这在提高稀疏激活的 MoE 中的 GEMM 效率时是很常见的。这使得端到端训练速度比相同配置的 bf16 提高了 30.2%

GPU 数量 BF16 tokens/秒 MXFP8 tokens/秒 MXFP8 相对 BF16 加速比
256 5317 6921 30.2%

推动此加速的引擎是 _to_mxfp8_then_scaled_grouped_mm 操作,对于这些形状,它比编译后的 bf16 快约 1.8 倍。将此用于路由专家,可使 MOE 层速度提升 1.43 倍,使 Llama4 Scout 的端到端训练速度比编译后的 bf16 快 1.2 倍。此外,当我们也将 MXFP8 用于共享专家线性层时,端到端训练速度达到了比编译后的 bf16 快 1.3 倍的效果。请参阅附录中的微基准测试表。

在下一节中,我们将介绍 TorchAO API,并提供关于 mxfp8 及其在缩放分组 GEMM 中应用的更多技术细节。

用于 MXFP8 MoE 训练的 TorchTitan 配置

对于上一节的结果,我们依赖 TorchTitan 作为我们的训练框架。要使用 TorchTitan 进行 MoE 的 MXFP8 训练,请查看 文档,其中详细说明了必要的配置和示例。

TorchAO MXFP8 MoE 训练 API

如果您不使用 TorchTitan,也可以直接使用 TorchAO 原语。TorchAO 最近添加了一个原型 API _to_mxfp8_then_scaled_grouped_mm,其功能正如其名:将分组 GEMM 输入(激活和权重)量化为 mxfp8,然后使用 mxfp8 操作数进行缩放分组 GEMM,并以原始精度输出结果。该原语是可微分的,因此可以直接用于训练。查看 文档以获取详细的微基准测试、不同流行模型所用形状的 roofline 分析等内容。

该原语的目标是相对于 bf16 分组 GEMM 基线实现净加速。通过将输入动态量化为 mxfp8,我们可以使用 mxfp8 缩放分组 GEMM,与 bf16 相比,它可实现高达 2 倍的 TFLOPs/秒。因此,只要我们的量化内核足够快且不会引入过多的开销,我们应该能够实现净加速,如下图所示:

动态 MXFP8 量化 + 缩放分组 GEMM 的前向和反向传播示意图

让我们巡视一下前向和反向传播,从输入激活进入路由专家的前向传播那一刻开始!

我们的起点是 MoE 层中路由专家计算之前的瞬间。需要明确的是,在 MoE 层执行的这一点上,已经完成了以下步骤:

  • 路由已计算出每个 token 的专家亲和度得分(token 选择路由)
  • 根据得分将 token “分配”给 Top K 专家
  • 对于专家并行,通过 all-to-all 通信将 token 分派到目标专家所在的设备上
  • 每个设备执行 token 洗牌/置换,使得 token 按专家分组,且组的排序顺序与专家权重的排序顺序一致
    • 此洗牌操作生成一个名为 offsets 的张量,它存储了扁平化后的 2D token 组张量中每个组的结束索引。

因此,我们拥有如下所示的用于路由专家的 高精度 输入激活和权重。

MXFP8 量化

这是关键点,我们现在将动态量化 输入激活和权重到 MXFP8,得到 float8 e4m3 数据和 float8 e8m0 比例因子。

每个 1×32 的高精度输入数据块共享一个 e8m0 比例因子,该因子用于缩放数值以填充 float8 e4m3 数据类型的动态范围。

将每组比例因子写入沿行(M)有组边界的块布局中

为了在 Blackwell GPU 上进行高效的分组 GEMM,我们希望我们的内核使用该硬件的第五代张量核心(tensorcores),这是最大化这些芯片计算吞吐量(TFLOPs/秒)的关键。这些张量核心使用 tcgen05 系列 PTX 指令 编程,这是 CUDA 等熟悉的内核语言在编译过程中会转换成的汇编级指令。

特别是对于 MXFP8 分组 GEMM,我们需要使用 块缩放变体的 tcgen05.mma PTX 指令。该指令对 MXFP8 数据中 e8m0 比例因子的布局有一些特殊要求。

具体来说,比例因子必须以非传统的 块布局 驻留在张量内存(TMEM)中,可以在此处的 NVIDIA 文档中看到:

来源:https://docs.nvda.net.cn/cuda/cublas/index.html#d-block-scaling-factors-layout

因此,在我们的量化内核中,我们需要将比例因子写入此布局,以便张量核心可以使用它们。为了将比例因子从简单的行优先布局转换为这种块布局,我们需要做 3 件事:

  1. 填充每个 token 组的比例因子,使其能被 128×4 的分块整除。
  2. 在每个 128×4 分块内,应用从简单行优先到 ((32,4), 4) 布局的变换,这意味着逻辑上我们在内存中水平排列了 4 个 (32,4) 分块。你可以将其想象为一个 (32,16) 形状的张量,其中每个 16 字节的“超行”现在包含 4 行,每行来自一个分块。这些超行在内存中是连续的。
  3. 以块级粒度将这些分块写入“块行”优先布局,如下所示(例如,我们在继续沿着行进行下一个分块之前,将整个分块作为连续内存写入)

以下是此块布局变换前后的对比图,可以更直观地理解。更多详细信息,请参考 NVIDIA 文档

在考虑分组时,实际内存布局如下所示。在前向传递的 2D-3D MXFP8 分组 GEMM 中,组沿着 M 维度(行)排列

转换为这种非传统布局时需要非常小心,以避免产生过多的开销,否则将抵消我们的净加速(甚至可能导致性能下降!)。

等等,还没完!这只是针对 单个比例因子 的考虑。请记住,我们正在执行一个分组 GEMM,它并行化了我们的路由专家 GEMM 并在单个内核中执行它们。每个 GEMM 的数据和比例因子 生活在同一个张量中,但所有数据和比例因子都必须单独满足我们讨论过的 tcgen05.mma PTX 指令的要求。这意味着我们不能直接对比例张量应用布局变换;相反,我们需要分别对每个组进行这种布局转换。

此外,每个专家的 token 组大小是动态的,并且仅在设备端可知!这意味着在主机端的 PyTorch 代码中,我们无法分派内核来单独转换每个组的内存布局,因为主机没有组大小信息。如果要获取这些信息,需要进行主机设备同步,这会在我们的 GPU 内核执行流中导致较大的间隙(空闲时间),这对性能来说是非常糟糕的!因此,我们必须设计一个能够完全在设备端高效执行此任务的自定义内核。

事实证明,这是一个非常有趣且不同寻常的内核开发任务——值得单独写一篇深度解析文章。与此同时,让我们进入 MXFP8 分组 GEMM 的下一步,即在动态量化输入并将比例因子写入每组块布局之后。

用于前向输出的 2D–3D MXFP8 分组 GEMM

数据和比例因子已就位在正确的内存布局中,我们终于准备好在 2D 输入激活和 3D 权重之间执行 MXFP8 分组 GEMM!这为我们提供了一个 2D 输出,其中每个 token 都已投影到隐藏维度。分解开来看,它是这样的:

反向传播;用于输入梯度计算的 2D–3D MXFP8 分组 GEMM

输入的梯度与前向输出的 2D-3D 缩放分组 GEMM 非常相似,所以这里不再赘述,因为篇幅已经很长了。具体公式为:

dgrad = dO @ weight

输出梯度与我们的前向输出具有相同的形状。它是一个形状为 (total_M, N) 的 2D 张量。权重仍然是我们的权重,一个形状为 (E, N, K) 的 3D 张量。

所以我们以在前向传播中量化输入激活相同的方式量化 2D 输出梯度,并以非常相似的方式量化我们的权重。不同之处在于它们是非转置的、行优先格式,并且我们正在将它们写入每专家列优先格式。所以内核有点不同,但这是唯一的区别。

反向传播;用于权重梯度计算的 2D–2D MXFP8 分组 GEMM

权重的梯度更值得探讨。这是因为它涉及一种完全不同类型的分组 GEMM,面临不同的挑战。计算权重梯度的公式是:

dW = dO^T @ X

这是一个形状为 (N, total_M) @ (total_M, K) 的 2D-2D 分组 GEMM。如你所见,组现在沿着 GEMM 的收缩维度!

将每组比例因子写入块格式,组沿收缩维度(K)排列

这有什么变化?好吧,一个将比例因子转换为组沿 M 维度(行)的块格式的内核,对于本例中组沿收缩或 K 维度的情况不再适用了。这是因为输入比例因子是行优先布局,因此沿着列边界将张量切分成组,从根本上改变了步长,使得步长 每组动态变化。所以我们需要一个处理每组不同步长情况的内核。在这种情况下,我们需要计算每行中 128×4 分块的数量,我们的步长将是单块步长 乘以该行中分块的数量。听起来很复杂,确实如此,但这有一个图表可以帮助可视化并使理解变得更容易:

分组 GEMM 中的每个独立 GEMM 产生一个 2D 输出,这些输出堆叠成最终的 3D 结果。分解开来看,它是这样的:

就是这样!此时,希望你对动态量化 MXFP8 分组 GEMM 的前向和反向传播的底层原理有了更好的理解!

以上就是全部内容——我们希望你阅读愉快,并记得查看 TorchAO MoE 训练文档,其中包含基准测试、示例等,助你入门!

未来工作

TorchAO 中的 MXFP8 MoE 训练仍然是一个原型功能,在将其升级为稳定版之前,我们正在积极进行一些改进,即:

  • 统一密集模型和稀疏/MoE 模型的 MXFP8 训练 API:目前,TorchAO 有独立的 API 用于转换模型的 nn.Linear 层和 torch._grouped_mm 操作以使用 MXFP8。我们正在努力统一这些 API 以简化用户体验。
  • 用于专家并行通信的 MXFP8:此外,除了 MXFP8 分组 GEMM 之外,我们还有用于高效专家并行 训练的自动微分函数原型,它在 all-to-all 通信之前更早地量化为 MXFP8,并在分组 GEMM 中保持为 MXFP8,从而节省网络带宽并带来加速。敬请期待后续关于此的文章!

附录

微基准测试:动态 MXFP8 量化 + MXFP8 分组 GEMM 相比 bf16 分组 GEMM 的净加速

以下是一些微基准测试,比较了用于支持 MXFP8 MoE 训练的 自动微分函数 的前向和反向传播的总持续时间,与 bf16 torch._grouped_mm 基线相比,针对的是近期 MoE 模型架构中使用的形状。

M = local_batch_size * sequence_length

G = number_of_experts_on_local_rank

N, K = expert_dimensions

Llama4 Scout 形状

M,N,K,G BF16 前向 + 反向 (微秒) MXFP8 前向 + 反向 (微秒) MXFP8 相对 BF16 加速比
(128000, 8192, 5120, 1) 43140.20 23867.00 1.808x
(128000, 8192, 5120, 2) 39487.60 23359.00 1.690x
(128000, 8192, 5120, 4) 39189.20 23945.50 1.637x
(128000, 8192, 5120, 8) 37700.70 22170.60 1.700x

你可以参考 文档获取在 B200 GPU 上复现这些基准测试的命令。

MoE 层基准测试

在单个 B200 上单个 MoE 层的微基准测试也显示,与 bf16 基线相比,MXFP8 实现的 MoE 层执行速度最高快 1.43 倍。

模型 total_M N K bf16 时间 (ms) mxfp8 时间 (ms) 加速比
Llama4 16e 131072 8192 5120 275.270 192.420 1.431x

你可以参考 文档获取在 B200 GPU 上复现这些基准测试的命令。