引言
近年来,研究社区在 NLP、计算机视觉和其他领域的大型模型方面取得了许多成功。其中许多成功都得益于 Cloud TPU——这是一种用于分布式训练的强大硬件。为了支持 PyTorch 中的 TPU,PyTorch/XLA 库为 XLA 设备(最著名的是 TPU)提供了后端,并为在 TPU 上扩展大型 PyTorch 模型奠定了基础。
然而,PyTorch 生态系统中的大多数现有模型扩展工具都假定使用 GPU(或 CPU)设备,通常依赖于 CUDA 中的特定功能,并且不能直接在 TPU 上工作。缺乏扩展工具使得构建无法适应单个 TPU 芯片内存的大型模型变得具有挑战性。
为了支持 TPU 上的模型扩展,我们实现了广泛采用的 完全分片数据并行 (FSDP) 算法,作为 PyTorch/XLA 1.12 版本的一部分,用于 XLA 设备。我们提供了一个 FSDP 接口,其高级设计与基于 CUDA 的 PyTorch FSDP 类类似,同时还处理了 XLA 中的一些限制(有关详细信息,请参阅下面的设计说明)。这个 FSDP 接口使我们能够轻松地在 TPU 上构建拥有例如 100 亿以上参数的模型,并促成了许多研究探索。
在 PyTorch/XLA 中使用完全分片数据并行 (FSDP)
我们提供了一个包装器类 `XlaFullyShardedDataParallel`,用于给定 PyTorch 模型,以在数据并行工作器之间分片其参数。以下是一个使用示例:
import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
使用 `XlaFullyShardedDataParallel` 包装 `nn.Module` 实例可以在其上启用 ZeRO-2 算法,其中其梯度和优化器状态在整个训练过程中都被分片。在其前向和后向传递中,包装模块的完整参数首先从其对应的分片重建以进行计算。
可以使用**嵌套 FSDP** 包装来进一步节省内存。这使得模型在任何给定时间点都只存储单个层的完整参数。对于嵌套 FSDP,应首先使用内部 FSDP 包装其各个子模块,然后再使用外部 FSDP 包装基础模型。这使得模型在任何给定时间点都只存储单个层的完整参数。并且拥有一个外部包装器可确保处理任何剩余参数,对应于 ZeRO-3 算法。嵌套 FSDP 包装可以应用于任何深度的子模块,并且可以有超过 2 层嵌套。
**模型检查点保存和加载**(适用于模型和优化器)可以像以前一样通过保存和加载其 `.state_dict()` 来完成。同时,每个训练过程应保存其自己的分片模型参数和优化器状态的检查点文件,并在恢复时加载相应级别的检查点文件(无论 ZeRO-2 还是 ZeRO-3,即是否嵌套包装)。提供了一个命令行工具和一个 Python 接口,用于将分片模型检查点文件合并为一个完整/未分片的模型检查点文件。
**梯度检查点**(也称为“激活检查点”或“重新实例化”)是另一种常见的模型扩展技术,可以与 FSDP 结合使用。我们提供了 `checkpoint_module`,这是一个包装函数,用于给定 `nn.Module` 实例的梯度检查点(基于 `torch_xla.utils.checkpoint.checkpoint`)。
下面的 MNIST 和 ImageNet 示例说明了(普通或嵌套)FSDP、模型检查点的保存和合并以及梯度检查点的用法。
PyTorch/XLA 中 FSDP 的入门示例
使用 FSDP 训练 MNIST 和 ImageNet
MNIST 和 ImageNet 分类通常可以用作构建更复杂的深度学习模型的起点。我们提供了关于这两个数据集的以下 FSDP 示例:
- MNIST: test/test_train_mp_mnist_fsdp_with_ckpt.py(它还说明了检查点保存和合并)
- ImageNet: test/test_train_mp_imagenet_fsdp.py
将它们与 MNIST 和 ImageNet 的原始数据并行示例进行比较,可以说明如何调整训练脚本以使用 FSDP。需要记住的一个主要区别是,当在 FSDP 封装的模型上执行优化器步骤时,应该直接调用 `optimizer.step()` 而不是 `xm.optimizer_step(optimizer)`。后者会在各级别之间减少梯度,而这在 FSDP 中是不需要的,因为梯度已经经过减少和分片(通过其反向传播中的 reduce-scatter 操作)。
安装
FSDP 可从 PyTorch/XLA 1.12 及更高版本的夜间构建中获取。请参阅 https://github.com/pytorch/xla#-available-images-and-wheels 以获取安装指南以及 Cloud TPU 分配。然后,在 TPU VM 上克隆 PyTorch/XLA 仓库,如下所示:
mkdir -p ~/pytorch && cd ~/pytorch
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在 v3-8 TPU 上训练 MNIST
两个 epoch 之后,准确率约为 98.9%。
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp
上述脚本会自动测试最后分片模型检查点的合并。您也可以通过以下方式手动合并分片检查点文件:
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet
100 个 epoch 后,准确率约为 75.9%,与不使用 FSDP 的结果相同;将 ImageNet-1k 数据集下载并预处理到 `/datasets/imagenet-1k`
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 \
--lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 \
--num_epochs 100 \
--use_nested_fsdp
您还可以在这两个示例中探索其他选项,例如 `--use_gradient_checkpointing` 用于在 ResNet 块上应用梯度检查点(即激活检查点),或者 `--compute_dtype bfloat16` 用于以 bfloat16 精度执行前向和后向传播。
大规模模型示例
在 TPU 上构建大型模型时,我们通常需要注意内存限制(例如,TPU v3 中每个核心 16 GB,TPU v4 中每个芯片 32 GB)。对于无法容纳在单个 TPU 内存或主机 CPU 内存中的大型模型,应该使用嵌套 FSDP 来实现 ZeRO-3 算法,将子模块构建与内部 FSDP 包装交错进行,这样在构建过程中就不需要将完整模型存储在内存中。
我们在 https://github.com/ronghanghu/ptxla_scaling_examples 中阐述了这些情况,该仓库提供了在 TPU v3 Pod(128 个核心)上训练具有 100 亿以上参数的 Vision Transformer (ViT) 模型以及其他案例的示例。
设计说明
人们可能会疑惑,为什么我们需要在 PyTorch/XLA 中开发一个单独的 FSDP 类,而不是直接重用 PyTorch 的 FSDP 类或将其扩展到 XLA 后端。在 PyTorch/XLA 中使用单独的 FSDP 类的主要动机是,原生的 PyTorch FSDP 类严重依赖 XLA 设备不支持的 CUDA 特性,而 XLA 也具有一些需要特殊处理的独特特性。这些区别需要 FSDP 的不同实现,这将更容易在单独的类中构建。
API 调用中的更改
一个显著的区别是,原生 PyTorch FSDP 建立在独立的 CUDA 流之上,用于急切模式下的异步执行,而 PyTorch/XLA 在惰性模式下运行,并且不支持流。此外,TPU 要求所有设备同构地运行相同的程序。因此,在 PyTorch/XLA FSDP 实现中,CUDA 调用和每进程异构性需要替换为 XLA API 和替代的同构实现。
张量存储处理
另一个显著区别是如何释放张量的存储,这在 XLA 中比在 CUDA 中要困难得多。为了实现 ZeRO-3,需要在模块的前向传播后释放完整参数的存储,以便下一个模块可以重用此内存缓冲区进行后续计算。PyTorch 的 FSPD 通过 `p.data.storage().resize_(0)` 释放参数 `p` 的实际存储来实现这一点。然而,XLA 张量没有这个 `.storage()` 句柄,因为 XLA HLO IR 是完全函数式的,并且不提供任何操作来解除分配张量或调整其存储大小。在 PyTorch 接口之下,只有 XLA 编译器才能决定何时释放与 XLA 张量对应的 TPU 设备内存,并且一个先决条件是只有当张量对象在 Python 中解除分配时才能释放内存——这在 FSDP 中不可能发生,因为这些参数张量被引用为模块属性,并且也被 PyTorch autograd 保存以用于反向传播。
我们解决这个问题的方法是将张量的值属性与其自动求导变量属性分离,并通过将其 `.data` 属性设置为大小为 1 的虚拟标量来释放 `nn.Parameter` 张量。这样,完整参数的实际数据张量在 Python 中被取消引用,以便 XLA 可以回收其内存用于其他计算,而自动求导仍然可以将基本的 `nn.Parameter` 跟踪为参数数据的弱引用。为了使其工作,还需要处理参数的视图,因为 PyTorch 中的视图也持有对其实际数据的引用(这需要在 PyTorch/XLA 中修复与视图相关的形状问题)。
与 XLA 编译器协同工作
如果 XLA 编译器忠实地保留了我们 PyTorch 程序中的操作及其执行顺序,上述解决方案应该足以释放完整参数。但还有另一个问题——XLA 试图通过对 HLO IR 应用公共子表达式消除 (CSE) 来优化程序以加快其执行速度。在 FSDP 的简单实现中,XLA 编译器通常会在反向传播中消除第二次全收集以重建完整参数,当它发现这是前向传播中重复的计算时,并直接持有并重用我们希望在前向传播后释放的完整参数。为了防止这种不希望的编译器行为,我们在 PyTorch/XLA 中引入了优化屏障操作,并使用它来阻止消除第二次全收集。此优化屏障也应用于梯度检查点的类似情况,以防止前向和反向传播之间的 CSE 可能会消除重新实例化。
未来,如果 CUDA 和 XLA 之间的区别不像上面提到的那样突出,那么将 PyTorch/XLA FSDP 与原生 PyTorch FSDP 合并以实现统一接口可能值得考虑。
致谢
感谢 AWS 的 Junmin Hao 审阅了 PyTorch/XLA FSDP 的拉取请求。感谢 Meta PyTorch 团队的 Brian Hirsh 对 PyTorch 核心问题的支持。感谢 Google 的 Isaack Karanja、Will Cromar 和 Blake Hechtman 对 GCP、XLA 和 TPU 问题的支持。
感谢 Meta FAIR 的 Piotr Dollar、Wan-Yen Lo、Alex Berg、Ryan Mark、Kaiming He、Xinlei Chen、Saining Xie、Shoubhik Debnath、Min Xu 和 Vaibhav Aggarwal 进行了各种与 TPU 相关的讨论。