跳转到主要内容
博客

使用 PyTorch FSDP 和 Torch.compile 最大化训练吞吐量

作者: 2024 年 5 月 21 日2024 年 11 月 13 日暂无评论

最近,我们展示了如何使用 FSDP 和选择性激活检查点,在 A100 GPU 上训练 7B 模型时实现 57% 的 MFU(模型浮点运算利用率)。我们还展示了如何训练出一个高质量模型,并将其作为 Granite 7B 基础模型 在 Hugging Face Hub 上以 Apache v2.0 许可证开源。

我们继续通过利用 torch.compile 来提高 GPU 利用率。结合 torch.compile 和我们之前工作中的选择性激活检查点,我们在 A100 GPU 上为 7B 模型实现了 68% 的 MFU!torch.compile 将各种模型大小的训练 MFU 提高了 10% 到 23%。

本博客分为三个部分:(1)为使用 torch.compile 进行训练而解决的挑战,(2)compile 与 no-compile 的数值一致性,以及(3)MFU 报告。

我们已将所有代码开源并更新到 fms-fsdp 仓库 中。我们还在与 Meta 的 PyTorch 团队合作,将这些贡献给新发布的 torch titan 仓库,用于预训练。

使用 torch.compile 的挑战

torch.compile 是一种图编译技术,可以提高 GPU 利用率。有关 torch.compile 工作原理的详细信息,我们建议读者参考最近的 PyTorch 论文和相关教程。让 torch.compile 表现出色的一个关键挑战是最大程度地减少(或消除)图中断。我们最初从 Meta 提供的 Llama 实现开始,但编译它导致了太多图中断,从而降低了训练吞吐量。

模型架构的几个部分必须进行修改,其中最重要的是位置嵌入层(RoPE)。典型的 RoPE 实现使用复数,这在测试时 torch.compile 尚不支持。我们使用 einops 实现了 RoPE,同时保持与原始模型架构实现的一致性。我们必须正确缓存频率,这样就不会在 RoPE 实现中遇到图中断。

编译 FSDP 模型确实会导致图中断,Meta 的 PyTorch 团队正在努力消除这些中断。然而,截至 PyTorch 2.3,这些图中断发生在 FSDP 单元边界,并且不会显著影响吞吐量。

使用自定义内核时,我们需要通过将其 API 暴露给 torch.compile 来包装每个内核。这包括指示哪些参数是原地修改的、如何修改以及它们的返回值将根据输入具有什么形状和步长。在我们的案例中,SDPA Flash attention 已适当集成,我们能够让该内核与 torch.compile 协同工作,且没有图中断。

我们还注意到,当数据量从 2T 增加到 6T tokens 时,数据加载器成为了瓶颈。一个主要原因是,我们之前在数据加载器中天真地实现了文档混洗,每个 worker 都维护一个混洗文档指针列表。

对于更大的数据集,这些指针列表增长到每个 worker 数十万条目。在这种规模下维护指针列表变得非常昂贵,以至于 CPU 争用扼杀了我们的训练吞吐量。我们使用 线性同余生成器 重新实现了文档混洗,而无需任何指针列表。LCG 是一种伪随机数生成器算法,它实现了一个在群体上的随机游走,提供了无放回抽样。

我们利用相同的思想来生成从有序到混洗文档索引的隐式双射映射。这使我们能够将那些烦人的数十万指针列表缩小为 LCG 的单个整数状态。这消除了 80% 的瓶颈,并显着提升了我们的性能。我们将专门撰写一篇博客,详细介绍我们高性能预训练数据加载器的所有细节。

torch.compile 与 torch.no-compile 的数值一致性

我们之前在 compile 和 no-compile 选项训练时观察到一致性问题,其中一个与 SDPA 的使用有关。经过 Meta 和 IBM 的 PyTorch 团队几天紧张的调试会话后,我们能够在 PyTorch compile 和 no-compile 模式之间实现一致性。为了记录和验证这种一致性,我们采用一个 1.4B 大小的 mini-Llama 模型架构,并以四种变体对其进行 100B tokens 的训练——no-compile,compile 不带激活检查点,compile 带选择性激活检查点,以及 compile 带完全激活检查点。

我们绘制了这些选项的损失曲线和梯度范数如下:

Figure 1: Loss curve and gradient norm for various compile options

图 1:各种编译选项的损失曲线和梯度范数

此外,我们运行了 lm-evaluation-harness,并比较了各种模型在不同基准上的分数,观察到 compile 和 no-compile 之间没有重大差异,如下所示。

Figure 2: lm-evaluation-harness comparison of various benchmarks between compile and no-compile

图 2:lm-evaluation-harness 对 compile 和 no-compile 之间各种基准的比较

从所有这些结果中我们观察到,compile 及其所有变体与 no-compile 选项相等,从而证明了 compile 和 no-compile 之间的一致性。

MFU 报告

最后,像我们之前的博客一样,我们在两个集群上计算了四种不同模型大小的 MFU。一个集群是 128 个 A100 GPU,具有 400 Gbps 的节点间连接;另一个集群是 464 个 H100 GPU,具有 3.2 Tbps 的节点间连接。除了 compile,我们还使用了 之前博客 中介绍的选择性激活检查点。结果记录在下表中。

模型大小批次大小无编译 MFU编译 MFU增益百分比 (%)
7B20.570.6820
13B20.510.6017
34B20.470.5415
70B20.500.5510

表 1:Llama2 模型架构在 128 个 A100 80GB GPU 上,采用 400Gbps 节点间互连的编译和非编译 MFU 结果

模型大小批次大小无编译 MFU编译 MFU增益百分比
7B20.370.4521
13B20.350.4323
34B20.320.3819
70B20.320.3819

表 2:Llama2 模型架构在 464 个 H100 80GB GPU 上,采用 3.2Tbps 节点间互连的编译和非编译 MFU 结果

我们还在 448 个 GPU 上使用 Llama2 7B 架构进行了内部生产运行。使用 compile 和选择性激活检查点,全局批次大小为 3.7M,我们在 13 天 10 小时内训练了 4T token!

在训练期间,数据中心冷却系统不得不启动额外的空调,我们的训练团队也收到了警报,因为我们有效地使用了 GPU ☺

从表 1 和表 2 中一个关键的观察是 MFU 数值并非随模型大小线性扩展。我们正在积极调查两种可能的解释,一是随着模型大小的增加 FSDP 的可伸缩性以及何时需要启用张量并行以更有效地使用 GPU,二是批次大小,可以进一步增加以获得更好的 MFU。我们计划探索 FSDP v2 和选择性操作符检查点以及张量并行功能,以研究 FSDP 随模型大小的缩放定律。

未来工作

我们计划开始测试 FSDP v2,它将作为 PyTorch 2.4 的一部分发布。FSDP2 提供了按参数分片和选择性操作符检查点功能,这可能会提供更好的内存-计算权衡。

我们还与 Meta 的 PyTorch 团队合作,评估新的异步检查点功能,该功能可以通过减少写入检查点的时间来进一步提高 GPU 利用率。

我们正在探索扩展目前在推理中使用的各种 Triton 内核,以执行反向操作,从而获得超越仅推理的加速。

最后,随着 fp8 使用的最新工作不断涌现,我们计划探索如何使用这种承诺 2 倍加速的新数据类型进一步加速模型训练。

致谢

有几个团队参与了实现这一证明点,我们要感谢 Meta 和 IBM 的所有团队。特别地,我们向 Meta PyTorch 分布式和编译器团队以及 IBM Research 表示感谢。

多位人员广泛参与了实现 torch.compile 与我们模型数值一致性的工作,我们希望感谢参与这项工作的关键人员;Meta 的 Animesh Jain 和 Less Wright,以及 IBM Research 的 Linsong Chu、Davis Wertheimer、Brian Vaughan、Antoni i Viros Martin、Mudhakar Srivatsa 和 Raghu Ganti。

特别感谢 Stas Bekman,他提供了大量反馈并帮助改进了这篇博客。他们的见解对于突出优化训练和探索进一步增强的关键方面非常有价值。