跳转到主要内容
博客

使用 PyTorch FSDP 和 Torch.compile 最大化训练吞吐量

作者: 2024 年 5 月 21 日2024 年 11 月 13 日暂无评论

最近,我们展示了如何使用 FSDP 和选择性激活检查点来实现 A100 GPU 上 7B 模型训练的 57% MFU(模型浮点运算利用率)。我们还展示了如何训练出高质量模型,我们已将其开源为 Hugging Face Hub 上的 Granite 7B 基础模型,采用 Apache v2.0 许可证。

我们通过利用 torch.compile 继续改进 GPU 利用率。结合 torch.compile 和我们先前工作中的选择性激活检查点,我们为 A100 GPU 上的 7B 模型实现了 68% 的 MFU!torch.compile 将各种模型尺寸的训练 MFU 提高了 10% 到 23%。

本博客分为三个部分:(1) 使用 torch.compile 训练时遇到的挑战,(2) compile 与 no-compile 的数值一致性,以及 (3) MFU 报告。

我们已将所有代码开源并更新到 fms-fsdp 仓库中。我们还与 Meta 的 PyTorch 团队合作,将这些贡献到新发布的用于预训练的 torch titan 仓库中。

使用 torch.compile 的挑战

torch.compile 是一种图编译技术,可提高 GPU 利用率。有关 torch.compile 工作原理的详细信息,我们建议读者参考最近的 PyTorch 论文和相关教程。让 torch.compile 表现良好的一个关键挑战是最大程度地减少(或消除)图中断。我们最初从 Meta 提供的 Llama 实现开始,但编译它导致了太多的图中断,从而降低了训练吞吐量。

模型架构的几个部分需要修复,其中最重要的是位置嵌入层 (RoPE)。典型的 RoPE 实现使用复数,而这在测试时 torch.compile 尚不支持。我们使用 einops 实现了 RoPE,同时保持了与原始模型架构实现的一致性。我们必须正确缓存频率,以便在 RoPE 实现中不会出现图中断。

编译 FSDP 模型确实会导致图中断,Meta 的 PyTorch 团队正在努力消除这些中断。然而,截至 PyTorch 2.3,这些图中断发生在 FSDP 单元边界,并且不会显著影响吞吐量。

使用自定义内核时,我们需要通过将其 API 公开给 torch.compile 来包装每个内核。这包括指示哪些参数是原地修改的、如何修改它们,以及它们的返回值将根据输入具有什么形状和步长。在我们的案例中,SDPA Flash attention 已适当集成,我们能够让该内核与 torch.compile 一起工作,而不会出现图中断。

我们还注意到,当数据量从 2T 增加到 6T tokens 时,数据加载器成为了瓶颈。其中一个关键原因是,之前我们天真地在数据加载器中实现了文档混洗,每个 worker 都维护一个混洗过的文档指针列表。

随着数据集的增大,这些指针列表每个 worker 增长到数十万个条目。维护如此规模的指针列表变得非常昂贵,以至于 CPU 争用限制了我们的训练吞吐量。我们使用 线性同余生成器重新实现了文档混洗,而无需任何指针列表。LCG 是一种伪随机数生成器算法,它在总体上实现随机游走,提供不重复采样。

我们利用相同的思想生成从有序文档索引到混洗文档索引的隐式双射映射。这使我们能够将那些令人讨厌的数十万个指针列表缩小为 LCG 的单个整数状态。这消除了 80% 的瓶颈,并显着提升了我们的性能。我们将专门撰写一篇博客,详细介绍我们高性能的预训练数据加载器。

torch.compile 与 torch.no-compile 的数值一致性

我们之前观察到使用 compile 和 no-compile 选项进行训练时存在一致性问题,其中一个与 SDPA 的使用有关。经过 Meta 和 IBM 的 PyTorch 团队几天紧张的调试会话后,我们能够在 PyTorch 的 compile 和 no-compile 模式之间实现一致性。为了记录和验证这种一致性,我们采用了一个 1.4B 大小的 mini-Llama 模型架构,并以四种变体训练它达到 100B tokens——no-compile,不带激活检查点的 compile,带选择性激活检查点的 compile,以及带完全激活检查点的 compile。

我们绘制了这些选项的损失曲线和梯度范数如下

Figure 1: Loss curve and gradient norm for various compile options

图 1:各种编译选项的损失曲线和梯度范数

此外,我们运行了 lm-evaluation-harness 并比较了各种模型在不同基准上的分数,没有观察到编译和不编译之间存在重大差异,如下图所示。

Figure 2: lm-evaluation-harness comparison of various benchmarks between compile and no-compile

图 2:编译与不编译之间各种基准的 lm-evaluation-harness 比较

从所有这些结果中,我们观察到带所有变体的编译选项与不编译选项是等效的,从而证明了编译与不编译之间的一致性。

MFU 报告

最后,像我们之前的博客一样,我们计算了两个集群上四种不同模型尺寸的 MFU。一个集群是 128 块 A100 GPU,具有 400 Gbps 的节点间连接,另一个集群是 464 块 H100 GPU,具有 3.2 Tbps 的节点间连接。除了编译之外,我们还使用了我们在之前的博客中介绍的选择性激活检查点。我们将结果记录在下表中。

模型大小批次大小MFU 未编译MFU 已编译百分比增益 (%)
7B20.570.6820
13B20.510.6017
34B20.470.5415
70B20.500.5510

表 1:在 128 块 A100 80GB GPU 上,采用 400Gbps 节点间互连的 Llama2 模型架构的编译和未编译 MFU 结果

模型大小批次大小MFU 未编译MFU 已编译百分比增益
7B20.370.4521
13B20.350.4323
34B20.320.3819
70B20.320.3819

表 2:在 464 块 H100 80GB GPU 上,采用 3.2Tbps 节点间互连的 Llama2 模型架构的编译和未编译 MFU 结果

我们还在 448 块 GPU 上使用 Llama2 7B 架构进行了内部生产运行。通过使用编译和选择性激活检查点,以及 3.7M 的全局批量大小,我们在 13 天 10 小时内训练了 4T tokens!

训练期间,数据中心冷却系统不得不启动额外的空调,我们的训练团队也接到了警报,因为我们正在非常有效地利用 GPU ☺

从表1和表2的一个关键观察是,MFU数值并非随着模型大小线性扩展。我们正在积极调查两种可能的解释,一是FSDP随着模型大小增加时的可扩展性以及何时需要启用张量并行以更有效地使用GPU,二是批量大小,可以进一步增加以获得更好的MFU。我们计划探索FSDP v2和选择性操作符检查点以及张量并行特性,以研究FSDP与模型大小的缩放定律。

未来工作

我们计划开始测试 FSDP v2,它将作为 PyTorch 2.4 的一部分发布。FSDP2 提供了每个参数分片和选择性操作符检查点功能,这可能提供更好的内存-计算权衡。

我们还与 Meta 的 PyTorch 团队合作,评估新的异步检查点功能,该功能可以通过减少写入检查点的时间来进一步提高 GPU 利用率。

我们正在探索扩展目前用于推理的各种 Triton 内核,以执行反向操作,从而获得超越仅推理的加速。

最后,随着 fp8 使用的最新工作不断涌现,我们计划探索如何使用这种有望实现 2 倍加速的新数据类型,进一步加速模型训练。

致谢

有几个团队参与了实现这一证明点的工作,我们要感谢 Meta 和 IBM 的各个团队。特别地,我们向 Meta PyTorch 分布式和编译器团队以及 IBM 研究院表示感谢。

在实现模型与 torch.compile 数值一致性的工作中,许多人付出了巨大的努力,我们在此感谢参与此项工作的关键人员:Meta 的 Animesh Jain 和 Less Wright,以及 IBM 研究院的 Linsong Chu、Davis Wertheimer、Brian Vaughan、Antoni i Viros Martin、Mudhakar Srivatsa 和 Raghu Ganti。

特别感谢 Stas Bekman,他提供了广泛的反馈并帮助改进了本博客。他们的见解对于突出优化训练的关键方面和探索进一步增强功能具有宝贵价值。