近期研究表明,大规模模型训练有助于提升模型质量。在过去三年中,模型规模增长了 10,000 倍,从拥有 1.1 亿参数的 BERT 发展到拥有 1 万亿参数的 Megatron-2。然而,训练大型人工智能模型并非易事——除了需要海量计算资源外,软件工程的复杂性也是一项严峻挑战。PyTorch 一直致力于构建相关工具和基础设施以简化这一过程。
PyTorch 分布式数据并行因其稳健性和简洁性,已成为可扩展深度学习的支柱。然而,它要求模型必须能够完整放入单个 GPU 中。DeepSpeed ZeRO 和 FairScale 的完全分片数据并行 (FSDP) 等近期方法允许我们将模型的参数、梯度和优化器状态分片到各个数据并行工作进程中,从而打破了这一限制,同时保持了数据并行原有的简洁性。
随着 PyTorch 1.11 的发布,我们增加了对完全分片数据并行 (FSDP) 的原生支持(目前作为原型功能提供)。其实现很大程度上借鉴了 FairScale 的版本,同时带来了更精简的 API 和额外的性能改进。
在 AWS 上对 PyTorch FSDP 进行的扩展性测试显示,它可以扩展至训练 1 万亿参数的密集模型。在我们的实验中,GPT 1T 模型在 AWS 集群上达到了每个 A100 GPU 84 TFLOPS 的性能,GPT 175B 模型则达到了每个 A100 GPU 159 TFLOPS。与 FairScale 的原始版本相比,原生 FSDP 实现还在开启 CPU 卸载 (CPU offloading) 时显著缩短了模型初始化时间。
在未来的 PyTorch 版本中,我们将支持用户在 DDP、ZeRO-1、ZeRO-2 和 FSDP 等数据并行模式之间无缝切换,以便用户能够通过统一 API 中的简单配置来训练不同规模的模型。
FSDP 的工作原理
FSDP 是一种数据并行训练方式。与传统的数据并行(在每个 GPU 上维护模型参数、梯度和优化器状态的副本)不同,它将所有这些状态分片到各个数据并行工作进程中,并可选择将分片后的模型参数卸载到 CPU 上。
下图展示了 FSDP 如何针对 2 个数据并行进程进行工作

图 1. FSDP 工作流程
通常情况下,模型层是以嵌套方式用 FSDP 包裹的,这样只有单个 FSDP 实例中的层需要在前向或反向计算期间将完整参数收集到单个设备上。收集后的完整参数会在计算完成后立即释放,释放出的内存可用于下一层的计算。通过这种方式,可以节省 GPU 的峰值内存占用,从而支持训练更大规模的模型或使用更大的批次大小。为了进一步最大化内存效率,当实例未处于计算活跃状态时,FSDP 可以将参数、梯度和优化器状态卸载到 CPU 上。
在 PyTorch 中使用 FSDP
有两种方式可以使用 PyTorch FSDP 来包裹模型。自动包裹 (Auto wrapping) 是 DDP 的直接替代方案;手动包裹 (Manual wrapping) 仅需对模型定义代码进行最小幅度的更改,并具备探索复杂分片策略的能力。
自动包裹
为了节省峰值内存并实现通信与计算的重叠,模型层应以嵌套方式包裹在 FSDP 中。最简单的方法是自动包裹,它可以作为 DDP 的直接替代品,而无需更改其余代码。
fsdp_auto_wrap_policy 参数允许指定一个可调用函数来递归地包裹层。PyTorch FSDP 提供的 default_auto_wrap_policy 函数会递归地包裹参数数量大于 1 亿的层。您可以根据需要提供自己的包裹策略。编写自定义包裹策略的示例请参阅 FSDP API 文档。
此外,可以配置 cpu_offload 选项,以便在计算中未使用这些参数时将包裹的参数卸载到 CPU 上。这能以主机与设备之间数据传输开销为代价,进一步提高内存效率。
以下示例展示了如何使用自动包裹来包裹 FSDP。
from torch.distributed.fsdp import (
FullyShardedDataParallel,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
default_auto_wrap_policy,
)
import torch.nn as nn
class model(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(8, 4)
self.layer2 = nn.Linear(4, 16)
self.layer3 = nn.Linear(16, 4)
model = DistributedDataParallel(model())
fsdp_model = FullyShardedDataParallel(
model(),
fsdp_auto_wrap_policy=default_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
)
手动包裹
当需要通过对模型的部分组件选择性应用 wrap 来探索复杂分片策略时,手动包裹非常有用。整体设置可以传递给 enable_wrap() 上下文管理器。
from torch.distributed.fsdp import (
FullyShardedDataParallel,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
enable_wrap,
wrap,
)
import torch.nn as nn
from typing import Dict
class model(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = wrap(nn.Linear(8, 4))
self.layer2 = nn.Linear(4, 16)
self.layer3 = wrap(nn.Linear(16, 4))
wrapper_kwargs = Dict(cpu_offload=CPUOffload(offload_params=True))
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
fsdp_model = wrap(model())
使用上述两种方法之一将模型用 FSDP 包裹后,就可以像进行本地训练一样训练模型,如下所示:
optim = torch.optim.Adam(fsdp_model.parameters(), lr=0.0001)
for sample, label in next_batch():
out = fsdp_model(input)
loss = criterion(out, label)
loss.backward()
optim.step()
基准测试结果
我们使用 PyTorch FSDP 在 AWS 集群上对 175B 和 1T 的 GPT 模型进行了详尽的扩展性测试。每个集群节点都是一个配置有 8 个 NVIDIA A100-SXM4-40GB GPU 的实例,节点间通过 AWS EFA (Elastic Fabric Adapter) 以 400 Gbps 的网络带宽互联。
GPT 模型使用 minGPT 实现。基准测试使用了随机生成的输入数据集。所有实验均在词汇表大小为 50K、fp16 精度和 SGD 优化器配置下运行。
| 模型 | 层数 | 隐藏层大小 | 注意力头数 | 模型大小(十亿参数) |
|---|---|---|---|---|
| GPT 175B | 96 | 12288 | 96 | 175 |
| GPT 1T | 128 | 25600 | 160 | 1008 |
除了在实验中使用参数 CPU 卸载的 FSDP 外,测试中还应用了 PyTorch 的 激活检查点 (activation checkpointing) 功能。
对于 GPT 175B 模型,在 128 个 GPU 上,使用批次大小 20 和序列长度 512 时,达到了 159 teraFLOP/s 的最大单 GPU 吞吐量(占 NVIDIA A100 理论峰值性能 312 teraFLOP/s/GPU 的 51%);进一步增加 GPU 数量会导致节点间通信增加,从而导致单 GPU 吞吐量下降。
对于 GPT 1T 模型,在 128 个 GPU 上,使用批次大小 4 和序列长度 2048 时,达到了 84 teraFLOP/s 的最大单 GPU 吞吐量(占峰值性能的 27%)。然而,进一步增加 GPU 数量对单 GPU 吞吐量影响不大,因为我们观察到 1T 模型训练中的最大瓶颈并非来自通信,而是当 GPU 峰值内存达到极限时,CUDA 缓存分配器的速度过慢。使用具有更大内存容量的 A100 80G GPU 将在很大程度上解决这一问题,并有助于增加批次大小以实现更高的吞吐量。


未来工作
在下一个 Beta 版本中,我们计划添加高效的分布式模型/状态检查点 API、用于大模型初始化的元设备 (meta device) 支持,以及在 FSDP 计算和通信内部的混合精度支持。我们还将使在新的 API 中切换 DDP、ZeRO1、ZeRO2 和 FSDP 数据并行模式变得更加容易。为了进一步提升 FSDP 性能,我们还计划进行内存碎片整理和通信效率优化。
FSDP 两个版本的简史
FairScale FSDP 于 2021 年初作为 FairScale 库的一部分发布。随后,我们启动了将 FairScale FSDP 上游合入 PyTorch 1.11 的工作,使其达到生产就绪状态。我们有选择地对 FairScale FSDP 的关键功能进行了合入和重构,重新设计了用户界面,并提升了性能。
在不久的将来,FairScale FSDP 将保留在 FairScale 仓库中供研究项目使用,而通用且被广泛采用的功能将逐步合入 PyTorch 并进行相应的强化。
同时,PyTorch FSDP 将更加专注于生产就绪性和长期支持。这包括与生态系统的更好集成,以及在性能、可用性、可靠性、可调试性和可组合性方面的持续改进。
致谢
我们感谢 FairScale FSDP 的作者:Myle Ott, Sam Shleifer, Min Xu, Priya Goyal, Quentin Duval, Vittorio Caggiano, Tingting Markstrum, Anjali Sridhar。感谢微软 DeepSpeed ZeRO 团队开发并推广了分片数据并行技术。感谢 Pavel Belevich, Jessica Choi, Sisil Mehta 在不同集群上运行使用 PyTorch FSDP 的实验。感谢 Geeta Chauhan, Mahesh Yadav, Pritam Damania, Dmytro Dzhulgakov 为这项工作提供的支持及富有见地的讨论。