• 文档 >
  • 急切模式 + 编译 API
快捷方式

急切模式 + 编译 API

在本文档中,我们将介绍如何使用 PyTorch/XLA 的新的实验性 eager 模式和 compile API。目标是使 PyTorch/XLA 体验与原生 PyTorch 更加一致,并简化开发过程。

背景

目前 PyTorch/XLA 默认在 LazyTensor 追踪模式下运行。在以下代码中

import torch
import torch_xla
import torchvision

device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)

# model tracing
res = model(input)

# model execution, same as `xm.mark_step`
torch_xla.sync()

实际的模型编译和设备执行发生在调用 torch_xla.sync 时。这种方法存在多个缺点。

  1. 用户经常对框架何时进行追踪以及何时进行执行感到困惑。

  2. 非核心模型代码(例如数据预处理)通常会生成一些小的待执行操作,这些操作会被泄漏到主图(step 函数)中并导致重新编译。整个图的重新编译通常非常昂贵。

  3. 很难调试何时/为何发生重新编译。

为了缓解上述问题,我们希望引入使用急切模式和编译的新 UX。

基本用法

import torch
import torch_xla
import torchvision

# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)

device = torch_xla.device()
model = torchvision.models.resnet18().to(device)

# Mark the function to be compiled
compiled_model = torch_xla.experimental.compile(model)
input = torch.randn(64, 3, 224, 224).to(device)

# Compilation and execution happens right away.
res = compiled_model(input)

请注意

  1. 目前用户必须通过 torch_xla.experimental.eager_mode(True) 手动启用急切模式。

  2. 希望编译的代码区域应使用 torch_xla.experimental.compile 进行包装。

torch_xla.experimental.compile 的实现实际上非常简单,它在进入目标函数时禁用急切模式并开始追踪。当目标函数返回时,它将调用 torch_xla.sync() 并重新启用急切模式。与现有的 mark_step/sync 方法相比,您可以期望使用 eager + compile API 获得相同的性能。

推理

torch_xla.experimental.eager_mode(True)

compiled_model = torch.compile(model, backend="openxla")

建议在推理中使用 torch.compile 而不是 torch_xla.experimental.compile 来减少追踪开销。

训练

torch_xla.experimental.eager_mode(True)

def step_fn(model, data, target, loss_fn, optimizer):
    optimizer.zero_grad()
    logits = model(data)
    loss = loss_fn(logits, target)
    loss.backward()
    optimizer.step()
    return loss

step_fn = torch_xla.experimental.compile(step_fn)

在训练中,我们要求用户将 step_fn 重构出来,因为通常最好将模型的前向、后向和优化器一起编译。长远目标也是在训练中使用 torch.compile,但目前我们建议用户使用 torch_xla.experimental.compile(出于性能原因)。

基准测试

我在 v4-8 的单个芯片上使用伪数据对一个 2 层解码器模型进行了 300 步的训练(它基本上就是一个 llama2)。以下是我的观察结果。

token/s
追踪模式(基线) 147
急切模式 65
急切模式 + torch_xla 编译 147

对于仅解码器模型,急切模式可以达到完全编译模型性能的约 45%。我用来测试的训练器可以在 此处此处 找到。请注意,急切模式的性能与模型密切相关。当我尝试运行 resnet50 时,急切模式的性能约为编译模式的 1%。我们不希望用户使用急切模式来执行主要的训练循环。急切模式旨在用于处理训练/推理逻辑的非核心部分(数据预处理、随机数生成等)或进行调试。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源