跳转到主要内容
博客

PyTorch 2.0 & XLA — 最新的尖端功能

今天,我们很高兴分享我们在 PyTorch/XLA 2.0 方面的新进展。PyTorch 2.0 的发布是这个历史悠久的社区又一个重要的里程碑,我们很高兴能继续成为其中的一部分。2018 年,当 Google 和 Meta 启动 PyTorch/XLA 项目时,重点是引入前沿的 Cloud TPU 以支持 PyTorch 社区。在此过程中,亚马逊等社区其他成员也加入了该项目,社区很快就扩大了。我们对 XLA 的发展方向以及该项目继续为 PyTorch 社区带来的好处感到兴奋。在本博客中,我们希望展示一些正在开发中的关键功能,展示代码片段,并通过一些基准测试来说明其益处。

TorchDynamo / torch.compile (实验性)

TorchDynamo (Dynamo) 是一个 Python 级别的 JIT 编译器,旨在加快未修改的 PyTorch 程序的运行速度。它为编译器后端提供了清晰的 API;其最大的特点是在执行前动态修改 Python 字节码。在 PyTorch/XLA 2.0 版本中,为推理和训练提供了 Dynamo 的实验性后端。

当 Dynamo 识别出模型模式时,它会提供一个 Torch FX (FX) 图,PyTorch/XLA 使用 Lazy Tensor 方法编译 FX 图并返回编译后的函数。要了解有关 PyTorch/XLA 的 Dynamo 实现技术细节的更多信息,请查看 这篇 dev-discuss 帖子和 Dynamo 文档

下面是使用 torch.compile 运行 ResNet18 的一个小型代码示例

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
      xla_resnet18, backend='torchxla_trace_once')
  for data, _ in loader:
    output = dynamo_resnet18(data)

使用 torch.compile,PyTorch/XLA 只在初始化时跟踪 ResNet18 模型一次,每次调用 dynamo_resnet18 时执行编译后的二进制文件,而不是每一步都跟踪模型。为了说明 Dynamo+XLA 的优势,下面是对使用 Cloud TPU v4-8 上的 TorchBench 比较 Dynamo 和 LazyTensor(不带 Dynamo)的推理加速分析,其中 y 轴是加速倍数。

Inference Speedup - PyTorch/XLA Dynamo on TPU

用于训练的 Dynamo 处于开发阶段,其实现比推理更早。欢迎开发人员测试此早期功能,但在 2.0 版本中,PyTorch/XLA 支持前向和后向传播图,但不支持优化器图;优化器图在每夜构建中可用,并将发布在 PyTorch/XLA 2.1 版本中。下面是使用 torch.compile 的 ResNet18 示例训练的示例

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='aot_torchxla_trace_once')
  for data, target in loader:
    output = dynamo_train_model(xla_resnet18, data, target)

请注意,训练的后端是 aot_torchxla_trace_once(API 将更新以用于稳定版本),而推理后端是 torchxla_trace_once(名称可能会更改)。如果您使用 Lazy Tensor,我们预计每个训练步骤提取并执行 3 个图,而不是 1 个训练步骤。下面是使用 Cloud TPU v4-8 上的 TorchBench 比较 Dynamo 和 Lazy 的训练加速分析。

Training Speedup - PyTorch/XLA Dynamo on TPU

PJRT 运行时(Beta)

PyTorch/XLA 正在从 XRT 迁移到新的 PJRT 运行时。PJRT 是一个维护更好的堆栈,具有已证实的性能优势,包括平均而言,在 TorchBench 2.0 模型上训练性能提高了 35%。它还支持更丰富的功能集,支持 SPMD 等技术。在 PyTorch/XLA 2.0 版本中,PJRT 是 TPU 和 CPU 的默认运行时;GPU 支持处于实验状态。PyTorch/XLA 2.0 版本中包含的 PJRT 功能是

  • 使用 PJRT 插件 APIlibtpu 中实现 TPU 运行时,性能提高了 30%
  • 支持 TPU v2 和 v3 的 torch.distributed,包括 pjrt:// init_method(实验性)
  • 单主机 GPU 支持。多主机支持即将推出。(实验性)

切换到 PJRT 不需要(或对 GPU 来说只需极少)更改用户代码(有关更多详细信息,请参阅 pjrt.md)。运行时配置就像将 PJRT_DEVICE 环境变量设置为本地设备类型(即 TPUGPUCPU)一样简单。下面是在不同设备上使用 PJRT 运行时的示例。

# TPU Device
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
# TPU Pod Device
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git"

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
# GPU Device (Experimental)
PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

下面是在 v4-8 TPU 上,按任务对 TorchBench 2.0 上的 XRT 和 PJRT 进行的性能比较。要了解有关 PJRT 与 XRT 的更多信息,请查阅文档

TorchBench Training Time

并行化

GSPMD(实验性)

我们很高兴在 PyTorch 中引入通用可扩展的 ML 计算图并行化 (GSPMD),作为一种新的实验性数据和模型分片解决方案。GSPMD 为常见的 ML 工作负载提供自动并行化,允许开发人员像在单个大型设备上一样编写 PyTorch 程序,而无需自定义分片计算操作和/或集体通信操作。XLA 编译器根据用户提供的分片提示,将单个设备程序转换为带有适当集合操作的分布式程序。API (RFC) 将在 PyTorch/XLA 2.0 版本中作为单个 TPU VM 主机上的实验性功能提供。

GSPMD 的下一步计划

GSPMD 在 2.0 版本中处于实验阶段。为了使其达到稳定状态,我们计划在后续版本中解决一些功能差距和已知问题,包括多主机支持、DTensor 集成、部分复制分片、异步数据加载和检查点。

FSDP(Beta)

PyTorch/XLA 在 1.12 版本中引入了完全分片数据并行 (FSDP) 的实验性支持。此功能是 PyTorch FSDP 的并行表示,XLA 和上游 CUDA 内核的设置方式存在细微差异。auto_wrap_policy 是一个新参数,使开发人员能够自动指定将分区规范传播到神经网络子模块的条件。auto_wrap_policys 可以简单地作为参数在用 FSDP 包装模型时传入。值得注意的两个 auto_wrap_policy 可调用函数是:size_based_auto_wrap_policytransformer_auto_wrap_policy

size_based_auto_wrap_policy 允许用户包装具有最小参数数量的子模块。下面的示例包装了至少有 10M 参数的模型子模块。

auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=1e7)

transformer_auto_wrap_policy 允许用户包装所有与特定层类型匹配的子模块。下面的示例包装了名为 torch.nn.Conv2d 的模型子模块。要了解更多信息,请查看 Ronghang Hu 的此 ResNet 示例

auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

PyTorch/XLA FSDP 现已集成到 HuggingFace 训练器类 (PR) 中,允许用户在 PyTorch/XLA 上训练更大的模型(官方 Hugging Face 文档)。一个 16B 参数的 GPT2 模型在 Cloud TPU v4-64 上使用此 FSDP 配置进行训练,实现了 39% 的硬件利用率。

TPU 加速器 – 设备数量v4-64
GPT2 参数计数16B
使用 FSDP 封装的层数GPT2Block
TFLOPs / 芯片275
PFLOPs / 步50
硬件利用率39%

FSDP 与 GSPMD 的区别

FSDP 是一种数据并行技术,通过将模型参数、优化器状态和梯度全部进行分片来减少设备内存占用。请注意,实际计算仍局限于设备本地,并且需要为前向和后向传播都收集所有分片模型参数,因此得名“数据并行”。FSDP 是 PyTorch/XLA 中最新的功能之一,用于扩展大型模型训练。

另一方面,GSPMD 是一种通用并行化系统,支持各种类型的并行化,包括数据并行和模型并行。PyTorch/XLA 提供分片注解 API 和 XLAShardedTensor 抽象,因此用户可以在 PyTorch 程序中对任何张量进行分片规范注解。开发人员无需手动实现分片计算或注入集体通信操作即可使其正确运行。XLA 编译器会完成这项工作,以便每个计算都可以在多个设备上以分布式方式运行。

示例与初步结果

要了解 PyTorch/XLA 并行化分片 API,请访问我们的 RFC 并查看 示例代码 参考。下面是启用数据并行和模型并行的简单示例。

model = SimpleLinear().to(xm.xla_device())
# Sharding annotate the linear layer weights.
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)
# Training loop
model.train()
for step, (data, target) in enumerate(loader):
  optimizer.zero_grad()
  data = data.to(xm.xla_device())
  target = target.to(xm.xla_device())
  # Sharding annotate input data, we can shard any input
  # dimensions. Sharidng the batch dimension enables 
  # data parallelism, sharding the feature dimension enables
  # spatial partitioning.
  xs.mark_sharding(data, mesh, partition_spec)
  ouput = model(data)
  loss = loss_fn(output, target)
  optimizer.step()
  xm.mark_step()

下图强调了 PyTorch/XLA FSDP 和 SPMD 在 Cloud TPU v4-8 上运行 ResNet50 时的内存效率优势。

Batch Size Scaling with Spatial Partitioning

总结...

我们很高兴能将这些功能带给 PyTorch 社区,而这仅仅是个开始。动态形状、对 OpenXLA 的更深层次支持以及许多其他领域都在开发中,我们计划发布更多博客来深入探讨细节。PyTorch/XLA 完全开源开发,我们邀请您通过在 GitHub 上提交问题、拉取请求和发送 RFC 来加入开发人员社区。您可以在各种 XLA 设备(包括 TPU 和 GPU)上试用 PyTorch/XLA。此处是入门指南。

再次祝贺 PyTorch 社区取得这一里程碑!

祝好,

Google PyTorch 团队