跳转到主要内容
博客

使用 float8 和 FSDP2 超级训练

作者: 2024 年 11 月 25 日2025 年 5 月 5 日暂无评论

IBM:Tuan Hoang Trong, Alexei Karve, Yan Koyfman, Linsong Chu, Divya Kumari, Shweta Salaria, Robert Walkup, Praneet Adusumilli, Nirmit Desai, Raghu Ganti, Seetharami Seelam
Meta:Less Wright, Wei Feng, Vasiliy Kuznetsov, Driss Guesseous

在本博客中,我们将展示如何通过利用 FSDP2、DTensor 和 torch.compile 结合 torchao 的 float8 进行线性层更新(计算)以及 float8 all_gathers 进行权重通信,在训练中实现高达 50% 的吞吐量加速,同时保持与 FSDP1 bf16 训练相同的损失和评估基准。我们将在从小型 1.8B 模型到 405B 模型的 Meta LLaMa 模型架构系列中展示这些改进,从而使训练比以往更快。

我们将使用 Meta Llama3 架构演示这些改进,然后进行两个规模的模型质量研究:8B 模型规模下 100B 令牌,以及 70B 模型规模下 50B 令牌,这提供了 float8 和 bf16 训练损失曲线的精确比较。我们证明,与 bf16 对应版本相比,这些模型训练运行的损失曲线导致相同的损失收敛。此外,我们使用 FineWeb-edu 数据集训练了一个 3B 模型,达到 1T 令牌,并运行标准评估基准,以确保模型质量保持完整并与 bf16 运行相当。

在 IBM Research,我们计划采用这些功能进行数据消融,以提高我们在给定 GPU 预算内可以执行的实验数量。从长远来看,我们将进行更大规模的模型运行,以证明 float8 训练的端到端可行性。

什么是 Float8?

float8 模型训练格式由 NVIDIA、ARM 和 Intel 在 2022 年论文中提出,该论文证明了使用较低精度 float8 进行训练的可行性,而不会牺牲模型质量。随着 NVIDIA Hopper 系列等新型 GPU 的引入,FP8 训练变得可行,由于原生 float8 核心支持,训练吞吐量有可能提高两倍以上。实现这一承诺存在一些挑战:
(i) 在 float8 中启用核心模型操作,如 matmulattention
(ii) 在分布式框架中启用 float8 训练,以及
(iii) 在 float8 中启用 GPU 之间的权重通信。
虽然 float8 matmul 由 NVIDIA 库启用,但后两者已在 FSDP2torchao 的最新更新中提供。

在本博客中,我们使用 torchtitan 作为训练入口点,IBM 的确定性数据加载器,来自 torchaofloat8 线性层实现,以及与 FSDP2 结合的最新 PyTorch nightly 版本中的 float8 all gather。对于此次训练,我们使用 float8 逐张量(tensorwise)缩放粒度而不是逐行缩放粒度。我们利用 torch.compile 以确保获得最大的性能提升。我们正在使用 SDPA 在 bf16 中计算 attention,目前正在努力将其也转换为 float8。

实验

我们进行了各种实验以证明 float8 训练的优势。首先是确保不牺牲模型质量。为了验证这一点,我们训练了一个 8B 模型和 70B 模型数千步,并比较了 float8 和 bf16 训练运行之间的损失曲线。我们的实验在三个不同的 H100 集群上进行,分别具有 128、256 和 512 个 H100 GPU 配置,在截然不同的环境中进行,以证明可重现性。第一个集群在 Meta 的 Grand Teton 上进行定制,采用 400Gbps 定制互连;第二个是 IBM 研究集群,采用 3.2Tbps Infiniband 互连;第三个是 IBM Cloud 集群,采用 3.2Tbps RoCE 互连进行 GPU 到 GPU 通信。

首先,我们在下图中绘制了这两个模型的损失曲线比较,以证明数千步的损失一致性。

Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps
Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps

图 1:(a) 8B 模型 2k 步损失一致性,(b) 70B 模型 1k 步损失一致性

我们观察到,在这些不同的模型和不同的环境中,我们获得了小规模令牌的损失一致性。接下来,我们表征了从 1.8B 到 405B 四种不同模型大小的吞吐量增益。我们探索了 float8 和 bf16 训练运行的最佳批处理大小和激活检查点方案,以确定每秒令牌数/GPU (wps) 指标并报告性能增益。对于 405B 模型,我们利用 DTensor 进行与 FSDP2 结合的张量并行训练。我们所有的测量都使用 8K 的序列长度。

模型大小 wps (bf16) wps (float8) 增益百分比
1.8B 29K 35K 18%
8K 8K 10K 28%
70B 956 1430 50%
405B (TP4) 149 227 52%

表 1:相对于 bf16 的性能增益(bf16 和 float8 都使用 torch.compile)

从表 1 中我们可以观察到,对于更大的模型(70B 和 405B),增益高达 50%,而较小的模型增益约为 20% 到 30%。在进一步的实验中,我们观察到添加 float8 all_gatherfloat8 计算本身之外还带来了约 5% 的提升,这与 这篇博客中的观察结果一致。

其次,为了证明 FP8 模型的有效性,我们使用 Hugging Face 的 FineWeb-edu 数据集训练了一个遵循 Llama3 架构的 3B 模型,达到 1T 令牌。我们使用 lm-eval-harness 框架进行评估,并在下表中列出了部分结果。我们观察到 bf16 的性能略好于 float8 的分数(大约百分之一)。虽然某些分数在 bf16 下显著更好(例如,MMLU 高出 3 分),但我们预计在选择正确的超参数和更大规模的训练运行中,这些差距会消失(例如,bf16 运行的批处理大小减半,众所周知,较小的批处理大小运行可以提高评估分数)。

基准 分数 (float8) 分数 (bf16)
MMLU (5-shot) 0.26 0.29
ARC-e 0.73 0.73
ARC-c 0.43 0.46
Hellaswag 0.65 0.67
sciq 0.89 0.88
OpenBook QA 0.43 0.43
PIQA 0.76 0.76
Winogrande 0.60 0.65
平均 0.59 0.60

表 2:float8 训练模型在 FP16 中进行评估的基准分数(FineWeb 预训练达到 1T 令牌)。

最后,我们将实验扩展到 IBM Cloud 集群上的 512 个 H100 GPU。即使在 512 个 GPU 的规模下,我们也能够重现我们观察到的结果和加速。下表仅总结了大型模型(70B 和 405B)的这些结果。

模型大小 wps (bf16) wps (float8) 增益百分比
70B 960 1448 51%
405B (TP4) 152 217 43%

表 3:512 GPU 规模下相对于 bf16 的性能增益(bf16 和 float8 都使用 torch.compile)

未来工作

我们还在评估其他并行形式,例如上下文并行。我们计划评估所有这些功能,以展示其可组合性和为训练大规模模型做出选择的能力。

致谢

我们感谢 IBM Research 的 Davis Wertheimer 启用 torchtitan 运行的数据加载器,使我们能够在多次运行中以相同的顺序重播数据。我们还要感谢 IBM Cloud 为我们提供 H100 集群的早期测试访问。