简而言之:我们最近展示了 Llama4 Scout 的训练速度提升了 30.2%,且收敛效果与 bfloat16 相当。这得益于我们在 TorchAO 中使用了 MXFP8 MoE 训练原语!在给定通用矩阵乘法(GEMM)和分组 GEMM 转换为 MXFP8 的情况下,这大约是该模型训练配置理论最大可实现加速比(roofline 1.37 倍)的 81%。这些实验是在 Crusoe Cloud 上进行的。
在本博客中,我们将讨论
- 训练运行结果,以及如何使用 TorchTitan 和 TorchAO 复现这些结果
- MoE 训练中动态量化 MXFP8 分组 GEMM 的前向和反向传播的深度解析
收敛性实验及结果
我们在一个包含 64 个节点/256 个设备的 GB200 集群上使用 TorchTitan Llama4 Scout 进行的训练运行表明,其收敛效果与 bfloat16 训练基线相当。这与我们之前关于 密集模型 MXFP8 训练 的扩展实验结果一致。我们使用了以下训练配置:
- 模型:Llama4 Scout
- 数据集:C4
- 序列长度:8192
- 本地批次大小:1
- 学习率:1e-4
- 学习率调度器预热步数:2000
- 并行策略
- FSDP=256(用于注意力层、共享专家、密集层 FFN)以及 256/4=64(用于路由专家)
- EP=16(用于路由专家)
- 激活检查点模式:全量(重新计算所有中间激活值而非存储它们,以降低峰值内存需求)
- TorchTitan 中的 torch.compile 已在以下组件上启用:模型、损失函数
- mxfp8 应用于路由专家计算(分组 GEMM)
- mxfp8 应用于所有线性层,除匹配以下 FQN 的层外:output, router.gate, attention.wk, attention.wv
- 输出嵌入投影对低精度过于敏感,会对收敛产生不利影响
- Wk/Wv 太小,无法从动态 mxfp8 量化中获得明显的性能收益
版本
- torch:2.11.0.dev20260122+cu130
- torchtitan:0.2.1
- torchao:0.17.0+git41e02b5fb
展示了 3000 多步训练中与 bf16 等效收敛性的训练损失曲线
我们进行了一项长时间的收敛实验(3000 步),以评估 bfloat16 基线与 mxfp8 之间的收敛行为是否等效。为了保持集群上的时间窗口可控,我们将运行的本地批次大小设为 1。所绘出的损失曲线显示训练损失几乎完全相同。

性能基准测试
接下来,为了评估 MXFP8 可实现的性能提升,我们将本地批次大小从 1 增加到 16,这在提高稀疏激活的 MoE 中的 GEMM 效率时是很常见的。这使得端到端训练速度比相同配置的 bf16 提高了 30.2%。
| GPU 数量 | BF16 tokens/秒 | MXFP8 tokens/秒 | MXFP8 相对 BF16 加速比 |
| 256 | 5317 | 6921 | 30.2% |
推动此加速的引擎是 _to_mxfp8_then_scaled_grouped_mm 操作,对于这些形状,它比编译后的 bf16 快约 1.8 倍。将此用于路由专家,可使 MOE 层速度提升 1.43 倍,使 Llama4 Scout 的端到端训练速度比编译后的 bf16 快 1.2 倍。此外,当我们也将 MXFP8 用于共享专家线性层时,端到端训练速度达到了比编译后的 bf16 快 1.3 倍的效果。请参阅附录中的微基准测试表。
在下一节中,我们将介绍 TorchAO API,并提供关于 mxfp8 及其在缩放分组 GEMM 中应用的更多技术细节。
用于 MXFP8 MoE 训练的 TorchTitan 配置
对于上一节的结果,我们依赖 TorchTitan 作为我们的训练框架。要使用 TorchTitan 进行 MoE 的 MXFP8 训练,请查看 文档,其中详细说明了必要的配置和示例。
TorchAO MXFP8 MoE 训练 API
如果您不使用 TorchTitan,也可以直接使用 TorchAO 原语。TorchAO 最近添加了一个原型 API _to_mxfp8_then_scaled_grouped_mm,其功能正如其名:将分组 GEMM 输入(激活和权重)量化为 mxfp8,然后使用 mxfp8 操作数进行缩放分组 GEMM,并以原始精度输出结果。该原语是可微分的,因此可以直接用于训练。查看 文档以获取详细的微基准测试、不同流行模型所用形状的 roofline 分析等内容。
该原语的目标是相对于 bf16 分组 GEMM 基线实现净加速。通过将输入动态量化为 mxfp8,我们可以使用 mxfp8 缩放分组 GEMM,与 bf16 相比,它可实现高达 2 倍的 TFLOPs/秒。因此,只要我们的量化内核足够快且不会引入过多的开销,我们应该能够实现净加速,如下图所示:

动态 MXFP8 量化 + 缩放分组 GEMM 的前向和反向传播示意图
让我们巡视一下前向和反向传播,从输入激活进入路由专家的前向传播那一刻开始!
我们的起点是 MoE 层中路由专家计算之前的瞬间。需要明确的是,在 MoE 层执行的这一点上,已经完成了以下步骤:
- 路由已计算出每个 token 的专家亲和度得分(token 选择路由)
- 根据得分将 token “分配”给 Top K 专家
- 对于专家并行,通过 all-to-all 通信将 token 分派到目标专家所在的设备上
- 每个设备执行 token 洗牌/置换,使得 token 按专家分组,且组的排序顺序与专家权重的排序顺序一致
- 此洗牌操作生成一个名为 offsets 的张量,它存储了扁平化后的 2D token 组张量中每个组的结束索引。
因此,我们拥有如下所示的用于路由专家的 高精度 输入激活和权重。

MXFP8 量化
这是关键点,我们现在将动态量化 输入激活和权重到 MXFP8,得到 float8 e4m3 数据和 float8 e8m0 比例因子。
每个 1×32 的高精度输入数据块共享一个 e8m0 比例因子,该因子用于缩放数值以填充 float8 e4m3 数据类型的动态范围。

将每组比例因子写入沿行(M)有组边界的块布局中
为了在 Blackwell GPU 上进行高效的分组 GEMM,我们希望我们的内核使用该硬件的第五代张量核心(tensorcores),这是最大化这些芯片计算吞吐量(TFLOPs/秒)的关键。这些张量核心使用 tcgen05 系列 PTX 指令 编程,这是 CUDA 等熟悉的内核语言在编译过程中会转换成的汇编级指令。
特别是对于 MXFP8 分组 GEMM,我们需要使用 块缩放变体的 tcgen05.mma PTX 指令。该指令对 MXFP8 数据中 e8m0 比例因子的布局有一些特殊要求。
具体来说,比例因子必须以非传统的 块布局 驻留在张量内存(TMEM)中,可以在此处的 NVIDIA 文档中看到:

来源:https://docs.nvda.net.cn/cuda/cublas/index.html#d-block-scaling-factors-layout
因此,在我们的量化内核中,我们需要将比例因子写入此布局,以便张量核心可以使用它们。为了将比例因子从简单的行优先布局转换为这种块布局,我们需要做 3 件事:
- 填充每个 token 组的比例因子,使其能被 128×4 的分块整除。
- 在每个 128×4 分块内,应用从简单行优先到 ((32,4), 4) 布局的变换,这意味着逻辑上我们在内存中水平排列了 4 个 (32,4) 分块。你可以将其想象为一个 (32,16) 形状的张量,其中每个 16 字节的“超行”现在包含 4 行,每行来自一个分块。这些超行在内存中是连续的。
- 以块级粒度将这些分块写入“块行”优先布局,如下所示(例如,我们在继续沿着行进行下一个分块之前,将整个分块作为连续内存写入)
以下是此块布局变换前后的对比图,可以更直观地理解。更多详细信息,请参考 NVIDIA 文档。


在考虑分组时,实际内存布局如下所示。在前向传递的 2D-3D MXFP8 分组 GEMM 中,组沿着 M 维度(行)排列

转换为这种非传统布局时需要非常小心,以避免产生过多的开销,否则将抵消我们的净加速(甚至可能导致性能下降!)。
等等,还没完!这只是针对 单个比例因子 的考虑。请记住,我们正在执行一个分组 GEMM,它并行化了我们的路由专家 GEMM 并在单个内核中执行它们。每个 GEMM 的数据和比例因子 生活在同一个张量中,但所有数据和比例因子都必须单独满足我们讨论过的 tcgen05.mma PTX 指令的要求。这意味着我们不能直接对比例张量应用布局变换;相反,我们需要分别对每个组进行这种布局转换。
此外,每个专家的 token 组大小是动态的,并且仅在设备端可知!这意味着在主机端的 PyTorch 代码中,我们无法分派内核来单独转换每个组的内存布局,因为主机没有组大小信息。如果要获取这些信息,需要进行主机设备同步,这会在我们的 GPU 内核执行流中导致较大的间隙(空闲时间),这对性能来说是非常糟糕的!因此,我们必须设计一个能够完全在设备端高效执行此任务的自定义内核。
事实证明,这是一个非常有趣且不同寻常的内核开发任务——值得单独写一篇深度解析文章。与此同时,让我们进入 MXFP8 分组 GEMM 的下一步,即在动态量化输入并将比例因子写入每组块布局之后。
用于前向输出的 2D–3D MXFP8 分组 GEMM
数据和比例因子已就位在正确的内存布局中,我们终于准备好在 2D 输入激活和 3D 权重之间执行 MXFP8 分组 GEMM!这为我们提供了一个 2D 输出,其中每个 token 都已投影到隐藏维度。分解开来看,它是这样的:
反向传播;用于输入梯度计算的 2D–3D MXFP8 分组 GEMM
输入的梯度与前向输出的 2D-3D 缩放分组 GEMM 非常相似,所以这里不再赘述,因为篇幅已经很长了。具体公式为:
dgrad = dO @ weight
输出梯度与我们的前向输出具有相同的形状。它是一个形状为 (total_M, N) 的 2D 张量。权重仍然是我们的权重,一个形状为 (E, N, K) 的 3D 张量。
所以我们以在前向传播中量化输入激活相同的方式量化 2D 输出梯度,并以非常相似的方式量化我们的权重。不同之处在于它们是非转置的、行优先格式,并且我们正在将它们写入每专家列优先格式。所以内核有点不同,但这是唯一的区别。
反向传播;用于权重梯度计算的 2D–2D MXFP8 分组 GEMM
权重的梯度更值得探讨。这是因为它涉及一种完全不同类型的分组 GEMM,面临不同的挑战。计算权重梯度的公式是:
dW = dO^T @ X
这是一个形状为 (N, total_M) @ (total_M, K) 的 2D-2D 分组 GEMM。如你所见,组现在沿着 GEMM 的收缩维度!

将每组比例因子写入块格式,组沿收缩维度(K)排列
这有什么变化?好吧,一个将比例因子转换为组沿 M 维度(行)的块格式的内核,对于本例中组沿收缩或 K 维度的情况不再适用了。这是因为输入比例因子是行优先布局,因此沿着列边界将张量切分成组,从根本上改变了步长,使得步长 每组动态变化。所以我们需要一个处理每组不同步长情况的内核。在这种情况下,我们需要计算每行中 128×4 分块的数量,我们的步长将是单块步长 乘以该行中分块的数量。听起来很复杂,确实如此,但这有一个图表可以帮助可视化并使理解变得更容易:

分组 GEMM 中的每个独立 GEMM 产生一个 2D 输出,这些输出堆叠成最终的 3D 结果。分解开来看,它是这样的:

就是这样!此时,希望你对动态量化 MXFP8 分组 GEMM 的前向和反向传播的底层原理有了更好的理解!
以上就是全部内容——我们希望你阅读愉快,并记得查看 TorchAO MoE 训练文档,其中包含基准测试、示例等,助你入门!
未来工作
TorchAO 中的 MXFP8 MoE 训练仍然是一个原型功能,在将其升级为稳定版之前,我们正在积极进行一些改进,即:
- 统一密集模型和稀疏/MoE 模型的 MXFP8 训练 API:目前,TorchAO 有独立的 API 用于转换模型的 nn.Linear 层和 torch._grouped_mm 操作以使用 MXFP8。我们正在努力统一这些 API 以简化用户体验。
- 用于专家并行通信的 MXFP8:此外,除了 MXFP8 分组 GEMM 之外,我们还有用于高效专家并行 训练的自动微分函数原型,它在 all-to-all 通信之前更早地量化为 MXFP8,并在分组 GEMM 中保持为 MXFP8,从而节省网络带宽并带来加速。敬请期待后续关于此的文章!
附录
微基准测试:动态 MXFP8 量化 + MXFP8 分组 GEMM 相比 bf16 分组 GEMM 的净加速
以下是一些微基准测试,比较了用于支持 MXFP8 MoE 训练的 自动微分函数 的前向和反向传播的总持续时间,与 bf16 torch._grouped_mm 基线相比,针对的是近期 MoE 模型架构中使用的形状。
M = local_batch_size * sequence_length
G = number_of_experts_on_local_rank
N, K = expert_dimensions
Llama4 Scout 形状
| M,N,K,G | BF16 前向 + 反向 (微秒) | MXFP8 前向 + 反向 (微秒) | MXFP8 相对 BF16 加速比 |
| (128000, 8192, 5120, 1) | 43140.20 | 23867.00 | 1.808x |
| (128000, 8192, 5120, 2) | 39487.60 | 23359.00 | 1.690x |
| (128000, 8192, 5120, 4) | 39189.20 | 23945.50 | 1.637x |
| (128000, 8192, 5120, 8) | 37700.70 | 22170.60 | 1.700x |
你可以参考 文档获取在 B200 GPU 上复现这些基准测试的命令。
MoE 层基准测试
在单个 B200 上单个 MoE 层的微基准测试也显示,与 bf16 基线相比,MXFP8 实现的 MoE 层执行速度最高快 1.43 倍。
| 模型 | total_M | N | K | bf16 时间 (ms) | mxfp8 时间 (ms) | 加速比 |
| Llama4 16e | 131072 | 8192 | 5120 | 275.270 | 192.420 | 1.431x |
你可以参考 文档获取在 B200 GPU 上复现这些基准测试的命令。