最近,我们展示了如何使用 FSDP 和选择性激活检查点(selective activation checkpointing)来达到在 A100 GPU 上训练 7B 模型时 57% 的 MFU(模型浮点运算利用率)。我们还展示了它如何训练出一个高质量的模型,该模型已以 Granite 7B base model 的名称在 Hugging Face Hub 上以 Apache v2.0 许可开源。
我们继续探索通过利用 torch.compile 来提高 GPU 利用率。结合 torch.compile 和我们先前工作中使用的选择性激活检查点,我们在 A100 GPU 上对 7B 模型实现了 68% 的 MFU!对于各种模型尺寸,torch.compile 将训练 MFU 提高了 10% 到 23%。
本博客分为三个部分:(1)使用 torch.compile 训练时遇到的挑战及解决方法,(2)编译模式与非编译模式的数值一致性,以及(3)MFU 报告。
我们将所有代码开源并更新到了 fms-fsdp 仓库。我们还在与 Meta 的 PyTorch 团队合作,将这些代码贡献到新发布的用于预训练的 torch titan 仓库。
使用 torch.compile 的挑战
torch.compile 是一种图编译技术,可以提高 GPU 利用率。关于 torch compile 工作原理的详细信息,请读者参考最近的PyTorch 论文和相关教程。让 torch.compile 表现良好的一个关键挑战是最小化(或消除)图断裂(graph breaks)。我们最初使用的是 Meta 提供的 Llama 实现,但编译它会导致过多的图断裂,从而降低训练吞吐量。
必须修复模型架构的几个部分,其中最重要的是位置嵌入层(RoPE)。典型的 RoPE 实现使用复数,这在测试时 torch.compile 尚不支持。我们使用 einops 实现了 RoPE,同时保持了与原始模型架构实现的数值一致性。我们必须适当地缓存频率,以便在 RoPE 实现中不会遇到图断裂。
编译 FSDP 模型确实会导致图断裂,Meta 的 PyTorch 团队正在努力消除这些断裂。然而,截至 PyTorch 2.3,这些图断裂发生在 FSDP 单元边界处,不会显著影响吞吐量。
使用自定义内核时,需要通过将其 API 暴露给 torch.compile 来包装每个内核。这包括指示哪些参数是原地修改的、如何修改它们,以及它们的返回值基于输入将具有什么形状和跨步。在我们的案例中,SDPA Flash attention 已经适当地集成,我们能够使该内核在 torch.compile 下工作,而没有图断裂。
我们还注意到,当数据量从 2T 增加到 6T token 时,数据加载器成为了瓶颈。其主要原因是之前我们在数据加载器中天真地实现了文档混洗,让每个工作进程维护一个混洗文档指针列表。
对于更大的数据集,这些指针列表对于每个工作进程来说会增长到数十万个条目。维持这种规模的指针列表变得非常昂贵,以至于 CPU 竞争限制了我们的训练吞吐量。我们使用线性同余生成器(Linear Congruential Generator)重新实现了文档混洗,而无需任何指针列表。LCG 是一种伪随机数生成器算法,通过对总体进行随机游走实现无放回采样。
我们利用相同的想法来生成从有序文档索引到混洗文档索引的隐式双射映射。这使得我们将那些令人烦恼的数十万个指针列表缩小到 LCG 的一个单个整数状态。这消除了 80% 的瓶颈,并显著提升了我们的性能。我们将专门撰写一篇单独的博客,详细介绍我们高性能的预训练数据加载器。
torch.compile 与 torch.no-compile 的数值一致性
我们之前在使用编译和非编译选项进行训练时观察到一致性问题,其中一个问题与使用 SDPA 有关。经过 Meta 和 IBM 的 PyTorch 团队之间数天的密集调试,我们成功地在 PyTorch 编译模式和非编译模式之间实现了数值一致性。为了记录和验证这种一致性,我们取了一个 1.4B 大小的迷你 Llama 模型架构,并以四种变体对其进行训练,直到达到 100B token——非编译模式、编译模式(无激活检查点)、编译模式(选择性激活检查点)和编译模式(完全激活检查点)。
我们在下面绘制了这些选项的损失曲线和梯度范数
图 1:各种编译选项的损失曲线和梯度范数
此外,我们运行了 lm-evaluation-harness,并比较了各种模型在不同基准测试上的分数,观察到编译模式和非编译模式之间没有主要差异,如下所示。
图 2:lm-evaluation-harness 比较编译模式和非编译模式在各种基准测试上的表现
从所有这些结果中,我们观察到编译模式及其所有变体都与非编译选项相等,从而证明了编译模式和非编译模式之间的数值一致性。
MFU 报告
最后,像我们之前的博客一样,我们计算了两种集群上四种不同模型尺寸的 MFU。一个集群是 128 个 A100 GPU,节点间连接速率为 400 Gbps;另一个集群是 464 个 H100 GPU,节点间连接速率为 3.2 Tbps。除了编译,我们还使用了先前博客中介绍的选择性激活检查点。结果汇总在下表中。
模型尺寸 | 批处理大小 | MFU 非编译 | MFU 编译 | 性能提升 (%) |
7B | 2 | 0.57 | 0.68 | 20 |
13B | 2 | 0.51 | 0.60 | 17 |
34B | 2 | 0.47 | 0.54 | 15 |
70B | 2 | 0.50 | 0.55 | 10 |
表 1:在具有 400Gbps 节点间互连的 128 个 A100 80GB GPU 上,Llama2 模型架构在编译和非编译模式下的 MFU 结果
模型尺寸 | 批处理大小 | MFU 非编译 | MFU 编译 | 性能提升 |
7B | 2 | 0.37 | 0.45 | 21 |
13B | 2 | 0.35 | 0.43 | 23 |
34B | 2 | 0.32 | 0.38 | 19 |
70B | 2 | 0.32 | 0.38 | 19 |
表 2:在具有 3.2Tbps 节点间互连的 464 个 H100 80GB GPU 上,Llama2 模型架构在编译和非编译模式下的 MFU 结果
我们还在 448 个 GPU 上使用 Llama2 7B 架构进行了一次内部生产运行。使用编译和选择性激活检查点,全局批处理大小为 3.7M,我们在 13 天 10 小时内训练了 4T token!
训练期间,数据中心不得不启动额外的空调进行冷却,我们的训练团队收到了警报,因为我们非常有效地利用了 GPU ☺
从表 1 和表 2 可以得出的一个关键观察是,MFU 数字不会随着模型尺寸线性增长。我们正在积极调查两种可能的解释,一种是 FSDP 随着模型尺寸增加的可扩展性,以及何时需要启用张量并行(tensor parallel)以更有效地利用 GPU;另一种是批处理大小,可以进一步增加以获得更好的 MFU。我们计划探索 FSDP v2 和选择性操作符检查点(selective operator checkpointing)以及张量并行特性,以研究 FSDP 随模型尺寸变化的缩放定律。
未来工作
我们计划开始测试将作为 PyTorch 2.4 一部分发布的 FSDP v2。FSDP2 提供了按参数分片和选择性操作符检查点功能,可能提供更好的内存-计算权衡。
我们还与 Meta 的 PyTorch 团队合作,评估新的异步检查点功能,该功能可以通过减少写入检查点的时间来进一步提高 GPU 利用率。
我们正在探索扩展目前用于推理的各种 Triton 内核以执行反向操作,从而在推理之外获得加速。
最后,随着 fp8 使用的最新研究出现,我们计划探索如何使用这种承诺 2 倍加速的新数据类型进一步加速模型训练。
致谢
有几个团队参与了达到这一验证点的工作,我们要感谢 Meta 和 IBM 的各个团队。特别地,我们向 Meta 的 PyTorch 分布式和编译器团队以及 IBM 研究院表示感谢。
多人为实现 torch.compile 与我们的模型数值一致性付出了巨大努力,我们希望感谢参与此项工作的关键人员:Meta 的 Animesh Jain 和 Less Wright,以及 IBM 研究院的 Linsong Chu、Davis Wertheimer、Brian Vaughan、Antoni i Viros Martin、Mudhakar Srivatsa 和 Raghu Ganti。
特别感谢 Stas Bekman,他提供了广泛的反馈并帮助改进了本博客。他在突出优化训练的关键方面和探索进一步增强方面的见解是无价的。