IBM:Tuan Hoang Trong, Alexei Karve, Yan Koyfman, Linsong Chu, Divya Kumari, Shweta Salaria, Robert Walkup, Praneet Adusumilli, Nirmit Desai, Raghu Ganti, Seetharami Seelam
Meta:Less Wright, Wei Feng, Vasiliy Kuznetsov, Driss Guesseous
在本篇博客中,我们将展示如何实现高达 50% 的吞吐量加速,同时在训练中保持与 FSDP1 bf16 训练相当的损失值和评估基准。我们通过利用 FSDP2、DTensor 和 torch.compile,结合 torchao 的 float8 线性层更新(计算)以及用于权重通信的 float8 all_gather 来实现这一加速。我们展示了这些改进在多种 Meta LLaMa 模型架构规模(从 1.8B 的小模型到 405B 的超大模型)上的表现,使训练速度达到前所未有的水平。
我们使用 Meta Llama3 架构演示了这些改进,并在两种规模下进行了模型质量研究:8B 模型规模下的 100B token 训练,以及 70B 模型规模下的 50B token 训练,从而提供了 float8 与 bf16 训练损失曲线的精确对比。我们证明,与 bf16 对照组相比,这些模型训练运行的损失曲线实现了完全一致的收敛。此外,我们使用 FineWeb-edu 数据集将一个 3B 模型训练到了 1T token,并运行了标准评估基准,以确保模型质量保持完整,且与 bf16 运行结果相当。
在 IBM 研究院,我们计划将这些功能应用于数据消融实验,以提高在给定 GPU 预算下可执行的实验数量。从长远来看,我们将进行更大规模的模型运行,以演示 float8 训练的端到端可行性。
什么是 Float8?
用于模型训练的 float8 格式由 NVIDIA、ARM 和 Intel 在一份 2022 年的论文中提出,证明了使用低精度 float8 进行训练且不牺牲模型质量的可行性。随着 NVIDIA Hopper 系列等新型 GPU 的推出,FP8 训练变得可行,得益于原生的 float8 张量核心支持,训练吞吐量有望提升 2 倍以上。要实现这一目标,还存在几个挑战:
(i) 在 float8 中启用核心模型操作,如 matmul 和 attention;
(ii) 在分布式框架中启用 float8 训练;以及
(iii) 在 float8 中启用 GPU 间的权重通信。
虽然 float8 matmul 已由 NVIDIA 库启用,但后两项功能是在最近的 FSDP2 和 torchao 更新中提供的。
在本博客中,我们使用 torchtitan 作为训练入口,结合 IBM 的确定性数据加载器、来自 torchao 的 float8 线性层实现,以及来自最新 PyTorch 每日构建版(nightlies)并配合 FSDP2 使用的 float8 all gather。在本次训练中,我们使用的是 float8 按张量(tensorwise)缩放粒度,而非按行(rowwise)缩放。我们利用 torch.compile 来确保获得最大的性能提升。目前,我们正使用 SDPA 以 bf16 计算 attention,并正在努力将其迁移到 float8。
实验
我们进行了各种实验来证明 float8 训练的优势。首先是确保模型质量不会受损。为此,我们训练了 8B 和 70B 模型数千步,并比较了 float8 和 bf16 训练运行的损失曲线。我们的实验在三个不同的 H100 集群上进行,分别是 128、256 和 512 个 H100 GPU 配置,环境各异,以证明可重复性。第一个集群是 Meta 基于 Grand Teton 定制的,具有 400Gbps 自定义互联;第二个是具有 3.2Tbps Infiniband 互联的 IBM 研究集群;第三个是具有 3.2Tbps RoCE 互联用于 GPU 间通信的 IBM Cloud 集群。
首先,我们在下图中绘制了这两种模型的损失曲线对比图,以展示数千步内的损失对等性。


图 1:(a) 8B 模型 2k 步损失对等性,(b) 70B 模型 1k 步损失对等性
我们观察到,在这些不同的模型和环境下,我们在小规模 token 训练中获得了损失对等性。接下来,我们刻画了从 1.8B 到 405B 四种不同模型规模的吞吐量增益。我们为 float8 和 bf16 训练运行探索了最佳 batch size 和激活检查点方案,以确定 token/秒/GPU (wps) 指标并报告性能增益。对于 405B 模型,我们利用 DTensor 进行 FSDP2 张量并行训练。我们所有的测量均使用 8K 序列长度。
| 模型大小 | wps (bf16) | wps (float8) | 百分比增益 |
| 1.8B | 29K | 35K | 18% |
| 8K | 8K | 10K | 28% |
| 70B | 956 | 1430 | 50% |
| 405B (TP4) | 149 | 227 | 52% |
表 1:相对于 bf16 的性能增益(bf16 和 float8 均使用 torch.compile)
从表 1 中我们可以观察到,较大模型(70B 和 405B)的增益高达 50%,较小模型也有约 20% 到 30% 的增益。在进一步的实验中,我们观察到添加 float8 all_gather 在计算本身之外带来了约 5% 的提升,这与此博客中的观察结果一致。
其次,为了证明 FP8 模型的有效性,我们使用 Hugging Face 的 FineWeb-edu 数据集,按照 Llama3 架构训练了一个 3B 模型,进行了 1T token 的训练。我们使用 lm-eval-harness 框架进行了评估,并在下表中展示了部分结果。我们观察到 bf16 的性能略优于 float8(约 1%)。虽然某些分数在 bf16 下表现更好(例如 MMLU 高出 3 分),但我们预计随着正确超参数的选择以及更大规模的训练,这些差距将会消失(例如,bf16 运行的 batch size 减半,而众所周知,较小的 batch size 运行可以提高评估分数)。
| 基准测试 | 分数 (float8) | 分数 (bf16) |
| MMLU (5-shot) | 0.26 | 0.29 |
| ARC-e | 0.73 | 0.73 |
| ARC-c | 0.43 | 0.46 |
| Hellaswag | 0.65 | 0.67 |
| sciq | 0.89 | 0.88 |
| OpenBook QA | 0.43 | 0.43 |
| PIQA | 0.76 | 0.76 |
| Winogrande | 0.60 | 0.65 |
| 平均值 | 0.59 | 0.60 |
表 2:评估时运行于 FP16 的 float8 训练模型的基准得分(在 1T token FineWeb 预训练下)。
最后,我们将实验扩展到 IBM Cloud 集群上的 512 个 H100 GPU。我们能够重现即使在 512 GPU 规模下观察到的结果和加速效果。我们在下表中总结了大型模型(70B 和 405B)的结果。
| 模型大小 | wps (bf16) | wps (float8) | 百分比增益 |
| 70B | 960 | 1448 | 51% |
| 405B (TP4) | 152 | 217 | 43% |
表 3:512 GPU 规模下相对于 bf16 的性能增益(bf16 和 float8 均使用 torch.compile)
未来工作
我们也在研究评估其他形式的并行化,例如上下文并行(Context Parallelism)。我们计划评估所有这些功能,以展示其组合性以及为训练大规模模型做出选择的能力。
致谢
感谢 IBM 研究院的 Davis Wertheimer 为 torchtitan 运行启用了数据加载器,使我们能够在多次运行中以相同的顺序重放数据。我们还要感谢 IBM Cloud 为我们提供了 H100 集群的早期测试访问权限。