引言
近年来,研究界在 NLP、计算机视觉和其他领域的大型模型方面取得了许多成功。这些成功中的许多都得益于 Cloud TPU,它是用于分布式训练的强大硬件。为了在 PyTorch 中支持 TPU,PyTorch/XLA 库为 XLA 设备(最值得注意的是 TPU)提供了后端,并为在 TPU 上扩展大型 PyTorch 模型奠定了基础。
然而,PyTorch 生态系统中大多数现有的模型扩展工具都假定使用 GPU(或 CPU)设备,并且通常依赖于 CUDA 中的特定功能,而无法直接在 TPU 上工作。缺乏扩展工具使得构建无法容纳在单个 TPU 芯片内存中的大型模型具有挑战性。
为了支持在 TPU 上进行模型扩展,我们作为 PyTorch/XLA 1.12 版本的一部分,为 XLA 设备实现了广泛采用的全分片数据并行 (FSDP) 算法。我们提供了一个 FSDP 接口,其高级设计与基于 CUDA 的 PyTorch FSDP 类相似,同时也处理了 XLA 中的一些限制(更多详情请参见下面的设计说明)。这个 FSDP 接口使我们能够轻松地在 TPU 上构建具有例如 100 亿以上参数的模型,并支持了许多研究探索。
在 PyTorch/XLA 中使用全分片数据并行 (FSDP)
我们提供了一个封装类 XlaFullyShardedDataParallel
,用于包装给定的 PyTorch 模型,以在其数据并行 worker 之间分片其参数。一个使用示例如下:
import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
使用 XlaFullyShardedDataParallel
包装 nn.Module
实例,可以在其上启用 ZeRO-2 算法,其中其梯度和优化器状态在整个训练过程中被分片。在其前向和反向传播过程中,首先从其相应的分片重建被包装模块的完整参数以进行计算。
可以使用嵌套 FSDP 包装来进一步节省内存。这允许模型在任何给定时间只存储一个单独层的完整参数。对于嵌套 FSDP,应该首先使用内部 FSDP 包装其单个子模块,然后使用外部 FSDP 包装基础模型。这允许模型在任何给定时间只存储一个单独层的完整参数。并且拥有一个外部包装确保处理任何剩余参数,这对应于 ZeRO-3 算法。嵌套 FSDP 包装可以应用于任何深度的子模块,并且可以有超过 2 层的嵌套。
模型和优化器的模型检查点保存和加载可以像以前一样,通过保存和加载它们的 .state_dict()
来完成。同时,每个训练进程应该保存其自己的分片模型参数和优化器状态的检查点文件,并在恢复时加载相应 rank 的检查点文件(无论是否使用 ZeRO-2 或 ZeRO-3,即是否嵌套包装)。提供了一个命令行工具和一个 Python 接口,用于将分片的模型检查点文件合并到一个完整/未分片的模型检查点文件中。
梯度检查点(也称为“激活检查点”或“重计算”)是另一种常见的模型扩展技术,可以与 FSDP 结合使用。我们提供 checkpoint_module
,一个包装函数,用于对给定的 nn.Module
实例应用梯度检查点(基于 torch_xla.utils.checkpoint.checkpoint
)。
下面的 MNIST 和 ImageNet 示例提供了(普通或嵌套)FSDP、模型检查点保存和合并以及梯度检查点的使用示例。
PyTorch/XLA 中 FSDP 的入门示例
使用 FSDP 训练 MNIST 和 ImageNet
MNIST 和 ImageNet 分类通常可以用作构建更复杂的深度学习模型的起点。我们在这两个数据集上提供了以下 FSDP 示例:
- MNIST:test/test_train_mp_mnist_fsdp_with_ckpt.py(它还演示了检查点保存和合并)
- ImageNet:test/test_train_mp_imagenet_fsdp.py
将它们与 MNIST 和 ImageNet 的普通数据并行示例进行比较,可以了解如何调整训练脚本以使用 FSDP。需要记住的一个主要区别是,在 FSDP 包装的模型上进行优化器步进时,应该直接调用 optimizer.step()
,而不是 xm.optimizer_step(optimizer)
。后者会跨 rank 缩减梯度,这在 FSDP 中不是我们需要的,因为在 FSDP 中梯度已经经过缩减和分片(通过其反向传播中的 reduce-scatter 操作)。
安装
FSDP 可从 PyTorch/XLA 1.12 及更新的 nightly 版本中获取。有关安装以及 Cloud TPU 分配的指南,请参阅https://github.com/pytorch/xla#-available-images-and-wheels。然后,在 TPU VM 上克隆 PyTorch/XLA 仓库,如下所示:
mkdir -p ~/pytorch && cd ~/pytorch
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在 v3-8 TPU 上训练 MNIST
经过 2 个 epoch 后,准确率约为 98.9%
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp
上面的脚本会在结束时自动测试分片模型检查点的合并。您也可以通过以下方式手动合并分片检查点文件:
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet
经过 100 个 epoch 后,准确率约为 75.9%,这与不使用 FSDP 时获得的结果相同;将 ImageNet-1k 数据集下载并预处理到 /datasets/imagenet-1k
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 \
--lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 \
--num_epochs 100 \
--use_nested_fsdp
您也可以在这些示例中探索其他选项,例如 --use_gradient_checkpointing
以在 ResNet 块上应用梯度检查点(即激活检查点),或 --compute_dtype bfloat16
以使用 bfloat16 精度执行前向和反向传播。
大规模模型示例
在 TPU 上构建大型模型时,我们经常需要注意内存限制(例如,TPU v3 中每个核心 16 GB,TPU v4 中每个芯片 32 GB)。对于无法容纳在单个 TPU 内存或主机 CPU 内存中的大型模型,应使用嵌套 FSDP 来实现 ZeRO-3 算法,将子模块构建与内部 FSDP 包装交错进行,以便在构建过程中无需在内存中存储完整模型。
我们在 https://github.com/ronghanghu/ptxla_scaling_examples 中说明了这些情况,其中提供了在 TPU v3 Pod(具有 128 个核心)上训练具有 100 亿以上参数的 Vision Transformer (ViT) 模型以及其他案例的示例。
设计说明
有人可能会想,为什么我们需要在 PyTorch/XLA 中开发一个单独的 FSDP 类,而不是直接重用 PyTorch 的 FSDP 类或将其扩展到 XLA 后端。在 PyTorch/XLA 中使用单独 FSDP 类的主要原因是,原生的 PyTorch FSDP 类严重依赖于 XLA 设备不支持的 CUDA 特性,而 XLA 也有几个需要特殊处理的独特特性。这些差异需要一个不同的 FSDP 实现,在一个单独的类中构建起来会容易得多。
API 调用的变更
一个显著的区别是,原生的 PyTorch FSDP 构建在单独的 CUDA 流上,用于在 eager 模式下进行异步执行,而 PyTorch/XLA 在 lazy 模式下运行,并且不支持流。此外,TPU 要求所有设备同质地运行相同的程序。因此,在 PyTorch/XLA FSDP 实现中,CUDA 调用和每进程异构性需要被 XLA API 和替代的同质实现所取代。
Tensor 存储处理
另一个显著的区别是如何释放 tensor 的存储,这在 XLA 中比在 CUDA 中困难得多。为了实现 ZeRO-3,需要在模块的前向传播之后释放完整参数的存储,以便下一个模块可以重用该内存缓冲区进行后续计算。PyTorch 的 FSDP 通过 p.data.storage().resize_(0)
在 CUDA 上实现这一点,从而释放参数 p
的实际存储。然而,XLA tensor 没有这个 .storage()
句柄,因为 XLA HLO IR 是完全函数式的,并且不提供任何操作来释放 tensor 或调整其存储大小。在 PyTorch 接口之下,只有 XLA 编译器可以决定何时释放与 XLA tensor 对应的 TPU 设备内存,并且前提条件是该内存只有在该 tensor 对象在 Python 中被释放时才能释放——这在 FSDP 中不会发生,因为这些参数 tensor 被引用为模块属性,并且也被 PyTorch autograd 保存用于反向传播。
我们解决这个问题的方法是,将 tensor 的值属性与其 autograd Variable 属性分离,并通过将其 .data
属性设置为大小为 1 的 dummy 标量来释放 nn.Parameter
tensor。这样,完整参数的实际数据 tensor 在 Python 中会被解除引用,以便 XLA 可以回收其内存用于其他计算,而 autograd 仍然可以将基础 nn.Parameter
追踪为参数数据的弱引用。为了使这项工作成功,还需要处理对参数的视图,因为 PyTorch 中的视图也持有对其实际数据的引用(这需要修复 PyTorch/XLA 中与视图相关的形状问题)。
与 XLA 编译器协作
如果 XLA 编译器忠实地保留了我们的 PyTorch 程序中的操作及其执行顺序,那么上述解决方案应该足以释放完整参数。但还有一个问题——XLA 试图通过对 HLO IR 应用公共子表达式消除 (CSE) 来优化程序以加快执行速度。在 FSDP 的朴素实现中,XLA 编译器通常会在看到它是前向传播中的重复计算时,消除反向传播中的第二次 all-gather 以重建完整参数,并直接保留并重用我们在前向传播后想要释放的完整参数。为了防止这种不期望的编译器行为,我们在 PyTorch/XLA 中引入了优化屏障操作,并用它来阻止消除第二次 all-gather。这个优化屏障也应用于梯度检查点的类似情况,以防止前向和反向传播之间的 CSE 消除重计算。
未来,如果 CUDA 和 XLA 之间的差异不像上述那么突出,则可以考虑将 PyTorch/XLA FSDP 与原生的 PyTorch FSDP 合并,以实现统一接口。
致谢
感谢 AWS 的 Junmin Hao 审查 PyTorch/XLA FSDP 的拉取请求。感谢 Meta PyTorch 团队的 Brian Hirsh 对 PyTorch 核心问题的支持。感谢 Google 的 Isaack Karanja、Will Cromar 和 Blake Hechtman 对 GCP、XLA 和 TPU 问题的支持。
感谢 Meta FAIR 的 Piotr Dollar、Wan-Yen Lo、Alex Berg、Ryan Mark、Kaiming He、Xinlei Chen、Saining Xie、Shoubhik Debnath、Min Xu 和 Vaibhav Aggarwal 在各种 TPU 相关讨论中提供的帮助。