由 IBM 的 PyTorch 团队和 Meta 的 PyTorch 团队联合撰写

在这篇博客中,我们演示了 FSDP 的可伸缩性,以一个预训练示例为例,这是一个训练了 2T token 的 7B 模型,并分享了我们用于实现 3,700 token/秒/GPU 的快速训练速度(或在 128 个 A100 GPU 上每天 40B token)的各种技术。这转化为 57% 的模型 FLOPS 利用率 (MFU) 和硬件 FLOPS 利用率 (HFU)。此外,我们观察到 FSDP 在扩展到 512 个 GPU 时接近线性,这意味着使用此方法在 512 个 GPU 上训练一个 7B 模型达到 2T token 仅需不到两周的时间。

IBM 研究人员将 Meta Llama 2 7B 架构训练到 2T token,我们将其称为 LlamaT(est)。该模型在各种学术基准测试上表现出与 Llama 2 相当的模型质量。所有训练代码以及实现此吞吐量的方法都可以在这篇博客中找到。我们还分享了适用于 Llama 2 模型(7B、13B、34B 和 70B)在 A100 和 H100 上的良好配置参数。

在此过程中,我们还提出了一种新的选择性激活检查点机制,该机制适用于 FSDP,可以在开箱即用的 FSDP 基础上提升 10%。我们已经开源了训练代码库和相关的可伸缩数据加载器,作为实现此吞吐量的方法。

PyTorch 原生训练路径的一个关键优势是能够在多种硬件后端上无缝训练。例如,AllenAI 最近通过 OLMo 发布的端到端训练堆栈也利用了 PyTorch FSDP 在 AMD 和 NVIDIA GPU 上进行训练。我们利用 FSDP 的三个主要组件来实现我们的吞吐量

  1. SDPA Flash attention,实现融合的注意力内核和高效的注意力计算
  2. 计算和通信的重叠允许更好地利用 GPU
  3. 选择性激活检查点使我们能够在 GPU 内存和计算之间进行权衡

IBM 近两年来一直与 Meta 的 PyTorch 团队就PyTorch FSDP紧密合作:引入速率限制器以在以太网互连上实现更好的吞吐量,分布式检查点以将检查点时间缩短一个数量级,并为 FSDP 的混合分片模式实现了早期版本的检查点。去年年底,我们使用 FSDP 端到端地训练了一个模型。

训练细节

7B 模型在 128 个 A100 GPU 上进行训练,具有 400Gbps 网络连接和 GPU Direct RDMA。我们使用 SDPA FlashAttention v2 进行注意力计算,对于此模型,我们关闭了限制批量大小但提供最高吞吐量的激活检查点——对于 128 个 GPU,每批次的批量大小为 100 万 token,与激活检查点相比,吞吐量提高了约 10%。使用这些参数,我们实现了计算和通信的几乎完全重叠。我们使用 32 位 AdamW 优化器,beta1 为 0.9,beta2 为 0.95,权重衰减为 0.1,学习率从 3e-4 热身到最大值,然后按照余弦调度在 2T token 内衰减到 3e-5。训练使用混合精度 bf16 在内部数据集上执行。训练堆栈使用了 IBM 的Foundation Model Stack作为模型架构,并使用了 PyTorch nightly 版本(2.2 发布后)来支持 FSDP 和 SDPA。我们在 2023 年 11 月至 2024 年 2 月期间尝试了几个不同的 nightly 版本,并观察到吞吐量的提升。

选择性激活检查点

我们共同实现了一种简单有效的选择性激活检查点 (AC) 机制。在 FSDP 中,通常的做法是对每个 Transformer 块进行检查点。一个简单的扩展是对每隔n个块进行检查点,减少重新计算量,同时增加所需的内存。这对于 13B 模型尺寸非常有效,可以将吞吐量提高 10%。对于 7B 模型尺寸,我们根本不需要激活检查点。未来版本的 FSDP 将在算子级别提供选择性激活检查点,从而实现计算-内存的最优权衡。上述代码在此处实现。

吞吐量和 MFU、HFU 计算

虽然我们只将 7B 模型训练到 2T token,但我们在其他模型尺寸上进行了大量实验,以提供最佳配置选项。下表总结了两种基础设施的配置选项——一个包含 128 个 GPU 和 400Gbps 节点间互连的 A100 集群,以及一个包含 96 个 GPU 和 800Gbps 节点间互连的 H100 集群。

模型尺寸 批量大小 激活检查点 吞吐量 tokens/秒/GPU (A100 80GB 和 400Gbps 互连) MFU % (A100 80GB) HFU % (A100 80GB) 吞吐量 tokens/秒/GPU (H100 80GB 和 800Gbps 互连) MFU % (H100 80GB) HFU % (H100 80GB)
7B 2 3700 0.57 0.57 7500 0.37 0.37
13B 2 选择性 1800 0.51 0.59 3800 0.35 0.40
34B 2 700 0.47 0.64 1550 0.32 0.44
70B 2 370 0.50 0.67 800 0.34 0.45

表 1:各种模型尺寸在 A100 和 H100 GPU 上的模型和硬件 FLOPS 利用率

HFU 数字是使用PyTorch FLOP 计数器以及 A100 和 H100 GPU 的理论 bf16 性能计算得出的,而 MFU 数字是使用NanoGPT中概述的方法和PaLM 论文计算得出的。我们还注意到,我们为较大模型使用的批量大小有意保持为每个 GPU 2,以模仿训练 4k 序列长度模型时的选择,并在不超过流行的 4M token 批量大小的情况下,将此实现到 512 个 GPU。在此之上,我们将需要张量并行或序列并行。

我们在上表中注意到,对于 A100,激活重新计算会导致 MFU 降低,而 HFU 增加!随着更好的激活检查点方案的引入,我们期望 MFU 增加并赶上 HFU。然而,我们观察到对于 H100,MFU 和 HFU 都相对较低。我们分析了 H100 上的 PyTorch profile 跟踪,并观察到由于网络“偷看”出去而导致 10% 的差距。此外,我们假设 H100 的 HBM 带宽是导致 H100 上 HFU/MFU 降低的原因,并且无法获得 3 倍的改进(H100 理论上比 A100 快 3 倍 - 312 vs 989 TFLOPS,但 HBM 带宽仅为 A100 的<2 倍 - 2.0 vs 3.35 TBps)。我们计划尝试其他配置选项,例如张量并行,以改进 H100 上 70B 模型的参数。

模型细节

训练的损失曲线如下图所示。

loss curve for training

图 1:LlamaT 训练损失曲线

2T 检查点通过仓库中提供的脚本转换为 Hugging Face 格式,然后我们使用lm-evaluation-harness计算关键学术基准,并将其与在 Llama2-7B 上运行的结果进行比较。这些结果记录在下表中。

评估指标 Llama2-7B (基线) LlamaT-7B
MMLU (零样本) 0.41 0.43
MMLU (5 样本加权平均) 0.47 0.50
Arc challenge 0.46 0.44
Arc easy 0.74 0.71
Boolq 0.78 0.76
Copa 0.87 0.83
Hellaswag 0.76 0.74
Openbookqa 0.44 0.42
Piqa 0.79 0.79
Sciq 0.91 0.91
Winogrande 0.69 0.67
Truthfulqa 0.39 0.39
GSM8k (8 样本) 0.13 0.11

表 1:LM 评估 harness 得分

我们观察到该模型与 Llama2 相比具有竞争力(加粗者更好)。

训练记事

训练稳定,没有崩溃,尽管我们确实遇到了一些小问题

0-200B token:我们观察到迭代时间(执行一个训练步骤所需的时间)变慢。我们停止了作业,以确保数据加载器没有导致任何速度下降,并且检查点性能良好且准确。我们没有发现任何问题。此时,PyTorch 中提供了 HSDP 检查点代码,我们借此机会切换到 PyTorch 检查点代码。

200B token-1.9T:我们在 12 月底没有对作业进行任何手动干预。1 月初回来时,磁盘空间已满,检查点写入失败,尽管训练作业仍在继续。最后一个已知检查点是 1.5T。

1.5T-1.7T:我们使用 lm-evaluation-harness 评估了 1.5T 检查点,发现由于 Hugging Face 分词器引入了一个分隔符 token 并且我们的数据加载器也添加了自己的文档分隔符,模型在两个文档之间用了一个额外的特殊 token 进行了训练。我们修改了数据加载器以消除额外的特殊 token,并从 1.7T token 开始使用修改后的数据加载器继续训练。

1.7T-2T:由于特殊 token 的变化,损失最初有所上升,但在几十亿 token 后很快恢复。训练顺利完成,没有其他手动干预!

主要收获和进一步提速

我们展示了如何使用 FSDP 将模型训练到 2T token,并实现了 3700 tokens/秒/GPU 的出色性能,生成了一个高质量模型。作为这项工作的一部分,我们开源了所有训练代码以及实现此吞吐量的参数。这些参数不仅可以用于大规模运行,还可以用于小规模微调运行。您可以在此处找到代码。

FSDP API 以 PyTorch 原生方式实现了ZeRO算法,并允许对大型模型进行微调和训练。过去,我们看到 FSDP 在微调各种 LLM(如 Meta Llama 2 7B 到 70B Llama)方面取得了验证(Stanford AlpacaHugging FaceLlama 2 recipe),通过简单的训练循环实现了良好的吞吐量和训练时间。

最后,我们注意到有几种方法可以加快训练速度

  1. 节点优化,可以加速特定操作(例如,使用 Flash Attention V2 进行注意力计算)
  2. 图优化(例如,融合内核,torch.compile)
  3. 计算-通信重叠
  4. 激活重新计算

我们在本博客中利用了第 1、3 项以及第 4 项的一种变体,并正与 Meta 的 PyTorch 团队密切合作,以引入 torch.compile(第 2 项)以及具有逐算子选择性激活重新计算功能的更高级的第 4 项。我们计划分享一个简单的格式化代码和示例数据,以便将其载入我们的数据加载器,使其他人能够使用该代码库进行模型训练。

致谢

有多个团队参与实现了这一验证点,我们在此感谢 Meta 和 IBM 的各团队。特别地,我们感谢构建FSDP API并根据我们的反馈进行了增强的 PyTorch 分布式团队、Facebook 研究团队和应用 AI 团队。我们还要感谢 IBM 研究院的数据团队,他们整理了本次工作中使用的数据集,以及 IBM 研究院的基础设施团队(特别是 Claudia Misale、Shweta Salaria 和 Seetharami Seelam),他们优化了 NCCL 和网络配置。通过构建和利用所有这些组件,我们成功地演示了 LlamaT 的验证点。

选择性激活检查点由 IBM 的 Linsong Chu、Davis Wertheimer、Mudhakar Srivatsa 和 Raghu Ganti 构思,并由 Meta 的 Less Wright 实现。

特别感谢Stas BekmanMinjia Zhang,他们提供了大量反馈并帮助改进了这篇博客。他们的见解对于突出训练优化的关键方面和探索进一步增强功能具有宝贵的价值。

附录

通信计算重叠

在多节点设置中进行训练的另一个关键方面是能够重叠通信和计算。在 FSDP 中,有多种重叠机会——在前向传播中的 FSDP 单元收集阶段以及后向传播计算阶段。在前向传播期间重叠收集与前一个单元的计算,以及在后向计算期间重叠下一个单元的收集和梯度分散,有助于将 GPU 利用率提高近 2 倍。我们以 400Gbps 网络互连和 A100 80GB GPU 为例进行说明。对于 HSDP,前向传播的预取阶段没有节点间流量,重叠仅限于后向梯度计算阶段。当然,HSDP 仅在模型可以在单个节点内进行分片时可行,这限制了模型的规模约为 30B 参数。

下图显示了 FSDP 的三个步骤,其中图像下半部分是节点之间的通信,上半部分是计算流。对于没有激活重新计算的 7B 模型,我们观察到重叠是完整的。实际上,可能实现的重叠百分比是 90%,因为在前向传播中的第一个块和后向传播中的最后一个块无法重叠。

three steps in FSDP with the communication between nodes at the bottom and the compute stream at the top of the second half

下图显示了上述三步过程的放大视图,仅针对一个步骤。我们可以清楚地看到计算和通信的粒度以及它们如何以交错的方式重叠。

zoomed in view of the above three-step process