• 文档 >
  • Eager 模式 + Compile API
快捷方式

Eager 模式 + Compile API

本文档将介绍如何将 PyTorch/XLA 新的实验性 eager 模式与 compile API 结合使用。目标是使 PyTorch/XLA 的体验更贴近原生 PyTorch,并简化开发过程。

目前,PyTorch/XLA 默认在 LazyTensor 跟踪模式下运行。在以下代码中

import torch
import torch_xla
import torchvision

device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)

# model tracing
res = model(input)

# model execution, same as `xm.mark_step`
torch_xla.sync()

实际的模型编译和设备执行发生在调用 torch_xla.sync 时。这种方法存在多个缺点。

  1. 用户经常对框架何时进行跟踪以及何时执行感到困惑。

  2. 非核心模型代码(例如数据预处理)通常会产生一些小的待处理执行,这些执行会泄露到主计算图(step function)中并导致重复编译。整个计算图的重复编译通常开销很大。

  3. 调试重复编译何时/为何发生是很困难的。

为了缓解上述问题,我们希望引入使用 eager 和 compile 的新用户体验。

基本用法

import torch
import torch_xla
import torchvision

# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)

device = torch_xla.device()
model = torchvision.models.resnet18().to(device)

# Mark the function to be compiled
compiled_model = torch_xla.compile(model)
input = torch.randn(64, 3, 224, 224).to(device)

# Compilation and execution happens right away.
res = compiled_model(input)

注意

  1. 当前,用户必须通过 torch_xla.experimental.eager_mode(True) 手动启用 eager 模式。

  2. 需要编译的代码区域应该用 torch_xla.compile 包裹起来。

torch_xla.compile 的实现实际上非常直接,它在进入目标函数时禁用 eager 模式并开始跟踪。当目标函数返回时,它会调用 torch_xla.sync() 并重新启用 eager 模式。与现有的 mark_step/sync 方法相比,使用 eager + compile API 可以获得相同的性能。

推理

torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")

建议在推理时使用 torch.compile 而不是 torch_xla.compile,以减少跟踪开销。

训练

torch_xla.experimental.eager_mode(True)

def step_fn(model, data, target, loss_fn, optimizer):
    optimizer.zero_grad()
    logits = model(data)
    loss = loss_fn(logits, target)
    loss.backward()
    optimizer.step()
    return loss

step_fn = torch_xla.compile(step_fn)

在训练中,我们建议用户将 step_fn 重构出来,因为通常最好将模型的正向传播、反向传播和优化器一起编译。长期目标也是将 torch.compile 用于训练,但目前出于性能原因,我们建议用户使用 torch_xla.compile

基准测试

我在单个 v4-8 芯片上使用模拟数据运行了一个 2 层仅解码器模型的训练(类似于 llama2),共 300 步。下面是我观察到的数据。

模式 token/秒


跟踪模式 (基线) 147 Eager 模式 65 Eager + torch_xla compile 147

: Eager 模式基准测试结果

对于仅解码器模型,Eager 模式的性能可以达到完全编译模型的约 45%。更多信息请参阅 train_decoder_only_base.pyeager 示例。请注意,eager 模式的性能非常依赖于模型。当我尝试运行 resnet50 时,eager 模式的性能约为编译模式的 1%。我们不期望用户使用 eager 模式来执行主训练循环。Eager 模式旨在用于处理训练/推理逻辑的非核心部分(如数据预处理、随机数生成等)或用于调试。

文档

查阅全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源