IBM:Tuan Hoang Trong, Alexei Karve, Yan Koyfman, Linsong Chu, Divya Kumari, Shweta Salaria, Robert Walkup, Praneet Adusumilli, Nirmit Desai, Raghu Ganti, Seetharami Seelam
Meta:Less Wright, Wei Feng, Vasiliy Kuznetsov, Driss Guesseous
在本篇博客中,我们将展示如何实现高达 50% 的吞吐量加速,同时在训练中达到与 FSDP1 bf16 训练相同的损失和评估基准。我们通过利用 FSDP2、DTensor 和 torch.compile
,结合 torchao 的 float8 进行线性层更新(计算),以及 float8 all_gathers 进行权重通信,从而实现这一加速。我们展示了这些改进在不同规模的 Meta LLaMa 模型架构上的效果,从小型 1.8B 模型到 405B 模型,使得训练比以往更快。
我们使用 Meta Llama3 架构展示了这些改进,并在两个规模上进行了模型质量研究:使用 8B 模型进行 100B token 的训练,以及使用 70B 模型进行 50B token 的训练,这提供了 float8 和 bf16 训练损失曲线的精确比较。我们展示了在这些模型训练运行中,损失曲线与 bf16
对应运行的损失收敛结果相同。此外,我们使用 FineWeb-edu 数据集将一个 3B 模型训练到 1T token,并运行标准评估基准,以确保模型质量完好无损,并可与 bf16
运行进行比较。
在 IBM Research,我们计划采用这些功能来进行数据消融,以在给定的 GPU 预算内增加我们可以进行的实验数量。从长远来看,我们将进行更大规模的模型运行,以展示 float8
训练的端到端可行性。
什么是 Float8?
float8
格式用于模型训练,由 NVIDIA、ARM 和 Intel 在 2022 年的一篇论文中提出,该论文证明了在不牺牲模型质量的情况下使用低精度 float8 进行训练的可行性。随着 NVIDIA Hopper 系列等新型 GPU 的引入,FP8 训练变得可行,由于原生支持 float8 Tensor Core,训练吞吐量有望提高 2 倍以上。实现这一承诺存在一些挑战:
(i) 在 float8
中启用核心模型操作,如 matmul
和 attention
,
(ii) 在分布式框架中启用 float8
训练,以及
(iii) 在 float8
中启用 GPU 之间的权重通信。
虽然 float8
matmul
已通过 NVIDIA 库启用,但后两项功能已在 FSDP2
和 torchao
的最新更新中提供。
在本篇博客中,我们使用 torchtitan 作为训练的入口,IBM 的确定性数据加载器,来自 torchao 的 float8
线性层实现,以及来自最新的 PyTorch nightlies 的 float8 all gather
,并结合 FSDP2。对于本次训练,我们使用 float8 按张量(张量级)缩放粒度,而不是按行。我们利用 torch.compile
来确保获得最大的性能增益。我们使用 SDPA 在 bf16
中计算 attention
,目前也正在努力将其迁移到 float8。
实验
我们进行了各种实验,以展示 float8 训练的优势。首先是确保模型质量不被牺牲。为了验证这一点,我们训练了一个 8B 模型和 70B 模型,运行数千步,并比较 float8 和 bf16 训练运行的损失曲线。我们的实验在三个不同的 H100 集群上进行,配置包括 128、256 和 512 个 H100 GPU,它们处于非常不同的环境中,以展示可重复性。第一个集群是 Meta 的 Grand Teton 定制集群,采用 400Gbps 定制互连;第二个是 IBM 研究集群,采用 3.2Tbps Infiniband 互连;第三个是 IBM Cloud 集群,采用 3.2Tbps RoCE 互连进行 GPU 到 GPU 通信。
首先,我们在下图绘制了这两种模型的损失曲线比较,以展示在数千步训练中达到的损失一致性。
图 1:(a) 8B 模型在 2k 步训练中的损失一致性,(b) 70B 模型在 1k 步训练中的损失一致性
我们观察到,在这些不同的模型和环境中,对于小规模的 token,我们获得了损失一致性。接下来,我们对从 1.8B 到 405B 四种不同模型大小的吞吐量增益进行了表征。我们探索了 float8 和 bf16 训练运行的最佳批量大小和激活检查点方案,以确定 tokens/秒/GPU (wps) 指标并报告性能增益。对于 405B 模型,我们利用了 DTensor
与 FSDP2 结合进行张量并行训练。我们在所有测量中都使用了 8K 的序列长度。
模型大小 | wps (bf16) | wps (float8) | 百分比增益 |
1.8B | 29K | 35K | 18% |
8B | 8K | 10K | 28% |
70B | 956 | 1430 | 50% |
405B (TP4) | 149 | 227 | 52% |
表 1:相对于 bf16 的性能增益(bf16 和 float8 均使用 torch.compile)
我们从表 1 观察到,对于大型模型(70B 和 405B),增益高达 50%,而小型模型的增益大约在 20% 到 30% 之间。在进一步的实验中,我们观察到添加 float8
all_gather
在 float8
计算本身之外还能带来约 5% 的提升,这与本博客中的观察结果一致。
其次,为了展示 FP8 模型的有效性,我们使用 Hugging Face 的 FineWeb-edu 数据集,按照 Llama3 架构训练了一个 3B 模型,训练时长为 1T token。我们使用 lm-eval-harness
框架进行评估,并在下表列出了部分结果。我们观察到 bf16
的性能略好于 float8
的得分(大约好一个百分点)。虽然在某些得分上 bf16
明显更好(例如,MMLU 高 3 分),但我们预计在选择合适的超参数并在更大规模的训练运行中,这些差距会消失(例如,bf16
运行的批量大小是 float8
的一半,众所周知,较小的批量大小运行可以提高评估得分)。
基准 | 得分 (float8) | 得分 (bf16) |
MMLU (5-shot) | 0.26 | 0.29 |
ARC-e | 0.73 | 0.73 |
ARC-c | 0.43 | 0.46 |
Hellaswag | 0.65 | 0.67 |
sciq | 0.89 | 0.88 |
OpenBook QA | 0.43 | 0.43 |
PIQA | 0.76 | 0.76 |
Winogrande | 0.60 | 0.65 |
平均值 | 0.59 | 0.60 |
表 2:使用 FP16 进行评估的 float8 训练模型的基准得分(在 1T token 的 FineWeb 预训练后)。
最后,我们将实验扩展到 IBM Cloud 集群上的 512 个 H100 GPU。即使在 512 个 GPU 的规模下,我们也能够重现我们观察到的结果和加速效果。下表总结了这些结果,仅针对大型模型(70B 和 405B)。
模型大小 | wps (bf16) | wps (float8) | 百分比增益 |
70B | 960 | 1448 | 51% |
405B (TP4) | 152 | 217 | 43% |
表 3:在 512 GPU 规模下,相对于 bf16 的性能增益(bf16 和 float8 均使用 torch.compile)
未来工作
我们还在评估其他形式的并行性,例如上下文并行 (Context Parallelism)。我们计划评估所有这些功能,以展示它们的可组合性以及在训练大规模模型时进行选择的能力。
致谢
我们感谢 IBM Research 的 Davis Wertheimer 提供了 torchtitan 运行所需的数据加载器,使我们能够在多次运行中按相同顺序回放数据。我们还要感谢 IBM Cloud 提供了 H100 集群的早期测试访问。