跳转到主要内容
博客

PyTorch 2.0 & XLA — 最新的尖端功能

今天,我们很高兴分享我们在 PyTorch/XLA 2.0 方面的最新工作。PyTorch 2.0 的发布是这个历史悠久的社区的又一个重要里程碑,我们很高兴能继续成为其中的一部分。当 Google 和 Meta 于 2018 年启动 PyTorch/XLA 项目时,重点是引入尖端云 TPU,以帮助支持 PyTorch 社区。在此过程中,亚马逊等社区中的其他成员也加入了该项目,社区迅速扩大。我们对 XLA 的 发展方向 以及该项目继续为 PyTorch 社区带来的好处感到兴奋。在本博客中,我们希望展示一些正在开发中的关键功能,展示代码片段,并通过一些基准测试来说明其优势。

TorchDynamo / torch.compile (实验性)

TorchDynamo (Dynamo) 是一个 Python 级别的 JIT 编译器,旨在加速未经修改的 PyTorch 程序。它为编译器后端提供了清晰的 API 以进行挂钩;其最大的特点是在执行前动态修改 Python 字节码。在 PyTorch/XLA 2.0 版本中,为推理和训练提供了 Dynamo 的实验性后端。

当 Dynamo 识别出模型模式时,它会提供一个 Torch FX (FX) 图,PyTorch/XLA 使用 Lazy Tensor 方法编译 FX 图并返回编译后的函数。要更深入地了解 PyTorch/XLA Dynamo 实现的技术细节,请查看 这篇 dev-discuss 帖子和 Dynamo 文档

这是一个使用 torch.compile 运行 ResNet18 的小代码示例

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
      xla_resnet18, backend='torchxla_trace_once')
  for data, _ in loader:
    output = dynamo_resnet18(data)

使用 torch.compile,PyTorch/XLA 只在初始化时追踪 ResNet18 模型一次,并在每次调用 dynamo_resnet18 时执行编译后的二进制文件,而不是每一步都追踪模型。为了说明 Dynamo+XLA 的优势,下面是对云 TPU v4-8 上使用 TorchBench 比较 Dynamo 和 LazyTensor(没有 Dynamo)的推理加速分析,其中 Y 轴是加速倍数。

Inference Speedup - PyTorch/XLA Dynamo on TPU

用于训练的 Dynamo 处于开发阶段,其实现比推理阶段更早。欢迎开发人员测试此早期功能,但在 2.0 版本中,PyTorch/XLA 支持前向和后向传播图,而不支持优化器图;优化器图在每夜构建中可用,并将发布在 PyTorch/XLA 2.1 版本中。下面是使用 ResNet18 示例和 torch.compile 进行训练的示例

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='aot_torchxla_trace_once')
  for data, target in loader:
    output = dynamo_train_model(xla_resnet18, data, target)

请注意,训练的后端是 aot_torchxla_trace_once(API 将在稳定版本中更新),而推理后端是 torchxla_trace_once(名称可能会更改)。如果您使用 Lazy tensor,我们预计每个训练步骤将提取并执行 3 个图,而不是 1 个训练步骤。下面是在云 TPU v4-8 上使用 TorchBench 比较 Dynamo 和 Lazy 的训练加速分析。

Training Speedup - PyTorch/XLA Dynamo on TPU

PJRT 运行时 (Beta)

PyTorch/XLA 正在从 XRT 迁移到新的 PJRT 运行时。PJRT 是一个维护更好的堆栈,具有已证实的性能优势,包括在 TorchBench 2.0 模型上训练的平均性能提高 35%。它还支持更丰富的功能集,支持 SPMD 等技术。在 PyTorch/XLA 2.0 版本中,PJRT 是 TPU 和 CPU 的默认运行时;GPU 支持处于实验状态。PyTorch/XLA 2.0 版本中包含的 PJRT 功能有:

  • 使用 PJRT 插件 APIlibtpu 中实现 TPU 运行时,性能提高高达 30%。
  • torch.distributed 支持 TPU v2 和 v3,包括 pjrt:// init_method(实验性)。
  • 单主机 GPU 支持。多主机支持即将推出。(实验性)

切换到 PJRT 不需要(或对 GPU 来说只需极少)更改用户代码(有关更多详细信息,请参阅 pjrt.md)。运行时配置就像将 PJRT_DEVICE 环境变量设置为本地设备类型(即 TPUGPUCPU)一样简单。下面是在不同设备上使用 PJRT 运行时的示例。

# TPU Device
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
# TPU Pod Device
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git"

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
# GPU Device (Experimental)
PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

下面是 Cloud TPU v4-8 上 TorchBench 2.0 中 XRT 和 PJRT 按任务进行的性能比较。要了解有关 PJRT 与 XRT 的更多信息,请查看文档

TorchBench Training Time

并行化

GSPMD (实验性)

我们很高兴地在 PyTorch 中引入通用可扩展的 ML 计算图并行化 (GSPMD),作为一种新的实验性数据和模型分片解决方案。GSPMD 为常见的 ML 工作负载提供自动并行化,允许开发人员像在单个大型设备上一样编写 PyTorch 程序,无需自定义分片计算操作和/或集体通信操作。XLA 编译器根据用户提供的分片提示,将单个设备程序转换为带有所需集体操作的分区程序。API (RFC) 将在 PyTorch/XLA 2.0 版本中作为单个 TPU VM 主机上的实验性功能提供。

GSPMD 的下一步计划

GSPMD 在 2.0 版本中是实验性的。为了使其达到稳定状态,我们计划在后续版本中解决一些功能差距和已知问题,包括多主机支持、DTensor 集成、部分复制分片、异步数据加载和检查点。

FSDP (Beta)

PyTorch/XLA 在 1.12 版本中引入了完全分片数据并行 (FSDP) 实验性支持。此功能是 PyTorch FSDP 的并行表示,XLA 和上游 CUDA 内核的设置方式存在细微差别。auto_wrap_policy 是一个新参数,它使开发人员能够自动指定将分区规范传播到神经网络子模块的条件。auto_wrap_policy 可以简单地作为参数传递给 FSDP 封装模型时。值得注意的两个 auto_wrap_policy 可调用对象是:size_based_auto_wrap_policytransformer_auto_wrap_policy

size_based_auto_wrap_policy 允许用户封装具有最少参数数量的子模块。以下示例封装了至少具有 1000 万个参数的模型子模块。

auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=1e7)

transformer_auto_wrap_policy 允许用户封装所有与特定层类型匹配的子模块。下面的示例封装了名为 torch.nn.Conv2d 的模型子模块。要了解更多信息,请查看 Ronghang Hu 的 此 ResNet 示例

auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

PyTorch/XLA FSDP 现已集成到 HuggingFace 训练器类 (PR) 中,使用户能够在 PyTorch/XLA 上训练更大的模型 (Hugging Face 官方文档)。使用此 FSDP 配置在 Cloud TPU v4-64 上训练的 16B 参数 GPT2 模型达到了 39% 的硬件利用率。

TPU 加速器 – 设备数量v4-64
GPT2 参数计数16B
用 FSDP 封装的层GPT2Block
TFLOPS / 芯片275
PFLOPS / 步50
硬件利用率39%

FSDP 和 GSPMD 之间的差异

FSDP 是一种数据并行技术,通过对模型参数、优化器状态和梯度进行分片来减少设备内存占用。请注意,实际计算仍局限于设备本地,并且需要为前向和后向传播聚合所有分片模型参数,因此得名“数据并行”。FSDP 是 PyTorch/XLA 中用于扩展大型模型训练的最新功能之一。

另一方面,GSPMD 是一种通用并行化系统,支持各种类型的并行化,包括数据并行和模型并行。PyTorch/XLA 提供了一个分片注释 API 和 XLAShardedTensor 抽象,因此用户可以在 PyTorch 程序中用分片规范注释任何张量。开发人员无需手动实现分片计算或注入集体通信操作即可使其正常工作。XLA 编译器会完成这项工作,以便每个计算都可以在多个设备上以分布式方式运行。

示例和初步结果

要了解 PyTorch/XLA 并行分片 API,请访问我们的 RFC 并查看 示例代码 参考。下面是一个启用数据和模型并行的简单示例。

model = SimpleLinear().to(xm.xla_device())
# Sharding annotate the linear layer weights.
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)
# Training loop
model.train()
for step, (data, target) in enumerate(loader):
  optimizer.zero_grad()
  data = data.to(xm.xla_device())
  target = target.to(xm.xla_device())
  # Sharding annotate input data, we can shard any input
  # dimensions. Sharidng the batch dimension enables 
  # data parallelism, sharding the feature dimension enables
  # spatial partitioning.
  xs.mark_sharding(data, mesh, partition_spec)
  ouput = model(data)
  loss = loss_fn(output, target)
  optimizer.step()
  xm.mark_step()

下图突出显示了 PyTorch/XLA FSDP 和 SPMD 在运行 ResNet50 的 Cloud TPU v4-8 上的内存效率优势。

Batch Size Scaling with Spatial Partitioning

总结……

我们很高兴能将这些功能带给 PyTorch 社区,而这仅仅是个开始。动态形状、对 OpenXLA 的更深入支持以及许多其他领域正在开发中,我们计划发布更多博客来深入探讨细节。PyTorch/XLA 完全开源开发,我们邀请您通过在 GitHub 上提交问题、提交拉取请求和发送 RFC 来加入开发人员社区。您可以在各种 XLA 设备(包括 TPU 和 GPU)上试用 PyTorch/XLA。这里是入门方法。

再次祝贺 PyTorch 社区达到这个里程碑!

祝好,

Google 的 PyTorch 团队