引言
近年来,研究界见证了大型模型在自然语言处理 (NLP)、计算机视觉及其他领域取得的巨大成功。这些成功很大程度上得益于云 TPU——一种用于分布式训练的强大硬件。为了在 PyTorch 中支持 TPU,PyTorch/XLA 库为 XLA 设备(主要是 TPU)提供了后端,并为在 TPU 上扩展大型 PyTorch 模型奠定了基础。
然而,PyTorch 生态系统中现有的模型扩展工具大多假设使用 GPU(或 CPU)设备,通常依赖于 CUDA 的特定功能,无法直接在 TPU 上运行。缺乏扩展工具使得构建那些无法放入单个 TPU 芯片内存的大型模型变得极具挑战性。
为了支持在 TPU 上进行模型扩展,作为 PyTorch/XLA 1.12 版本的一部分,我们为 XLA 设备实现了广泛采用的全分片数据并行 (FSDP) 算法。我们提供了一个 FSDP 接口,其高级设计与基于 CUDA 的 PyTorch FSDP 类相似,同时处理了 XLA 中的若干限制(详见下文“设计说明”)。该 FSDP 接口使我们能够轻松地在 TPU 上构建参数量在 10B 以上的模型,并促进了许多研究探索。
在 PyTorch/XLA 中使用全分片数据并行 (FSDP)
我们提供了一个针对给定 PyTorch 模型的包装类 XlaFullyShardedDataParallel,用于跨数据并行工作节点分片其参数。示例用法如下:
import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
使用 XlaFullyShardedDataParallel 包装 nn.Module 实例可在其上启用 ZeRO-2 算法,在该算法中,其梯度和优化器状态在整个训练过程中被分片。在前向和反向传播过程中,被包装模块的完整参数首先会从其对应的分片中重构出来以进行计算。
嵌套 FSDP 包装可用于进一步节省内存。这允许模型在任何给定时间仅存储单个层的完整参数。对于嵌套 FSDP,应先用内部 FSDP 包装单个子模块,然后再用外部 FSDP 包装基础模型。外部包装器确保处理任何剩余参数,这对应于 ZeRO-3 算法。嵌套 FSDP 包装可以应用于子模块的任何深度,且嵌套层数可以超过 2 层。
模型检查点保存与加载:模型和优化器的操作与以前一样,通过保存和加载它们的 .state_dict() 来完成。同时,每个训练进程应保存其自己的分片模型参数和优化器状态的检查点文件,并在恢复训练时为相应的 rank 加载检查点文件(无论是否使用 ZeRO-2 或 ZeRO-3,即是否使用嵌套包装)。我们提供了一个命令行工具和 Python 接口,用于将分片后的模型检查点文件合并为一个完整/未分片的模型检查点文件。
梯度检查点(也称为“激活检查点”或“重计算”)是另一种常见的模型扩展技术,可以与 FSDP 结合使用。我们提供了 checkpoint_module,这是一个针对给定 nn.Module 实例的包装函数,用于梯度检查点(基于 torch_xla.utils.checkpoint.checkpoint)。
下方的 MNIST 和 ImageNet 示例提供了(普通或嵌套)FSDP、模型检查点保存与合并以及梯度检查点的演示用法。
PyTorch/XLA 中 FSDP 的入门示例
使用 FSDP 训练 MNIST 和 ImageNet
MNIST 和 ImageNet 分类通常可作为构建更复杂深度学习模型的起点。我们在这两个数据集上提供了以下 FSDP 示例:
- MNIST:test/test_train_mp_mnist_fsdp_with_ckpt.py(它还演示了检查点保存和合并)
- ImageNet:test/test_train_mp_imagenet_fsdp.py
通过将它们与 MNIST 和 ImageNet 的原始数据并行示例进行对比,可以了解如何调整训练脚本以使用 FSDP。需要记住的一个主要区别是:在对 FSDP 包装的模型进行优化器步进时,应直接调用 optimizer.step(),而不是 xm.optimizer_step(optimizer)。后者会在各个 rank 之间规约梯度,而在 FSDP 中这并非我们所需要的,因为梯度已经在反向传播中的 reduce-scatter 操作中被规约并分片了。
安装
FSDP 可在 PyTorch/XLA 1.12 及更新的 nightly 版本中使用。请参考 https://github.com/pytorch/xla#-available-images-and-wheels 获取安装指南以及云 TPU 分配指南。然后,在 TPU VM 上克隆 PyTorch/XLA 仓库如下:
mkdir -p ~/pytorch && cd ~/pytorch
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在 v3-8 TPU 上训练 MNIST
训练 2 个 epoch 后,准确率约为 98.9%。
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp
上述脚本会在最后自动测试分片模型检查点的合并。您也可以通过以下方式手动合并分片后的检查点文件:
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet
训练 100 个 epoch 后,准确率约为 75.9%,与不使用 FSDP 时获得的结果相同;请下载并预处理 ImageNet-1k 数据集至 /datasets/imagenet-1k。
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 \
--lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 \
--num_epochs 100 \
--use_nested_fsdp
您还可以在这两个示例中探索其他选项,例如 --use_gradient_checkpointing 以在 ResNet 块上应用梯度检查点(即激活检查点),或 --compute_dtype bfloat16 以在 bfloat16 精度下执行前向和反向传播。
大规模模型示例
在 TPU 上构建大型模型时,我们通常需要注意内存限制(例如,TPU v3 每个核心 16 GB,TPU v4 每个芯片 32 GB)。对于无法放入单个 TPU 内存或主机 CPU 内存的大型模型,应使用嵌套 FSDP 来实现 ZeRO-3 算法,将子模块构建与内部 FSDP 包装交错进行,这样在构建过程中就不需要将完整模型保存在内存中。
我们在 https://github.com/ronghanghu/ptxla_scaling_examples 中演示了这些案例,该仓库提供了在 TPU v3 pod(拥有 128 个核心)上训练参数量超过 10B 的视觉 Transformer (ViT) 模型等案例。
设计说明
人们可能会问,为什么我们需要在 PyTorch/XLA 中开发一个独立的 FSDP 类,而不是直接复用 PyTorch 的 FSDP 类 或将其扩展到 XLA 后端。PyTorch/XLA 中独立 FSDP 类的核心动机在于:原生 PyTorch 的 FSDP 类严重依赖 XLA 设备不支持的 CUDA 特性,而 XLA 也有一些需要特殊处理的独特属性。这些差异要求 FSDP 的实现必须有所不同,在一个独立的类中实现会更容易。
API 调用中的变更
一个显著的区别在于,原生 PyTorch FSDP 构建在独立的 CUDA 流上,用于 eager 模式下的异步执行,而 PyTorch/XLA 运行在 lazy 模式下且不支持流。此外,TPU 要求所有设备同构地运行相同的程序。因此,在 PyTorch/XLA FSDP 实现中,CUDA 调用和跨进程的异构性需要被 XLA API 和替代的同构实现所取代。
张量存储处理
另一个显著区别是如何释放张量的存储,这在 XLA 中比在 CUDA 中困难得多。要实现 ZeRO-3,需要在模块的前向传播后释放完整参数的存储,以便下一个模块可以为后续计算重用该内存缓冲区。PyTorch 的 FSDP 在 CUDA 上是通过 p.data.storage().resize_(0) 来释放参数 p 的实际存储来实现这一点的。然而,由于 XLA HLO IR 是完全函数式的,且不提供任何用于解除分配张量或调整其存储大小的操作,XLA 张量没有这个 .storage() 句柄。在 PyTorch 接口之下,只有 XLA 编译器可以决定何时释放对应于 XLA 张量的 TPU 设备内存,前提是只有当张量对象在 Python 中被去分配时才能释放内存——而在 FSDP 中这不会发生,因为这些参数张量作为模块属性被引用,并被 PyTorch autograd 保存以用于反向传播。
我们解决此问题的方法是将张量的值属性与其自动微分变量属性分离,并通过将其 .data 属性设置为大小为 1 的虚拟标量来释放 nn.Parameter 张量。这样,完整参数的实际数据张量在 Python 中就会被去引用,使得 XLA 可以回收其内存用于其他计算,而自动微分仍可将基础 nn.Parameter 追踪为对参数数据的弱引用。为了使其生效,还需要处理参数的视图 (view),因为 PyTorch 中的视图也持有对其实际数据的引用(这需要修复 PyTorch/XLA 中与视图相关的形状问题)。
与 XLA 编译器的协作
上述解决方案在 XLA 编译器忠实地保留我们 PyTorch 程序中的操作及其执行顺序时,足以释放完整参数。但还有一个问题——XLA 试图通过对 HLO IR 应用公共子表达式消除 (CSE) 来优化程序,从而加快执行速度。在 FSDP 的简单实现中,XLA 编译器通常会在看到反向传播中的第二次 all-gather 是前向传播的重复计算时将其消除,并直接保留和重用我们想要在前向传播后释放的完整参数。为了防止这种不期望的编译器行为,我们在 PyTorch/XLA 中引入了 优化屏障操作 (optimization barrier op),并使用它来阻止消除第二次 all-gather。该优化屏障也被应用于类似的梯度检查点案例,以防止前向和反向传播之间发生可能消除重计算的 CSE。
未来,如果 CUDA 和 XLA 之间的差异不再像上述那样显著,可以考虑将 PyTorch/XLA FSDP 与原生 PyTorch FSDP 合并,以实现统一的接口。
致谢
感谢 AWS 的 Junmin Hao 审阅 PyTorch/XLA FSDP 的合并请求 (PR)。感谢 Meta PyTorch 团队的 Brian Hirsh 在 PyTorch 核心问题上提供的支持。感谢 Google 的 Isaack Karanja、Will Cromar 和 Blake Hechtman 在 GCP、XLA 和 TPU 问题上提供的支持。
感谢 Meta FAIR 的 Piotr Dollar、Wan-Yen Lo、Alex Berg、Ryan Mark、Kaiming He、Xinlei Chen、Saining Xie、Shoubhik Debnath、Min Xu 和 Vaibhav Aggarwal 参与的各种 TPU 相关讨论。