最近,我们展示了如何使用 FSDP 和选择性激活检查点机制,在 A100 GPU 上训练 7B 模型时实现 57% 的 MFU(模型浮点运算利用率)。我们还展示了如何使用它训练高质量模型,并以 Apache v2.0 许可在 Hugging Face Hub 上开源了 Granite 7B 基础模型。
我们继续致力于通过利用 torch.compile 来提高 GPU 的利用率。通过使用 torch.compile 和我们之前工作中使用的选择性激活检查点机制,我们在 A100 GPU 上针对 7B 模型实现了 68% 的 MFU! torch.compile 将各种模型尺寸的训练 MFU 提高了 10% 到 23%。
本博客分为三个部分:(1)为了使用 torch.compile 进行训练而解决的挑战,(2)编译与非编译的数值对等性,以及(3)MFU 报告。
我们开源了所有代码,并在 fms-fsdp 仓库中进行了更新。我们还与 Meta 的 PyTorch 团队合作,将这些代码贡献到新发布的 torch titan 仓库中,用于预训练。
使用 torch.compile 的挑战
torch.compile 是一种图编译技术,可以提高 GPU 利用率。有关 torch.compile 工作原理的详细信息,请读者参考最近的 PyTorch 论文和相关教程。使 torch.compile 良好运行的关键挑战是最大限度地减少(或消除)图中断。我们最初从 Meta 提供的 Llama 实现开始,但编译它会导致过多的图中断,从而降低训练吞吐量。
模型架构的几个部分必须进行修复,其中最重要的是位置嵌入层(RoPE)。典型的 RoPE 实现使用复数,这在测试时 torch.compile 不支持。我们使用 einops 实现了 RoPE,同时保持了与原始模型架构实现的一致性。我们必须正确缓存频率,以避免在 RoPE 实现中遇到图中断。
编译 FSDP 模型确实会导致图中断,Meta 的 PyTorch 团队正在努力消除这些中断。然而,截至 PyTorch 2.3,这些图中断位于 FSDP 单元边界,并且不会显着影响吞吐量。
当使用自定义内核时,我们需要通过将其 API 暴露给 torch.compile 来包装每个内核。这包括指示哪些参数被就地修改、如何修改以及它们的返回值将根据输入具有哪些形状和步幅。在我们的例子中,SDPA Flash attention 已经适当集成,我们能够使该内核在 torch.compile 中工作,而没有图中断。
我们还注意到,当将数据量从 2T 增加到 6T tokens 时,数据加载器成为瓶颈。一个关键原因是,以前我们在数据加载器中天真地实现了文档打乱,方法是让每个 worker 维护一个打乱的文档指针列表。
随着数据集的增大,这些指针列表增长到每个 worker 数十万个条目。在这种规模下维护指针列表变得非常昂贵,以至于 CPU 争用限制了我们的训练吞吐量。我们使用线性同余生成器重新实现了文档打乱,而没有使用任何指针列表。LCG 是一种伪随机数生成器算法,它实现了对总体的随机游走,提供无放回抽样。
我们利用相同的思想来生成从有序文档索引到打乱文档索引的隐式双射映射。这使我们能够将那些烦人的数十万个指针列表缩小为 LCG 的单个整数状态。这消除了 80% 的瓶颈,并显着提高了我们的性能。我们将专门写一篇博客来详细介绍我们高性能的预训练数据加载器。
torch.compile 和 torch.no-compile 的数值对等性
我们之前观察到在使用编译和非编译选项进行训练时存在对等性问题,其中一个问题与 SDPA 的使用有关。经过 Meta 和 IBM 的 PyTorch 团队之间几天的密集调试会话,我们成功实现了 PyTorch 编译和非编译模式之间的对等性。为了记录和验证这种对等性,我们采用了一个 1.4B 大小的 mini-Llama 模型架构,并以四种变体对其进行 100B tokens 的训练——非编译、无激活检查点编译、选择性激活检查点编译和完全激活检查点编译。
我们在下面绘制了这些选项的损失曲线和梯度范数
图 1:各种编译选项的损失曲线和梯度范数
此外,我们运行了 lm-evaluation-harness,比较了不同基准上各种模型的得分,并观察到编译和非编译之间没有重大差异,如下所示。
图 2:lm-evaluation-harness 比较编译和非编译之间各种基准
从所有这些结果中我们观察到,编译及其所有变体都与非编译选项相等,从而证明了编译和非编译之间的对等性。
MFU 报告
最后,与我们之前的博客一样,我们在两个集群上计算了四种不同模型尺寸的 MFU。一个集群是 128 个 A100 GPU,具有 400 Gbps 的节点间连接,另一个集群是 464 个 H100 GPU,具有 3.2 Tbps 的节点间连接。除了编译之外,我们还使用了我们在之前的博客中介绍的选择性激活检查点机制。我们将结果记录在下表中。
模型尺寸 | 批大小 | MFU (非编译) | MFU (编译) | 百分比增益 (%) |
7B | 2 | 0.57 | 0.68 | 20 |
13B | 2 | 0.51 | 0.60 | 17 |
34B | 2 | 0.47 | 0.54 | 15 |
70B | 2 | 0.50 | 0.55 | 10 |
表 1:在具有 400Gbps 节点间互连的 128 个 A100 80GB GPU 上,Llama2 模型架构的编译和非编译 MFU 结果
模型尺寸 | 批大小 | MFU (非编译) | MFU (编译) | 百分比增益 |
7B | 2 | 0.37 | 0.45 | 21 |
13B | 2 | 0.35 | 0.43 | 23 |
34B | 2 | 0.32 | 0.38 | 19 |
70B | 2 | 0.32 | 0.38 | 19 |
表 2:在具有 3.2Tbps 节点间互连的 464 个 H100 80GB GPU 上,Llama2 模型架构的编译和非编译 MFU 结果
我们还在 448 个 GPU 上使用 Llama2 7B 架构进行了内部生产运行。通过使用编译和选择性激活检查点机制,以及 370 万的全局批大小,我们在 13 天 10 小时内训练了 4T tokens!
在训练期间,数据中心冷却系统不得不启动额外的空调,我们的训练团队也收到了警报,因为我们非常有效地使用了 GPU ☺
从表 1 和表 2 中的一个关键观察是,MFU 数字并没有随着模型尺寸线性扩展。我们正在积极调查两种可能的解释,一种是 FSDP 随着模型尺寸增加的可扩展性,以及何时需要启用张量并行以更有效地使用 GPU;另一种是批大小,可以进一步增加以获得更好的 MFU。我们计划探索 FSDP v2 和选择性算子检查点机制以及张量并行功能,以研究 FSDP 随模型尺寸的扩展规律。
未来工作
我们计划开始测试 FSDP v2,它将作为 PyTorch 2.4 的一部分发布。FSDP v2 提供了每个参数的分片和选择性算子检查点机制,这可能提供更好的内存-计算权衡。
我们还与 Meta 的 PyTorch 团队合作,评估新的异步检查点机制功能,该功能可以通过减少写入检查点的时间来进一步提高 GPU 利用率。
我们正在探索扩展当前用于推理的各种 Triton 内核,以执行反向操作,从而获得超出仅推理的加速。
最后,随着最近关于使用 fp8 的研究涌现,我们计划探索如何使用这种承诺 2 倍加速的新数据类型来进一步加速模型训练。
致谢
有几个团队参与了达到这个验证点的过程,我们想感谢 Meta 和 IBM 的各个团队。特别是,我们向 Meta PyTorch 分布式和编译器团队以及 IBM 研究院表示感谢。
许多人广泛参与了实现 torch.compile 与我们模型的数值对等性的工作,我们希望感谢参与这项工作的关键人员;Meta 的 Animesh Jain 和 Less Wright,以及 IBM 研究院的 Linsong Chu、Davis Wertheimer、Brian Vaughan、Antoni i Viros Martin、Mudhakar Srivatsa 和 Raghu Ganti。
特别感谢 Stas Bekman,他提供了广泛的反馈并帮助改进了这篇博客。他们的见解对于突出优化训练的关键方面和探索进一步增强功能非常有价值。