Meta:Less Wright、Hamid Shojanazeri、Vasiliy Kuznetsov、Daniel Vega-Myhre、Gokul Nadathur、Will Constable、Tianyu Liu、Tristan Rice、Driss Guessous、Josh Fromm、Luca Wehrstedt、Jiecao Yu Crusoe:Ethan Petersen、Martin Cala、Chip Smith
我们与 Crusoe.AI 合作,获得了他们在冰岛的新 2K H200 集群的访问权限,这使我们能够通过利用 TorchTitan 的 HSDP2 和 TorchAO 的新 float8 rowwise,展示在规模上将训练加速 34 – 43%,并与 BF16 具有可比的收敛性和稳定性。
在这篇文章中,我们将详细介绍 H200 与 PyTorch 新的 Float8 行式训练以及 TorchTitan 的 FSDP2/HSDP2 和大规模 CP 的协同作用。
背景 – 什么是 H200?
H200 是“增强型”H100,提供与 H100 完全相同的计算能力,但有两个额外的改进。
- 更大的全局内存,141GiB HBM3e vs 标准 80GiB HBM3
- 内存带宽快约 43%,达到 4.8TB/s vs 3.35 TB/s。更快的内存传输对训练速度有巨大影响,尤其是对于 PyTorch 的 AsyncTP。
什么是 PyTorch Float8 行式?
Float 8 行式是 Float8 的更精细粒度分辨率,与之前的“张量级”Float8 不同。它旨在确保更精细的精度,以支持更大规模的工作负载,这些工作负载随着规模的扩大和训练的进展,往往对量化变得更加敏感。
Float8 行式有两个关键改进
- 现在每行都保持自己的缩放因子,而不是整个张量的单个缩放因子,从而提高了量化精度。每行更精细的缩放有助于减少离群值的影响(极端值迫使量化缩放因子拉伸并降低正态分布值的精度),从而确保更好的精度。
- 缩放因子本身现在通过向下舍入到最接近的 2 的幂来实现。这已被证明有助于减少在乘以/除以缩放因子时的量化误差,并确保大值在正向和反向传递中都缩放到相同的值。
请注意,其他大型模型已在 2K 规模下使用 Float8 进行了训练,结合了 1x128 组式和 128x128 块式,并采用 2 的幂次缩放因子。它们的目标相同,都是为了提高 Float8 的精度以支持大规模训练。
因此,Float8 行式提供了类似的前景,可以在非常大规模的训练中启用 Float8,但我们希望提供在规模上稳定性和收敛性的证据,Crusoe H200 2k 集群上的训练为此提供了初步验证。
展示 Float8 行式损失收敛与 BF16 在 1600 和 1920 GPU 规模下的表现
为了验证可比的损失收敛性,我们使用 TorchTitan 和 Llama3 70B 在 1920 和 1600 (1.6k) GPU 规模下分别进行了两次运行。1.6K GPU 运行设置为 2.5k 次迭代,使用 TorchTitans 的 HSDP2 和上下文并行以启用 2D 并行。
损失收敛测试使用 Titan 的确定性模式运行——这种模式有效地冻结了运行之间大多数潜在的变异来源,从而有助于确保唯一实质性变化是我们想要测试的,即 BF16 与 Float8 行式的损失收敛和损失曲线。
请注意,确定性模式也会降低训练速度,因为各种内核不会进行自动调整以最大化吞吐量(否则我们可能会在运行之间使用不同的内核并引入方差)。
完成了两次运行,一次使用 BF16,另一次使用 Float8 行式。
两次运行都完成了分配的 2.5k 迭代,没有出现问题,展示了 Crusoe 集群的稳定性,其中 FP8 恰好在 24 小时内完成,BF16 在 31 小时 19 分钟后完成。
数据类型 | 时间/迭代数 | 损失 |
BF16 | 24 小时 | 3.15453 |
Float8 行式 | 24 小时 | 2.86386 |
BF16 | 31 小时 19 分钟 / 2.5K | 2.88109 |
Float8 行式 | 24 小时 / 2.5K | 2.86386 |
在 24 小时标记处,Float8 完成了 2.5K 次迭代,展示了 float8 训练的相对加速(即使在确定性模式下)。在 24 小时标记处,对于相同的 24 小时大规模训练时间,Float8 在损失方面相对于 BF16 实现了 +9.21% 的相对改进。
31 小时 19 分钟后,BF16 运行最终完成了 2.5k 次迭代。
最终损失数值
BF16 = 2.88109 Float8 = 2.86386
从损失曲线上,我们观察到在前三分之一和最后三分之一处曲线非常相似,然后中间有一个湍流区域,两者都显示出相似的尖峰,但尖峰的相对时间略有偏差。
因此,我们可以看到 PyTorch 的 Float8 行式提供了相似的收敛性,但在相同的训练时间下,速度提高了 33% 以上。
Float8 行式的长期训练稳定性
除了展示可比的收敛性,我们还希望展示 Float8 的长期训练稳定性,因此我们启动了一个 4 天、15K 次迭代的 256 规模运行。
如上图所示,Float8 训练运行超过 100 小时,没有出现任何问题,凸显了 Float8 行式的长期稳定性。
TorchTitan 中的确定性
为了验证确定性并查看较长运行中的尖峰是否来自规模,我们还进行了一次较小的运行,包括 2 次 BF16 运行和 1 次 Float8 运行,规模为 256,并且仅使用 HSDP2(即不使用 2D 上下文并行)。
在这种情况下,两个 BF16 运行都具有相同的曲线和最终损失,并且我们观察到所有三次运行都有相似的尖峰区域。
在 2K 迭代标记处,Float8 和 BF16 几乎在相同的点结束。
BF16 *2 = 3.28538
Float8 行式 = 3.28203
上述结果证实,CP 和规模 (2k) 都不是造成损失尖峰的原因,因为我们在 256 规模上也看到了类似的效果。损失尖峰最可能的解释可能是数据集中的内容分布。
为了确定性,实验使用序列化的 C4 数据集(未打乱)运行,这意味着尖峰可能来自数据集中遇到新内容。
Float8 行式在不同规模下的实际加速
我们在不同的 GPU 规模下进行了较短的运行,以了解 Float8 Rowwise 在集群规模扩大时,在训练加速方面的表现。从 960 到 1920,Float8 继续提供令人印象深刻的训练加速,与 BF16 相比,性能提升了 34-43%。我们还注意到,从 1k 到 2k GPU 的扩展可能会引入通信开销,我们观察到 BF16 的吞吐量下降了 4%。
如上所示的在规模上进行的长时间训练运行,Float8 行式提供了显著的加速,具有相同甚至略微改善的损失终点,同时在 1920 (DeepSeek) 规模下提供了 34% 的加速。
如何在训练中使用 Float8 行式?
Float8 行式现在可供您在大规模训练中使用。它封装在 TorchAO 的最新构建(0.9 及更高版本)中,如果您想快速启动并运行,它已原生集成到 TorchTitan 中。
在 TorchTitan 中激活 Float8 行式
首先,在您的模型 .toml 文件中启用模型转换器,将 nn.linears 热插拔到 float8 线性层中——参见第 29 行
其次,指定“rowwise”float8 方案——参见第 72 行
请注意,您有三种“recipe_name”选择
- rowwise,这是推荐的默认值,
- tensorwise(旧式 float8)和
- rowwise_with_gw_hp。
gw_hp 行式选项在反向传播过程中将权重的梯度保持在 BF16 精度,这可以进一步提高 Float8 对极其敏感的工作负载的精度。但是,如果模型中大多数矩阵乘法的尺寸较小(H100 上估计的临界点大约在 13-16K 维度),它反而可能比通用行式性能更好。
因此,虽然我们推荐将 rowwise 作为默认值,但在您的模型上与 gw_hp 进行比较可能值得一试,以验证哪个能提供最佳性能,并且具有更高的精度。
通过使用 # 切换模型转换器,您可以直接比较 BF16 和 Float8 行式之间的训练加速,以了解您自己训练的潜在加速。
未来更新
我们将推出一个额外的更新,展示并行管道和异步分布式检查点的多项改进,敬请期待。