跳转到主要内容
博客

使用 FSDP 在 Cloud TPU 上扩展 PyTorch 模型

简介

近年来,研究社区在自然语言处理、计算机视觉和其他领域的大型模型方面取得了许多成功。其中许多成功都得益于 Cloud TPU,它们是用于分布式训练的强大硬件。为了支持 PyTorch 中的 TPU,PyTorch/XLA 库为 XLA 设备(最著名的是 TPU)提供了后端,并为在 TPU 上扩展大型 PyTorch 模型奠定了基础。

然而,PyTorch 生态系统中现有的大多数模型扩展工具都假定使用 GPU(或 CPU)设备,通常依赖于 CUDA 中的特定功能,并且不能直接在 TPU 上工作。缺乏扩展工具使得构建无法适应单个 TPU 芯片内存的大型模型变得具有挑战性。

为了支持在 TPU 上进行模型扩展,我们作为 PyTorch/XLA 1.12 发布的一部分,为 XLA 设备实现了广泛采用的 完全分片数据并行 (FSDP) 算法。我们提供了一个 FSDP 接口,其高级设计与基于 CUDA 的 PyTorch FSDP 类相似,同时还处理了 XLA 中的一些限制(有关更多详细信息,请参阅下面的设计说明)。这个 FSDP 接口使我们能够轻松地在 TPU 上构建具有 10B+ 参数的模型,并支持了许多研究探索。

在 PyTorch/XLA 中使用完全分片数据并行 (FSDP)

我们提供了一个包装类 XlaFullyShardedDataParallel,用于给定 PyTorch 模型,以将其参数分片到数据并行工作器中。示例如下:

import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

使用 XlaFullyShardedDataParallel 包装 nn.Module 实例可在其上启用 ZeRO-2 算法,在该算法中,其梯度和优化器状态在整个训练过程中都被分片。在其前向和后向传播过程中,包装模块的完整参数首先从其相应的分片中重建以进行计算。

可以使用嵌套 FSDP 包装来进一步节省内存。这允许模型在任何给定时间仅存储一个单独层的完整参数。对于嵌套 FSDP,应首先使用内部 FSDP 包装其单独的子模块,然后再使用外部 FSDP 包装基础模型。这允许模型在任何给定时间仅存储一个单独层的完整参数。并且拥有一个外部包装器可确保处理任何剩余参数,对应于 ZeRO-3 算法。嵌套 FSDP 包装可以应用于子模块的任何深度,并且可以有超过 2 层的嵌套。

模型和优化器的模型检查点保存和加载可以像以前一样通过保存和加载它们的 .state_dict() 来完成。同时,每个训练过程应保存其自己的分片模型参数和优化器状态的检查点文件,并在恢复时加载相应等级的检查点文件(无论 ZeRO-2 还是 ZeRO-3,即嵌套包装与否)。提供了一个命令行工具和一个 Python 接口,用于将分片模型检查点文件合并为一个完整/未分片的模型检查点文件。

梯度检查点(也称为“激活检查点”或“重新实例化”)是另一种用于模型扩展的常见技术,可以与 FSDP 结合使用。我们提供了 checkpoint_module,这是一个包装给定 nn.Module 实例以进行梯度检查点(基于 torch_xla.utils.checkpoint.checkpoint)的函数。

下面的 MNIST 和 ImageNet 示例提供了(普通或嵌套)FSDP、模型检查点的保存和合并以及梯度检查点的说明性用法。

PyTorch/XLA 中 FSDP 的起始示例

使用 FSDP 训练 MNIST 和 ImageNet

MNIST 和 ImageNet 分类通常可以作为构建更复杂的深度学习模型的起点。我们提供了以下关于这两个数据集的 FSDP 示例:

将它们与 MNISTImageNet 的普通数据并行示例进行比较,说明了如何调整训练脚本以使用 FSDP。需要记住的一个主要区别是,当在 FSDP 包装模型上优化器进行步进时,应直接调用 optimizer.step() 而不是 xm.optimizer_step(optimizer)。后者会减少跨等级的梯度,这在 FSDP 中不是我们需要的,因为梯度已经减少并分片(来自其后向传播中的 reduce-scatter 操作)。

安装

FSDP 可在 PyTorch/XLA 1.12 及更新的每夜版本中获得。请参阅 https://github.com/pytorch/xla#-available-images-and-wheels 以获取安装指南以及 Cloud TPU 分配。然后,在 TPU VM 上克隆 PyTorch/XLA 存储库,如下所示:

mkdir -p ~/pytorch && cd ~/pytorch
git clone --recursive https://github.com/pytorch/xla.git
cd ~/

在 v3-8 TPU 上训练 MNIST

它在 2 个 epoch 中获得了大约 98.9 的准确率

python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp

上面的脚本会在最后自动测试分片模型检查点的合并。您也可以通过以下方式手动合并分片检查点文件:

python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet

它在 100 个 epoch 中获得了大约 75.9 的准确率,与不使用 FSDP 获得的结果相同;将 ImageNet-1k 数据集下载并预处理到 /datasets/imagenet-1k

python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 \
  --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 \
  --num_epochs 100 \
  --use_nested_fsdp

您还可以探索这两个示例中的其他选项,例如 --use_gradient_checkpointing 以在 ResNet 块上应用梯度检查点(即激活检查点),或 --compute_dtype bfloat16 以 bfloat16 精度执行前向和后向传播。

大型模型示例

在 TPU 上构建大型模型时,我们通常需要注意内存限制(例如,TPU v3 中每个核心 16 GB,TPU v4 中每个芯片 32 GB)。对于无法适应单个 TPU 内存或主机 CPU 内存的大型模型,应使用嵌套 FSDP 来实现 ZeRO-3 算法,将子模块构建与内部 FSDP 包装交错,以便在构建过程中永远不需要将完整模型存储在内存中。

我们在 https://github.com/ronghanghu/ptxla_scaling_examples 中说明了这些情况,其中提供了在 TPU v3 Pod(带 128 个核心)上训练具有 10B+ 参数的 Vision Transformer (ViT) 模型以及其他情况的示例。

设计说明

有人可能想知道为什么我们需要在 PyTorch/XLA 中开发一个单独的 FSDP 类,而不是直接重用 PyTorch 的 FSDP 类 或将其扩展到 XLA 后端。在 PyTorch/XLA 中使用单独 FSDP 类的主要动机是,原生 PyTorch 的 FSDP 类严重依赖 XLA 设备不支持的 CUDA 功能,而 XLA 也具有一些需要特殊处理的独特特性。这些差异需要 FSDP 的不同实现,在单独的类中构建会容易得多。

API 调用更改

一个显著的区别是,原生 PyTorch FSDP 基于独立的 CUDA 流在 eager 模式下进行异步执行,而 PyTorch/XLA 在 lazy 模式下运行,并且不支持流。此外,TPU 要求所有设备同构地运行相同的程序。因此,在 PyTorch/XLA FSDP 实现中,CUDA 调用和每进程异构性需要被 XLA API 和替代的同构实现所取代。

张量存储处理

另一个显著的区别是如何释放张量的存储,这在 XLA 中比在 CUDA 中要困难得多。为了实现 ZeRO-3,需要在模块的前向传播之后释放完整参数的存储,以便下一个模块可以重用此内存缓冲区进行后续计算。PyTorch 的 FSPD 通过 p.data.storage().resize_(0) 释放参数 p 的实际存储来实现这一点。然而,XLA 张量没有这个 .storage() 句柄,因为 XLA HLO IR 是完全函数式的,不提供任何操作来释放张量或调整其存储大小。在 PyTorch 接口之下,只有 XLA 编译器可以决定何时释放与 XLA 张量对应的 TPU 设备内存,并且一个先决条件是只有当张量对象在 Python 中被释放时才能释放内存——这在 FSDP 中不可能发生,因为这些参数张量被引用为模块属性,并且也被 PyTorch autograd 保存用于反向传播。

我们解决此问题的方法是将张量的值属性与其自动梯度变量属性分离,并通过将其 .data 属性设置为大小为 1 的虚拟标量来释放 nn.Parameter 张量。这样,完整参数的实际数据张量在 Python 中被解除引用,以便 XLA 可以回收其内存用于其他计算,而自动梯度仍然可以将基础 nn.Parameter 跟踪为参数数据的弱引用。为了使其工作,还需要处理参数的视图,因为 PyTorch 中的视图也保留对其实际数据的引用(这需要在 PyTorch/XLA 中修复与视图相关的形状问题)。

与 XLA 编译器协作

如果 XLA 编译器忠实地保留了我们 PyTorch 程序中的操作及其执行顺序,上述解决方案应该足以释放完整参数。但还有一个问题——XLA 试图通过对 HLO IR 应用公共子表达式消除 (CSE) 来优化程序以加速其执行。在 FSDP 的朴素实现中,XLA 编译器通常会在反向传播中消除第二个 all-gather 以重建完整参数,当它看到这是来自前向传播的重复计算时,并直接保留和重用我们想要在前向传播之后释放的完整参数。为了防止这种不希望的编译器行为,我们在 PyTorch/XLA 中引入了 优化屏障操作,并用它来阻止消除第二个 all-gather。此优化屏障也应用于梯度检查点的类似情况,以防止前向和反向传播之间的 CSE 消除重新实例化。

未来,如果 CUDA 和 XLA 之间的区别不再像上面提到的那样突出,那么将 PyTorch/XLA FSDP 与原生 PyTorch FSDP 合并以实现统一接口可能是值得考虑的。

致谢

感谢 AWS 的 Junmin Hao 审阅 PyTorch/XLA FSDP 拉取请求。感谢 Meta PyTorch 团队的 Brian Hirsh 对 PyTorch 核心问题的支持。感谢 Google 的 Isaack Karanja、Will Cromar 和 Blake Hechtman 对 GCP、XLA 和 TPU 问题的支持。

感谢 Meta FAIR 的 Piotr Dollar、Wan-Yen Lo、Alex Berg、Ryan Mark、Kaiming He、Xinlei Chen、Saining Xie、Shoubhik Debnath、Min Xu 和 Vaibhav Aggarwal 进行的各种 TPU 相关讨论。