在本篇博客中,我们以一个预训练示例(一个训练了 2T token 的 7B 模型)展示了 FSDP 的可扩展性,并分享了我们为实现 3,700 token/秒/GPU 的快速训练速度(即在 128 个 A100 GPU 上每天 40B token)而使用的各种技术。这相当于模型 FLOPS 利用率 (MFU) 和硬件 FLOPS 利用率 (HFU) 达到 57%。此外,我们观察到 FSDP 近乎线性地扩展到 512 个 GPU,这意味着使用这种方法在 512 个 GPU 上训练 7B 模型到 2T token 仅需不到两周时间。
IBM 研究人员将 Meta Llama 2 7B 架构训练到 2T token,我们将其称为 LlamaT(est)。该模型在各种学术基准测试中表现出与 Llama 2 相当的模型质量。所有训练代码,以及我们实现此吞吐量的方法,都可以在本博客中找到。我们还分享了适用于 A100 和 H100 的 Llama 2 模型(7B、13B、34B 和 70B)的良好配置。
在此过程中,我们还提出了一种应用于 FSDP 的_新_选择性激活检查点机制,它在开箱即用的 FSDP 基础上提供了 10% 的提升。我们已将训练代码库和相关的可扩展数据加载器开源,作为实现此吞吐量的方法。
PyTorch 原生训练路径的一个主要优势是能够无缝地在多个硬件后端上进行训练。例如,AllenAI 通过 OLMo 发布的新近端到端训练栈也利用 PyTorch FSDP 在 AMD 和 NVIDIA GPU 上进行训练。我们利用 FSDP 的三个主要组件来实现我们的吞吐量:
- SDPA Flash Attention,实现融合的注意力内核和高效的注意力计算
- 计算与通信的重叠,允许更好地利用 GPU
- 选择性激活检查点,使我们能够在 GPU 内存和计算之间进行权衡
IBM 近两年来一直与 Meta 的 PyTorch 团队密切合作PyTorch FSDP:引入速率限制器以在以太网互连上实现更好的吞吐量,分布式检查点以将检查点时间提高一个数量级,以及为 FSDP 的混合分片模式实现早期版本的检查点。去年年底,我们使用 FSDP 端到端地训练了一个模型。
训练详情
7B 模型在 128 个 A100 GPU 上训练,具有 400Gbps 网络连接和 GPU Direct RDMA。我们使用 SDPA FlashAttention v2 进行注意力计算,对于这个模型,我们关闭了限制批量大小但提供最高吞吐量的激活检查点——对于 128 个 GPU,每批次批量大小为 100 万 token,与激活检查点相比,吞吐量提高了约 10%。在这些参数下,我们的计算和通信几乎完全重叠。我们使用 32 位 AdamW 优化器,beta1 为 0.9,beta2 为 0.95,权重衰减为 0.1,学习率最终为 3e-5,预热至最大学习率 3e-4,并采用余弦调度在 2T token 内降至 3e-5。训练使用混合精度 bf16 在内部数据集上进行。训练栈使用 IBM 的基础模型栈进行模型架构,并使用 PyTorch nightlies 2.2 版本后用于 FSDP 和 SDPA。我们在 2023 年 11 月到 2024 年 2 月期间尝试了几种不同的 nightlies,并观察到吞吐量有所改善。
选择性激活检查点
我们共同实现了一个简单有效的选择性激活检查点 (AC) 机制。在 FSDP 中,常见的做法是检查点每个 Transformer 块。一个简单的扩展是每 n 个块检查点一次,以减少重新计算量,同时增加所需的内存。这对于 13B 模型大小非常有效,吞吐量增加了 10%。对于 7B 模型大小,我们根本不需要激活检查点。FSDP 的未来版本将提供操作符级别的选择性激活检查点,从而实现最佳的计算-内存权衡。上述代码已在此处实现。
吞吐量和 MFU、HFU 计算
虽然我们只将 7B 模型训练到 2T tokens,但我们在其他模型大小上进行了大量实验,以提供最佳配置选项。下表总结了两种类型的基础设施的配置——一个拥有 128 个 GPU 和 400Gbps 节点间互连的 A100 集群,以及一个拥有 96 个 GPU 和 800Gbps 节点间互连的 H100 集群。
模型大小 | 批次大小 | 激活检查点 | 吞吐量 tokens/秒/GPU (A100 80GB 和 400Gbps 互连) | MFU % (A100 80GB) | HFU % (A100 80GB) | 吞吐量 tokens/秒/GPU (H100 80GB 和 800Gbps 互连) | MFU % (H100 80GB) | HFU % (H100 80GB) |
7B | 2 | 否 | 3700 | 0.57 | 0.57 | 7500 | 0.37 | 0.37 |
13B | 2 | 选择性 | 1800 | 0.51 | 0.59 | 3800 | 0.35 | 0.40 |
34B | 2 | 是 | 700 | 0.47 | 0.64 | 1550 | 0.32 | 0.44 |
70B | 2 | 是 | 370 | 0.50 | 0.67 | 800 | 0.34 | 0.45 |
表 1:各种模型大小在 A100 和 H100 GPU 上的模型和硬件 FLOPS 利用率
HFU 数字使用PyTorch FLOP 计数器和 A100 和 H100 GPU 的理论 bf16 性能计算,而 MFU 数字使用NanoGPT和PaLM 论文中概述的方法计算。我们还注意到,我们用于较大模型的批量大小有意保持为每个 GPU 2,以模仿训练 4k 序列长度模型时所做的选择,并在不超过 4M tokens 的常用批量大小的情况下,将其实现到 512 个 GPU。除此之外,我们将需要张量并行或序列并行。
我们注意到上表中,对于 A100,激活重新计算导致 MFU 降低,而 HFU 增加!随着更好的激活检查点方案的引入,我们期望 MFU 能够增加并赶上 HFU。然而,我们观察到对于 H100,MFU 和 HFU 都相对较低。我们分析了 H100 上的 PyTorch 配置文件跟踪,并观察到由于网络“窥探”而导致 10% 的差距。此外,我们假设 H100 的 HBM 带宽是 H100 上 HFU/MFU 降低的原因,并且未能获得 3 倍的改进(H100 在理论上比 A100 快 3 倍——312 vs 989 TFLOPS,但 HBM 带宽仅比 A100 少于 2 倍——2.0 vs 3.35 TBps)。我们计划尝试其他配置选项,例如张量并行,以改善 H100 上 70B 模型的参数。
模型详情
训练的损失曲线如下图所示。

图 1:LlamaT 训练损失曲线
2T 检查点通过存储库中提供的脚本转换为 Hugging Face 格式,然后我们使用lm-evaluation-harness计算关键学术基准,并将其与 Llama2-7B 的运行结果进行比较。这些结果已捕获在下表中。
评估指标 | Llama2-7B (基线) | LlamaT-7B |
MMLU (零样本) | 0.41 | 0.43 |
MMLU (5样本加权平均) | 0.47 | 0.50 |
Arc 挑战 | 0.46 | 0.44 |
Arc 简单 | 0.74 | 0.71 |
Boolq | 0.78 | 0.76 |
Copa | 0.87 | 0.83 |
Hellaswag | 0.76 | 0.74 |
Openbookqa | 0.44 | 0.42 |
Piqa | 0.79 | 0.79 |
Sciq | 0.91 | 0.91 |
Winogrande | 0.69 | 0.67 |
Truthfulqa | 0.39 | 0.39 |
GSM8k (8样本) | 0.13 | 0.11 |
表 1:LM 评估基准分数
我们观察到该模型与 Llama2 相比具有竞争力(越粗体越好)。
训练纪实
训练过程稳定,没有崩溃,尽管我们确实观察到一些小插曲
0-200B token:我们观察到迭代时间(执行一个训练步骤所需的时间)有所减慢。我们停止了作业,以确保数据加载器没有造成任何减速,并且检查点性能良好且准确。我们没有发现任何问题。此时,HSDP 检查点代码已在 PyTorch 中可用,我们借此机会切换到 PyTorch 检查点代码。
200B token-1.9T:我们在 12 月下旬没有对作业进行任何手动干预。1 月初回来时,磁盘空间已超出,检查点写入失败,尽管训练作业仍在继续。最后一个已知检查点是 1.5T。
1.5T-1.7T:我们使用 lm-evaluation-harness 评估了 1.5T 检查点,发现由于 Hugging Face 分词器引入了分隔符 token,并且我们的数据加载器也附加了自己的文档分隔符,模型在两个文档之间训练时多了一个特殊 token。我们修改了数据加载器以消除多余的特殊 token,并从 1.7T token 开始使用修改后的数据加载器继续训练。
1.7T-2T:由于特殊 token 的变化,损失最初飙升,但在数十亿 token 内迅速恢复。训练在没有任何其他人工干预的情况下完成了!
主要收获和更快的速度
我们演示了如何使用 FSDP 将模型训练到 2T token,并实现了 3700 token/秒/GPU 的出色性能,生成了高质量的模型。作为此项工作的一部分,我们开源了所有训练代码以及实现此吞吐量的参数。这些参数不仅可以用于大规模运行,还可以用于小规模微调运行。您可以在此处找到代码。
FSDP API 以 PyTorch 原生方式实现了ZeRO算法,并允许对大型模型进行调优和训练。过去,我们已经看到 FSDP 的验证点(Stanford Alpaca、Hugging Face、Llama 2 recipes)在各种 LLM(例如 Meta Llama 2 7B 到 70B Llama)的调优中使用简单的训练循环,并实现了良好的吞吐量和训练时间。
最后,我们注意到有几个可以加快训练速度的杠杆
- 节点优化,可以加速特定操作(例如,使用 Flash Attention V2 进行注意力计算)
- 图优化(例如,融合内核,torch.compile)
- 计算-通信重叠
- 激活重计算
我们在这篇博客中利用了第 1、3 和第 4 种优化,并正在与 Meta 的 PyTorch 团队密切合作,以实现 torch.compile (2) 和更高级的第 4 种优化,即按操作符的选择性激活重计算。我们计划分享一个简单的格式化代码和示例数据,以便其他人可以使用我们的代码库进行模型训练。
致谢
有几个团队参与了实现这一验证点,我们谨此感谢 Meta 和 IBM 的各个团队。特别地,我们感谢 PyTorch 分布式团队、Facebook 研究和应用 AI 团队,他们构建了FSDP API并根据我们的反馈进行了增强。我们还要感谢 IBM 研究院的数据团队,他们策划了本次练习中使用的数据语料库,以及 IBM 研究院的基础设施团队(特别是 Claudia Misale、Shweta Salaria 和 Seetharami Seelam),他们优化了 NCCL 和网络配置。通过构建和利用所有这些组件,我们成功地展示了 LlamaT 的验证点。
选择性激活检查点由 IBM 的 Linsong Chu、Davis Wertheimer、Mudhakar Srivatsa 和 Raghu Ganti 构思,并由 Meta 的 Less Wright 实现。
特别感谢Stas Bekman和Minjia Zhang,他们提供了大量的反馈并帮助改进了这篇博客。他们的见解对于突出优化训练的关键方面和探索进一步的增强至关重要。
附录
通信计算重叠
多节点训练的另一个关键方面是通信和计算的重叠能力。在 FSDP 中,有多种重叠机会——在前向传播中的 FSDP 单元收集阶段以及反向传播计算阶段。在前向传播期间将收集与上一个单元的计算重叠,并将反向计算与下一个单元的收集和梯度散射重叠,有助于将 GPU 利用率提高近 2 倍。我们将在 400Gbps 网络互连和 A100 80GB GPU 上演示这一点。在 HSDP 的情况下,前向传播的预取阶段没有节点间流量,重叠仅限于反向梯度计算阶段。当然,HSDP 只有当模型可以在单个节点内分片时才可行,这限制了模型的大小约为 30B 参数。
下图显示了 FSDP 的三个步骤,底部是节点之间的通信,图片后半部分的顶部是计算流。对于没有激活重新计算的 7B 模型,我们观察到重叠是完整的。实际上,可能的重叠百分比是 90%,因为前向传播的第一个块和反向传播的最后一个块无法重叠。

下面显示了上述三步过程的单步放大视图。我们可以清楚地看到计算和通信的粒度以及它们如何以交错的方式重叠。
