博客

使用 float8 和 FSDP2 超级训练

作者: 2024年11月25日2025年5月5日暂无评论

IBM:Tuan Hoang Trong, Alexei Karve, Yan Koyfman, Linsong Chu, Divya Kumari, Shweta Salaria, Robert Walkup, Praneet Adusumilli, Nirmit Desai, Raghu Ganti, Seetharami Seelam
Meta:Less Wright, Wei Feng, Vasiliy Kuznetsov, Driss Guesseous

在本篇博客中,我们将展示如何实现高达 50% 的吞吐量加速,同时在训练中保持与 FSDP1 bf16 训练相当的损失值和评估基准。我们通过利用 FSDP2、DTensor 和 torch.compile,结合 torchao 的 float8 线性层更新(计算)以及用于权重通信的 float8 all_gather 来实现这一加速。我们展示了这些改进在多种 Meta LLaMa 模型架构规模(从 1.8B 的小模型到 405B 的超大模型)上的表现,使训练速度达到前所未有的水平。

我们使用 Meta Llama3 架构演示了这些改进,并在两种规模下进行了模型质量研究:8B 模型规模下的 100B token 训练,以及 70B 模型规模下的 50B token 训练,从而提供了 float8 与 bf16 训练损失曲线的精确对比。我们证明,与 bf16 对照组相比,这些模型训练运行的损失曲线实现了完全一致的收敛。此外,我们使用 FineWeb-edu 数据集将一个 3B 模型训练到了 1T token,并运行了标准评估基准,以确保模型质量保持完整,且与 bf16 运行结果相当。

在 IBM 研究院,我们计划将这些功能应用于数据消融实验,以提高在给定 GPU 预算下可执行的实验数量。从长远来看,我们将进行更大规模的模型运行,以演示 float8 训练的端到端可行性。

什么是 Float8?

用于模型训练的 float8 格式由 NVIDIA、ARM 和 Intel 在一份 2022 年的论文中提出,证明了使用低精度 float8 进行训练且不牺牲模型质量的可行性。随着 NVIDIA Hopper 系列等新型 GPU 的推出,FP8 训练变得可行,得益于原生的 float8 张量核心支持,训练吞吐量有望提升 2 倍以上。要实现这一目标,还存在几个挑战:
(i) 在 float8 中启用核心模型操作,如 matmulattention
(ii) 在分布式框架中启用 float8 训练;以及
(iii) 在 float8 中启用 GPU 间的权重通信。
虽然 float8 matmul 已由 NVIDIA 库启用,但后两项功能是在最近的 FSDP2torchao 更新中提供的。

在本博客中,我们使用 torchtitan 作为训练入口,结合 IBM 的确定性数据加载器、来自 torchaofloat8 线性层实现,以及来自最新 PyTorch 每日构建版(nightlies)并配合 FSDP2 使用的 float8 all gather。在本次训练中,我们使用的是 float8 按张量(tensorwise)缩放粒度,而非按行(rowwise)缩放。我们利用 torch.compile 来确保获得最大的性能提升。目前,我们正使用 SDPA 以 bf16 计算 attention,并正在努力将其迁移到 float8。

实验

我们进行了各种实验来证明 float8 训练的优势。首先是确保模型质量不会受损。为此,我们训练了 8B 和 70B 模型数千步,并比较了 float8 和 bf16 训练运行的损失曲线。我们的实验在三个不同的 H100 集群上进行,分别是 128、256 和 512 个 H100 GPU 配置,环境各异,以证明可重复性。第一个集群是 Meta 基于 Grand Teton 定制的,具有 400Gbps 自定义互联;第二个是具有 3.2Tbps Infiniband 互联的 IBM 研究集群;第三个是具有 3.2Tbps RoCE 互联用于 GPU 间通信的 IBM Cloud 集群。

首先,我们在下图中绘制了这两种模型的损失曲线对比图,以展示数千步内的损失对等性。

Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps
Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps

图 1:(a) 8B 模型 2k 步损失对等性,(b) 70B 模型 1k 步损失对等性

我们观察到,在这些不同的模型和环境下,我们在小规模 token 训练中获得了损失对等性。接下来,我们刻画了从 1.8B 到 405B 四种不同模型规模的吞吐量增益。我们为 float8 和 bf16 训练运行探索了最佳 batch size 和激活检查点方案,以确定 token/秒/GPU (wps) 指标并报告性能增益。对于 405B 模型,我们利用 DTensor 进行 FSDP2 张量并行训练。我们所有的测量均使用 8K 序列长度。

模型大小 wps (bf16) wps (float8) 百分比增益
1.8B 29K 35K 18%
8K 8K 10K 28%
70B 956 1430 50%
405B (TP4) 149 227 52%

表 1:相对于 bf16 的性能增益(bf16 和 float8 均使用 torch.compile)

从表 1 中我们可以观察到,较大模型(70B 和 405B)的增益高达 50%,较小模型也有约 20% 到 30% 的增益。在进一步的实验中,我们观察到添加 float8 all_gather 在计算本身之外带来了约 5% 的提升,这与此博客中的观察结果一致。

其次,为了证明 FP8 模型的有效性,我们使用 Hugging Face 的 FineWeb-edu 数据集,按照 Llama3 架构训练了一个 3B 模型,进行了 1T token 的训练。我们使用 lm-eval-harness 框架进行了评估,并在下表中展示了部分结果。我们观察到 bf16 的性能略优于 float8(约 1%)。虽然某些分数在 bf16 下表现更好(例如 MMLU 高出 3 分),但我们预计随着正确超参数的选择以及更大规模的训练,这些差距将会消失(例如,bf16 运行的 batch size 减半,而众所周知,较小的 batch size 运行可以提高评估分数)。

基准测试 分数 (float8) 分数 (bf16)
MMLU (5-shot) 0.26 0.29
ARC-e 0.73 0.73
ARC-c 0.43 0.46
Hellaswag 0.65 0.67
sciq 0.89 0.88
OpenBook QA 0.43 0.43
PIQA 0.76 0.76
Winogrande 0.60 0.65
平均值 0.59 0.60

表 2:评估时运行于 FP16 的 float8 训练模型的基准得分(在 1T token FineWeb 预训练下)。

最后,我们将实验扩展到 IBM Cloud 集群上的 512 个 H100 GPU。我们能够重现即使在 512 GPU 规模下观察到的结果和加速效果。我们在下表中总结了大型模型(70B 和 405B)的结果。

模型大小 wps (bf16) wps (float8) 百分比增益
70B 960 1448 51%
405B (TP4) 152 217 43%

表 3:512 GPU 规模下相对于 bf16 的性能增益(bf16 和 float8 均使用 torch.compile)

未来工作

我们也在研究评估其他形式的并行化,例如上下文并行(Context Parallelism)。我们计划评估所有这些功能,以展示其组合性以及为训练大规模模型做出选择的能力。

致谢

感谢 IBM 研究院的 Davis Wertheimer 为 torchtitan 运行启用了数据加载器,使我们能够在多次运行中以相同的顺序重放数据。我们还要感谢 IBM Cloud 为我们提供了 H100 集群的早期测试访问权限。