• 文档 >
  • PyTorch XLA 中的 TorchDynamo(torch.compile) 集成
快捷方式

PyTorch XLA 中的 TorchDynamo(torch.compile) 集成

TorchDynamo 是一款 Python 级 JIT 编译器,旨在使未修改的 PyTorch 程序更快。它为编译器后端提供了简洁的 API,其最大特点是在执行 Python 字节码之前动态修改它。在 pytorch/xla 2.0 版本中,PyTorch/XLA 为 TorchDynamo 提供了一个实验性后端,用于推理和训练。

XLA 桥接器的工作方式是,Dynamo 在识别模型模式时会提供 TorchFX 图,PyTorch/XLA 将使用现有的 Lazy Tensor 技术编译 FX 图并返回已编译的函数。

集成

目前通过向 torch.compile 添加 backend='openxla' 参数来支持 PyTorch/XLA 和 Dynamo。例如

import torch
import torch_xla.core.xla_model as xm

def add(a, b):
  a_xla = a.to(xm.xla_device())
  b_xla = b.to(xm.xla_device())
  return a_xla + b_xla

compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))

推理

这是一个使用 torch.compile 运行 resnet18 的小型代码示例

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
    xla_resnet18, backend='openxla')
  for data, _ in loader:
    with torch.no_grad():
      output = dynamo_resnet18(data)

使用 torch.compile,您将看到 PyTorch/XLA 只在初始化时跟踪一次 resnet18 模型,并在每次调用 dynamo_resnet18 时执行编译后的二进制文件,而不是每次都跟踪模型。以下是在 Cloud TPU v4-8 上使用 torch bench 比较 Dynamo 和 Lazy 的推理速度分析

resnet18 | 2.59 resnet50 | 2.64 resnext50_32x4d | 1.91 alexnet | 1.28 mobilenet_v2 | 18.62 mnasnet1_0 | 2.68 vgg16 | 1.33 BERT_pytorch | 7.49 squeezenet1_1 | 2.29 timm_vision_transformer | 3.52 几何平均 | 3.04

训练

PyTorch/XLA 还支持用于训练的 Dynamo,但它处于实验阶段,我们正在与 PyTorch Compiler 团队合作迭代实现。以下是如何使用 torch.compile 训练 resnet18 的示例

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target, optimizer):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  optimizer.step()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='openxla')
  for data, target in loader:
    xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
    output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)

我们预计在每个训练步骤中提取和执行 3 个图,而不是使用 Lazy 张量时每个训练步骤执行 1 个图。以下是在 Cloud TPU v4-8 上使用 torch bench 比较 Dynamo 和 Lazy 的训练速度分析。

resnet50 | 1.33 resnet18 | 1.33 BERT_pytorch | 3.07 resnext50_32x4d | 1.43 alexnet | 1.12 mobilenet_v2 | 1.4 mnasnet1_0 | 1.19 vgg16 | 0.81 timm_vision_transformer | 1.87 squeezenet1_1 | 1.41 几何平均 | 1.41

**注意:**我们为每个模型的前向和后向运行单个步骤,然后收集端到端时间。在现实世界中,我们将在每个训练作业中运行多个步骤,这可以很容易地隐藏执行中的跟踪成本(因为它是异步的)。在这种情况,Lazy Tensor 将具有更好的性能。

功能差距

我们想指出一个差距,它阻止我们在更大规模的模型上使用 TorchDynamo。

  1. TorchDynamo 将前向和后向跟踪到不同的图中。对于 PyTorch/XLA,让 XLA 编译器将整个步骤作为一个图看到非常重要,以便最佳优化速度。每个设备执行还存在固定的开销,这使得每个训练步骤执行多个图变得不理想。

与 Lazy Tensor 相比,这种差距在实际训练用例中效率较低,特别是跟踪成本可以在训练中与执行重叠。

结论

TorchDynamo 为编译器后端提供了一种非常有前景的方式,可以隐藏用户端的复杂性,并轻松地以图形格式检索建模代码。与 PyTorch/XLA 传统提取图形的 Lazy Tensor 方式相比,TorchDynamo 可以跳过每次迭代的图形跟踪,从而提供更快的推理响应时间。

PyTorch/XLA 支持的大多数模型在使用新的 dynamo-xla 桥接器运行推理时都看到了显著的提速。我们的社区正在努力扩展支持的模型集。关于上面提到的训练功能差距,PyTorch/XLA 社区非常高兴能在我们即将开展的开发工作中缩小训练差距。该团队将继续大力投资 TorchDynamo,并与上游合作以完善训练故事。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源