跳转到主要内容
博客

PyTorch 全分片数据并行 (FSDP) API 简介

最近的研究表明,大型模型训练将有利于提高模型质量。在过去的 3 年中,模型规模增长了 10,000 倍,从参数量 1.1 亿的 BERT 到参数量万亿的 Megatron-2。然而,训练大型人工智能模型并非易事——除了需要大量计算资源外,软件工程复杂性也极具挑战性。PyTorch 一直致力于构建工具和基础设施,以简化这一过程。

PyTorch 分布式数据并行因其鲁棒性和简单性而成为可扩展深度学习的基石。但它要求模型能够适应单个 GPU。最近的方法,如 DeepSpeed ZeRO 和 FairScale 的完全分片数据并行,允许我们通过在数据并行工作器之间分片模型的参数、梯度和优化器状态来突破这一限制,同时仍然保持数据并行的简单性。

在 PyTorch 1.11 中,我们正在添加对完全分片数据并行 (FSDP) 的原生支持,目前作为一个原型功能提供。它的实现大量借鉴了 FairScale 的版本,同时带来了更精简的 API 和额外的性能改进。

PyTorch FSDP 在 AWS 上的扩展测试表明,它可以扩展到训练具有 1 万亿参数的密集模型。在我们的实验中,在 AWS 集群上,对于 GPT 1T 模型,每 A100 GPU 的实际性能达到 84 TFLOPS;对于 GPT 175B 模型,每 A100 GPU 的实际性能达到 159 TFLOPS。与 FairScale 的原始版本相比,当启用 CPU 卸载时,原生 FSDP 实现还显著缩短了模型初始化时间。

在未来的 PyTorch 版本中,我们将使用户能够在 DDP、ZeRO-1、ZeRO-2 和 FSDP 各种数据并行方式之间无缝切换,以便用户可以在统一的 API 中通过简单的配置训练不同规模的模型。

FSDP 工作原理

FSDP 是一种数据并行训练,但与传统的在每个 GPU 上维护模型参数、梯度和优化器状态副本的数据并行不同,它将所有这些状态分片到数据并行工作器中,并且可以选择将分片的模型参数卸载到 CPU。

下图显示了 FSDP 如何为 2 个数据并行进程工作

图 1. FSDP 工作流程

通常,模型层以嵌套方式用 FSDP 封装,这样只有单个 FSDP 实例中的层才需要在前向或后向计算期间将完整参数收集到单个设备。收集的完整参数将在计算后立即释放,释放的内存可用于下一层的计算。通过这种方式,可以节省峰值 GPU 内存,从而可以将训练扩展到使用更大的模型尺寸或更大的批处理尺寸。为了进一步最大限度地提高内存效率,当实例在计算中不活动时,FSDP 可以将参数、梯度和优化器状态卸载到 CPU。

在 PyTorch 中使用 FSDP

有两种方法可以使用 PyTorch FSDP 封装模型。自动封装是 DDP 的直接替代品;手动封装只需要对模型定义代码进行少量更改,并能够探索复杂的切分策略。

自动封装

模型层应以嵌套方式封装在 FSDP 中,以节省峰值内存并实现通信和计算的重叠。最简单的方法是自动封装,它可以作为 DDP 的直接替代品,而无需更改其余代码。

fsdp_auto_wrap_policy 参数允许指定一个可调用函数,以递归地使用 FSDP 封装层。PyTorch FSDP 提供的 default_auto_wrap_policy 函数递归地封装参数数量大于 1 亿的层。您可以根据需要提供自己的封装策略。编写自定义封装策略的示例显示在 FSDP API 文档中。

此外,可以选择配置 cpu_offload 以在这些参数未用于计算时将其卸载到 CPU。这可以进一步提高内存效率,但会增加主机和设备之间的数据传输开销。

下面的示例展示了如何使用自动封装来封装 FSDP。

from torch.distributed.fsdp import (
   FullyShardedDataParallel,
   CPUOffload,
)
from torch.distributed.fsdp.wrap import (
   default_auto_wrap_policy,
)
import torch.nn as nn
 
class model(nn.Module):
   def __init__(self):
       super().__init__()
       self.layer1 = nn.Linear(8, 4)
       self.layer2 = nn.Linear(4, 16)
       self.layer3 = nn.Linear(16, 4)
 
model = DistributedDataParallel(model())
fsdp_model = FullyShardedDataParallel(
   model(),
   fsdp_auto_wrap_policy=default_auto_wrap_policy,
   cpu_offload=CPUOffload(offload_params=True),
)

手动封装

手动封装对于通过有选择地将 `wrap` 应用于模型的某些部分来探索复杂的切分策略非常有用。整体设置可以传递给 enable_wrap() 上下文管理器。

from torch.distributed.fsdp import (
   FullyShardedDataParallel,
   CPUOffload,
)
from torch.distributed.fsdp.wrap import (
   enable_wrap,
   wrap,
)
import torch.nn as nn
from typing import Dict
 
 
class model(nn.Module):
   def __init__(self):
       super().__init__()
       self.layer1 = wrap(nn.Linear(8, 4))
       self.layer2 = nn.Linear(4, 16)
       self.layer3 = wrap(nn.Linear(16, 4))
 
wrapper_kwargs = Dict(cpu_offload=CPUOffload(offload_params=True))
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
   fsdp_model = wrap(model())

使用上述两种方法之一通过 FSDP 封装模型后,模型可以以与本地训练类似的方式进行训练,如下所示:

optim = torch.optim.Adam(fsdp_model.parameters(), lr=0.0001)
for sample, label in next_batch():
  out = fsdp_model(input)
  loss = criterion(out, label)
  loss.backward()
  optim.step()

基准测试结果

我们在 AWS 集群上使用 PyTorch FSDP 对 175B 和 1T GPT 模型进行了广泛的扩展测试。每个集群节点都是一个实例,配备 8 个 NVIDIA A100-SXM4-40GB GPU,节点间通过 AWS Elastic Fabric Adapter (EFA) 连接,网络带宽为 400 Gbps。

GPT 模型使用 minGPT 实现。使用随机生成输入数据集进行基准测试。所有实验均使用 50K 词汇量、fp16 精度和 SGD 优化器。

模型层数隐藏层大小注意力头模型大小,数十亿参数
GPT 175B961228896175
GPT 1T128256001601008

除了在实验中使用带参数 CPU 卸载的 FSDP 外,测试中还应用了 PyTorch 中的 激活检查点功能

对于 GPT 175B 模型,在使用 128 个 GPU、批量大小为 20、序列长度为 512 的情况下,实现了每 GPU 159 teraFLOP/s 的最大吞吐量(占 NVIDIA A100 峰值理论性能 312 teraFLOP/s/GPU 的 51%);进一步增加 GPU 数量会导致每 GPU 吞吐量下降,因为节点间通信量增加。

对于 GPT 1T 模型,在使用 128 个 GPU、批量大小为 4、序列长度为 2048 的情况下,实现了每 GPU 84 teraFLOP/s 的最大吞吐量(占峰值 teraFLOP/s 的 27%)。然而,进一步增加 GPU 数量对每 GPU 吞吐量影响不大,因为我们观察到 1T 模型训练的最大瓶颈不是来自通信,而是来自慢速 CUDA 缓存分配器,当峰值 GPU 内存达到极限时。使用具有更大内存容量的 A100 80G GPU 将在很大程度上解决此问题,并有助于扩展批量大小以实现更大的吞吐量。

未来工作

在下一个 Beta 版本中,我们计划添加高效的分布式模型/状态检查点 API、大型模型实例化元设备支持以及 FSDP 计算和通信中的混合精度支持。我们还将使在新 API 中更轻松地在 DDPZeRO1、ZeRO2 和 FSDP 各种数据并行方式之间切换。为了进一步提高 FSDP 性能,还计划减少内存碎片并改进通信效率。

FSDP 两个版本的历史

FairScale FSDP 于 2021 年初作为 FairScale 库的一部分发布。然后我们开始努力将 FairScale FSDP 上游到 PT 1.11 中的 PyTorch,使其达到生产就绪状态。我们有选择地从 FairScale FSDP 上游和重构了关键功能,重新设计了用户界面并改进了性能。

在不久的将来,FairScale FSDP 将保留在 FairScale 仓库中用于研究项目,而通用和广泛采用的功能将逐步上游到 PyTorch 并相应地进行强化。

同时,PyTorch FSDP 将更专注于生产就绪性和长期支持。这包括更好地与生态系统集成,以及在性能、可用性、可靠性、可调试性和可组合性方面的改进。

致谢

我们要感谢 FairScale FSDP 的作者:Myle Ott、Sam Shleifer、Min Xu、Priya Goyal、Quentin Duval、Vittorio Caggiano、Tingting Markstrum、Anjali Sridhar。感谢 Microsoft DeepSpeed ZeRO 团队开发并推广了分片数据并行技术。感谢 Pavel Belevich、Jessica Choi、Sisil Mehta 在不同集群上使用 PyTorch FSDP 进行实验。感谢 Geeta Chauhan、Mahesh Yadav、Pritam Damania、Dmytro Dzhulgakov 支持这项工作并提供了富有洞察力的讨论。