作者:IBM 的 PyTorch 团队和 Meta 的 PyTorch 团队

在本博客中,我们展示了 FSDP 的可扩展性,并以一个预训练示例、一个针对 2T tokens 训练的 7B 模型为例,分享了我们用于实现 3,700 tokens/秒/GPU 的快速训练速度或在 128 个 A100 GPU 上每天 40B tokens 的各种技术。这转化为 57% 的模型 FLOPS 利用率 (MFU) 和硬件 FLOPS 利用率 (HFU)。此外,我们还观察到 FSDP 近似线性扩展到 512 个 GPU,这意味着使用此方法在 512 个 GPU 上将 7B 模型训练到 2T tokens 仅需不到两周的时间。

IBM 研究人员将 Meta Llama 2 7B 架构训练到 2T tokens,我们将其称为 LlamaT(est)。该模型在各种学术基准测试中表现出与 Llama 2 相当的模型质量。所有训练代码,以及我们实现此吞吐量的方法,都可以在本博客中找到。我们还分享了适用于 Llama 2 模型(7B、13B、34B 和 70B,用于 A100 和 H100)的配置旋钮。

在此过程中,我们还提出了一种 _新的_ 选择性激活检查点机制,该机制适用于 FSDP,使我们在开箱即用的 FSDP 基础上提高了 10% 的性能。我们已开源了训练代码库和一个相关的可扩展数据加载器,作为实现此吞吐量的方法。

PyTorch 原生训练路径的一个主要优势是能够在多个硬件后端上无缝训练。例如,AllenAI 通过 OLMo 发布的最新端到端训练堆栈也利用 PyTorch FSDP 在 AMD 和 NVIDIA GPU 上进行训练。我们利用 FSDP 实现吞吐量主要有三个组成部分

  1. SDPA Flash attention,它支持融合的注意力内核和高效的注意力计算
  2. 重叠计算和通信,从而更好地利用 GPU
  3. 选择性激活检查点使我们能够在 GPU 内存和计算之间进行权衡

IBM 与 Meta 的 PyTorch 团队在 PyTorch FSDP 方面密切合作了近两年:推出了速率限制器,以在以太网互连上实现更高的吞吐量;分布式检查点,将检查点时间缩短一个数量级;并为 FSDP 的混合分片模式实现了早期版本的检查点。去年年底,我们使用 FSDP 端到端地训练了一个模型。

训练详情

7B 模型在具有 400Gbps 网络连接和 GPU Direct RDMA 的 128 个 A100 GPU 上进行训练。我们使用 SDPA FlashAttention v2 进行注意力计算,对于此模型,我们关闭了限制批量大小的激活检查点,但提供了最高的吞吐量 – 对于 128 个 GPU,批量大小为每批 100 万个 tokens,与激活检查点相比,吞吐量提高了约 10%。使用这些参数,我们在计算和通信方面几乎完全重叠。我们使用 32 位 AdamW 优化器,beta1 为 0.9,beta2 为 0.95,权重衰减为 0.1,学习率最终为 3e-5,预热至最大学习率为 3e-4,并使用余弦调度在 2T tokens 上降至 3e-5。训练使用混合精度 bf16 在内部数据集上执行。训练堆栈使用 IBM 的 Foundation Model Stack 作为模型架构,并使用 PyTorch nightly 版本(2.2 发布后)作为 FSDP 和 SDPA。我们在 2023 年 11 月至 2024 年 2 月期间尝试了几个不同的 nightly 版本,并观察到吞吐量有所提高。

选择性激活检查点

我们共同实施了一种简单有效的选择性激活检查点 (AC) 机制。在 FSDP 中,常见的做法是检查每个 Transformer 块。一个简单的扩展是检查每 _n_ 个块,并减少重新计算量,同时增加所需的内存。这对于 13B 模型大小非常有效,吞吐量提高了 10%。对于 7B 模型大小,我们根本不需要激活检查点。未来版本的 FSDP 将在运算符级别提供选择性激活检查点,从而实现最佳的计算-内存权衡。以上代码在此处实现:此处

吞吐量和 MFU、HFU 计算

虽然我们只将 7B 模型训练到 2T tokens,但我们对其他模型大小进行了大量实验,以提供最佳配置选项。下表总结了两种类型的基础架构的配置选项:一个具有 128 个 GPU 和 400Gbps 节点间互连的 A100 集群,以及一个具有 96 个 GPU 和 800Gbps 节点间互连的 H100 集群。

模型大小 批量大小 激活检查点 吞吐量 tokens/秒/GPU(A100 80GB 和 400Gbps 互连) MFU %(A100 80GB) HFU %(A100 80GB) 吞吐量 tokens/秒/GPU(H100 80GB 和 800Gbps 互连) MFU %(H100 80GB) HFU %(H100 80GB)
7B 2 3700 0.57 0.57 7500 0.37 0.37
13B 2 选择性 1800 0.51 0.59 3800 0.35 0.40
34B 2 700 0.47 0.64 1550 0.32 0.44
70B 2 370 0.50 0.67 800 0.34 0.45

表 1:A100 和 H100 GPU 上各种模型大小的模型和硬件 FLOPS 利用率

HFU 数字是使用 PyTorch FLOP 计数器以及 A100 和 H100 GPU 的理论 bf16 性能计算得出的,而 MFU 数字是使用 NanoGPTPaLM 论文中概述的方法计算得出的。我们还注意到,对于较大的模型,我们有意将批量大小保持在每个 GPU 2 个,以模拟训练 4k 序列长度模型时所做的选择,并在高达 512 个 GPU 的情况下实现此目标,而不会超过 4M tokens 的常用批量大小。除此之外,我们将需要张量并行或序列并行。

我们在上表中注意到,对于 A100,激活重新计算会导致 MFU 降低,而 HFU 升高!随着更好的激活检查点方案的引入,我们预计 MFU 会增加并赶上 HFU。但是,我们观察到,对于 H100,MFU 和 HFU 都相对较低。我们分析了 H100 上的 PyTorch 配置文件跟踪,并观察到由于网络“窥视”而存在 10% 的差距。此外,我们假设 H100 的 HBM 带宽是 H100 上 HFU/MFU 降低的原因,并且无法获得 3 倍的改进(H100 在理论上比 A100 快 3 倍 - 312 与 989TFLOPS,但 HBM 带宽仅为 A100 的 <2 倍 - 2.0 与 3.35TBps)。我们计划尝试其他配置选项,如张量并行,以改进 H100 上 70B 模型的旋钮。

模型详情

下图显示了训练的损失曲线。

loss curve for training

图 1:LlamaT 训练损失曲线

2T 检查点由存储库中提供的脚本转换为 Hugging Face 格式,然后我们使用 lm-evaluation-harness 计算关键学术基准测试,并通过在 Llama2-7B 上运行来比较结果。这些结果在下表中捕获。

评估指标 Llama2-7B(基线) LlamaT-7B
MMLU(零样本) 0.41 0.43
MMLU(5 次样本加权平均) 0.47 0.50
Arc challenge 0.46 0.44
Arc easy 0.74 0.71
Boolq 0.78 0.76
Copa 0.87 0.83
Hellaswag 0.76 0.74
Openbookqa 0.44 0.42
Piqa 0.79 0.79
Sciq 0.91 0.91
Winogrande 0.69 0.67
Truthfulqa 0.39 0.39
GSM8k(8 次样本) 0.13 0.11

表 1:LM eval harness 分数

我们观察到该模型的性能与 Llama2 相当(粗体字更好)。

训练编年史

训练稳定,没有崩溃,但我们确实观察到一些小问题

0-200B tokens:我们观察到迭代时间(执行一个训练步骤所花费的时间)变慢。我们停止了作业,以确保数据加载器没有造成任何减速,并且检查点是高性能且准确的。我们没有发现任何问题。此时,HSDP 检查点代码已在 PyTorch 中可用,我们借此机会切换到 PyTorch 检查点代码。

200B tokens-1.9T:我们在 12 月下旬没有对作业进行任何手动干预。当我们在 1 月初回来时,磁盘空间已超出,检查点写入失败,尽管训练作业仍在继续。上次已知的检查点是 1.5T。

1.5T-1.7T:我们使用 lm-evaluation-harness 评估了 1.5T 检查点,并发现由于 Hugging Face 分词器引入了分隔符 tokens,并且我们的数据加载器也附加了自己的文档分隔符,因此模型已使用两个文档之间的额外特殊 tokens 进行了训练。我们修改了数据加载器以消除额外的特殊 tokens,并从 1.7T tokens 开始继续使用修改后的数据加载器进行训练。

1.7T-2T:由于特殊 tokens 的更改,损失最初飙升,但在数十亿 tokens 中迅速恢复。训练在没有任何其他手动干预的情况下完成!

主要收获和更快的速度

我们演示了如何使用 FSDP 将模型训练到 2T tokens,并具有 3700 tokens/秒/GPU 的出色性能,并生成高质量的模型。作为此练习的一部分,我们开源了所有用于训练的代码以及实现此吞吐量的旋钮。这些旋钮不仅可以用于大规模运行,还可以用于小规模调整运行。您可以在此处找到代码。

FSDP API 以 PyTorch 原生方式实现了 ZeRO 算法,并允许调整和训练大型模型。过去,我们已经看到 FSDP 的证明点(Stanford AlpacaHugging FaceLlama 2 食谱)关于使用简单的训练循环调整各种 LLM(例如 Meta Llama 2 7B 到 70B Llama),并实现良好的吞吐量和训练时间。

最后,我们注意到有几个加速训练的杠杆

  1. 可以加速特定操作的节点优化(例如,使用 Flash Attention V2 进行注意力计算)
  2. 图形优化(例如,融合内核,torch.compile)
  3. 计算-通信重叠
  4. 激活重新计算

我们在本博客中利用了 1、3 和 4 的变体,并正在与 Meta 的 PyTorch 团队密切合作,以获得 torch.compile (2) 以及更高级版本的 4,其中包含每个运算符的选择性激活重新计算。我们计划分享一个简单的格式化代码和示例数据,以便摄取到我们的数据加载器中,使其他人能够使用该代码库来训练模型。

致谢

有几个团队参与了实现这一证明点,我们要感谢 Meta 和 IBM 的团队。特别是,我们向 PyTorch 分布式团队、Facebook 研究团队和应用 AI 团队表示感谢,他们构建了 FSDP API 并根据我们的反馈进行了增强。我们还要感谢 IBM 研究院的数据团队,他们策划了本练习中使用的数据语料库,以及 IBM 研究院的基础架构团队(特别是 Claudia Misale、Shweta Salaria 和 Seetharami Seelam),他们优化了 NCCL 和网络配置。通过构建和利用所有这些组件,我们成功地演示了 LlamaT 证明点。

选择性激活检查点是由 IBM 的 Linsong Chu、Davis Wertheimer、Mudhakar Srivatsa 和 Raghu Ganti 概念化的,并由 Meta 的 Less Wright 实现。

特别感谢 Stas BekmanMinjia Zhang,他们提供了广泛的反馈并帮助改进了博客。他们的见解对于突出显示优化训练的关键方面和探索进一步增强功能非常宝贵。

附录

通信计算重叠

在多节点设置中进行训练的另一个关键方面是重叠通信和计算的能力。在 FSDP 中,有多个重叠机会 – 在前向传递期间的 FSDP 单元收集阶段以及后向传递计算期间。在前向传递期间重叠收集,同时计算前一个单元,并在后向计算与下一个单元收集和梯度分散重叠,这有助于将 GPU 利用率提高近 2 倍。我们在具有 A100 80GB GPU 的 400Gbps 网络互连上对此进行了说明。在 HSDP 的情况下,在前向传递的预取阶段没有节点间流量,并且重叠仅用于后向梯度计算阶段。当然,HSDP 仅在模型可以在单个节点内分片时才可行,从而将模型大小限制在大约 30B 参数。

下图显示了 FSDP 中的三个步骤,底部是节点之间的通信,图像后半部分的顶部是计算流。对于没有激活重新计算的 7B 模型,我们观察到重叠是完全的。实际上,可能的重叠百分比为 90%,因为前向传递期间的第一个块和后向传递期间的最后一个块无法重叠。

three steps in FSDP with the communication between nodes at the bottom and the compute stream at the top of the second half

下图显示了单个步骤的上述三步过程的放大视图。我们可以清楚地看到计算和通信的粒度,以及它们如何以交错的方式重叠。

zoomed in view of the above three-step process