作者:Yanli Zhao, Rohan Varma, Chien-Chin Huang, Shen Li, Min Xu, Alban Desmaison

最近的研究表明,大型模型训练有助于提高模型质量。在过去 3 年中,模型规模增长了 10,000 倍,从拥有 1.1 亿参数的 BERT 增长到拥有万亿参数的 Megatron-2。然而,训练大型 AI 模型并不容易——除了需要大量的计算资源外,软件工程的复杂性也极具挑战性。PyTorch 一直致力于构建工具和基础设施,以使其变得更容易。

PyTorch 分布式数据并行性因其鲁棒性和简单性而成为可扩展深度学习的基石。然而,它要求模型必须适合单个 GPU。最近的方法,如 DeepSpeed ZeRO 和 FairScale 的完全分片数据并行,通过在数据并行工作器之间分片模型的参数、梯度和优化器状态,同时仍保持数据并行性的简单性,从而打破了这一限制。

在 PyTorch 1.11 中,我们增加了对完全分片数据并行 (FSDP) 的原生支持,目前作为原型功能提供。其实现大量借鉴了 FairScale 的版本,同时带来了更精简的 API 和额外的性能改进。

PyTorch FSDP 在 AWS 上的扩展性测试表明,它可以扩展到训练具有 1 万亿参数的密集模型。在我们的实验中,在 AWS 集群上,对于 GPT 1T 模型,每个 A100 GPU 的实际性能达到 84 TFLOPS;对于 GPT 175B 模型,每个 A100 GPU 的实际性能达到 159 TFLOPS。与 FairScale 原始版本相比,原生 FSDP 实现还显著改善了模型初始化时间(当启用 CPU 卸载时)。

在未来的 PyTorch 版本中,我们将使用户能够无缝地在 DDP、ZeRO-1、ZeRO-2 和 FSDP 数据并行模式之间切换,以便用户可以在统一的 API 中通过简单的配置训练不同规模的模型。

FSDP 工作原理

FSDP 是一种数据并行训练,但与传统的保持模型参数、梯度和优化器状态的每个 GPU 副本的数据并行不同,它将所有这些状态分片到数据并行工作器上,并可以选择将分片后的模型参数卸载到 CPU。

下图展示了 FSDP 在 2 个数据并行进程中的工作方式

图 1. FSDP 工作流程

通常,模型层以嵌套方式用 FSDP 包装,这样只有单个 FSDP 实例中的层需要在前向或后向计算期间将完整参数收集到单个设备上。收集到的完整参数将在计算后立即释放,释放的内存可用于下一层的计算。通过这种方式,可以节省 GPU 峰值内存,从而可以将训练扩展到使用更大的模型规模或更大的批处理大小。为了进一步最大化内存效率,当 FSDP 实例在计算中不活跃时,可以将参数、梯度和优化器状态卸载到 CPU。

在 PyTorch 中使用 FSDP

有两种方法使用 PyTorch FSDP 包装模型。自动包装可以替代 DDP 直接使用;手动包装需要对模型定义代码进行最少的更改,并能够探索复杂的分片策略。

自动包装

模型层应以嵌套方式用 FSDP 包装,以节省峰值内存并启用通信与计算重叠。最简单的方法是自动包装,它可以直接替代 DDP,无需更改其余代码。

fsdp_auto_wrap_policy 参数允许指定一个可调用函数来递归地用 FSDP 包装层。PyTorch FSDP 提供的 default_auto_wrap_policy 函数递归地包装参数数量大于 100M 的层。您可以根据需要提供自己的包装策略。编写自定义包装策略的示例显示在FSDP API 文档中。

此外,可以选择配置 cpu_offload 以在这些参数未用于计算时将包装的参数卸载到 CPU。这可以进一步提高内存效率,但会增加主机和设备之间的数据传输开销。

下面的示例展示了如何使用自动包装来包装 FSDP。

from torch.distributed.fsdp import (
   FullyShardedDataParallel,
   CPUOffload,
)
from torch.distributed.fsdp.wrap import (
   default_auto_wrap_policy,
)
import torch.nn as nn
 
class model(nn.Module):
   def __init__(self):
       super().__init__()
       self.layer1 = nn.Linear(8, 4)
       self.layer2 = nn.Linear(4, 16)
       self.layer3 = nn.Linear(16, 4)
 
model = DistributedDataParallel(model())
fsdp_model = FullyShardedDataParallel(
   model(),
   fsdp_auto_wrap_policy=default_auto_wrap_policy,
   cpu_offload=CPUOffload(offload_params=True),
)

手动包装

手动包装可用于探索复杂的分片策略,方法是选择性地将 wrap 应用于模型的某些部分。整体设置可以传递给 enable_wrap() 上下文管理器。

from torch.distributed.fsdp import (
   FullyShardedDataParallel,
   CPUOffload,
)
from torch.distributed.fsdp.wrap import (
   enable_wrap,
   wrap,
)
import torch.nn as nn
from typing import Dict
 
 
class model(nn.Module):
   def __init__(self):
       super().__init__()
       self.layer1 = wrap(nn.Linear(8, 4))
       self.layer2 = nn.Linear(4, 16)
       self.layer3 = wrap(nn.Linear(16, 4))
 
wrapper_kwargs = Dict(cpu_offload=CPUOffload(offload_params=True))
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
   fsdp_model = wrap(model())

使用上述两种方法之一用 FSDP 包装模型后,模型可以像本地训练一样进行训练,如下所示

optim = torch.optim.Adam(fsdp_model.parameters(), lr=0.0001)
for sample, label in next_batch():
  out = fsdp_model(input)
  loss = criterion(out, label)
  loss.backward()
  optim.step()

基准测试结果

我们使用 PyTorch FSDP 在 AWS 集群上对 175B 和 1T GPT 模型进行了广泛的扩展性测试。每个集群节点是配备 8 个 NVIDIA A100-SXM4-40GB GPU 的实例,节点之间通过 AWS Elastic Fabric Adapter (EFA) 连接,网络带宽为 400 Gbps。

GPT 模型使用 minGPT 实现。使用随机生成的输入数据集进行基准测试。所有实验均使用 50K 词汇量、fp16 精度和 SGD 优化器运行。

模型 层数 隐藏层大小 注意力头数量 模型大小,十亿参数
GPT 175B 96 12288 96 175
GPT 1T 128 25600 160 1008

除了在实验中使用带有参数 CPU 卸载的 FSDP 外,测试中还应用了 PyTorch 中的激活检查点功能

对于 GPT 175B 模型,在使用 128 个 GPU、批量大小 20、序列长度 512 的配置下,实现了每个 GPU 的最大吞吐量 159 teraFLOP/s(占 NVIDIA A100 峰值理论性能 312 teraFLOP/s/GPU 的 51%);进一步增加 GPU 数量会导致每个 GPU 的吞吐量下降,原因是节点间通信增加。

对于 GPT 1T 模型,在使用 128 个 GPU、批量大小 4、序列长度 2048 的配置下,实现了每个 GPU 的最大吞吐量 84 teraFLOP/s(占峰值 teraFLOP/s 的 27%)。然而,进一步增加 GPU 数量对每个 GPU 的吞吐量影响不大,因为我们观察到 1T 模型训练的最大瓶颈并非来自通信,而是来自当 GPU 峰值内存达到限制时缓慢的 CUDA 缓存分配器。使用具有更大内存容量的 A100 80G GPU 将主要解决这个问题,并有助于扩展批量大小以实现更大的吞吐量。

未来工作

在下一个 Beta 版本中,我们计划添加高效的分布式模型/状态检查点 API、用于大型模型具体化的元设备支持,以及 FSDP 计算和通信中的混合精度支持。我们还将在新 API 中使得在 DDPZeRO1、ZeRO2 和 FSDP 数据并行模式之间切换更加容易。为了进一步提高 FSDP 性能,还计划减少内存碎片和改进通信效率。

FSDP 两个版本的小历史

FairScale FSDP 于 2021 年初作为 FairScale 库的一部分发布。然后我们开始将 FairScale FSDP 上游到 PyTorch 1.11 的工作,使其具备生产就绪性。我们有选择地从 FairScale FSDP 上游并重构了关键特性,重新设计了用户界面并进行了性能改进。

在不久的将来,FairScale FSDP 将保留在 FairScale 仓库中用于研究项目,而通用且广泛采用的特性将逐步上游到 PyTorch 并相应地进行强化。

同时,PyTorch FSDP 将更侧重于生产就绪性和长期支持。这包括与生态系统的更好集成以及性能、可用性、可靠性、可调试性和可组合性的改进。

致谢

我们要感谢 FairScale FSDP 的作者:Myle Ott, Sam Shleifer, Min Xu, Priya Goyal, Quentin Duval, Vittorio Caggiano, Tingting Markstrum, Anjali Sridhar。感谢 Microsoft DeepSpeed ZeRO 团队开发和推广了分片数据并行技术。感谢 Pavel Belevich, Jessica Choi, Sisil Mehta 在不同集群上使用 PyTorch FSDP 运行实验。感谢 Geeta Chauhan, Mahesh Yadav, Pritam Damania, Dmytro Dzhulgakov 对这项工作的支持和富有见地的讨论。