快捷方式

常见问题

作者: Mark Saroufim

torch.compile 是否支持训练?

torch.compile 支持训练,使用 AOTAutograd 来捕获反向传播

  1. .forward() 图和 optimizer.step() 由 TorchDynamo 的 python evalframe 前端捕获。

  2. 对于 torchdynamo 捕获的每个 .forward() 段,它使用 AOTAutograd 生成一个反向图段。

  3. 每对正向和反向图(可选)进行最小割分区,以在正向和反向之间保存最小状态。

  4. 正向和反向对都封装在 autograd.function 模块中。

  5. 调用 .backward() 的用户代码仍然会触发 eager 的 autograd 引擎,该引擎运行每个编译后的反向图,就像它是一个操作一样,同时还运行任何未编译的 eager 操作的 .backward() 函数。

你们是否支持分布式代码?

torch.compile 支持 DistributedDataParallel (DDP)。正在考虑支持其他分布式训练库。

Dynamo 在分布式代码方面具有挑战性的主要原因是 AOTAutograd 展开了正向和反向传播,并为后端提供了 2 个图以进行优化。这对于分布式代码来说是一个问题,因为我们理想情况下希望将通信操作与计算操作重叠。Eager PyTorch 以不同的方式为 DDP/FSDP 完成此操作 - 使用 autograd 钩子、模块钩子以及模块状态的修改/突变。在 Dynamo 的朴素应用中,应该在反向传播期间直接在操作之后运行的钩子可能会延迟到整个编译后的反向操作区域之后,这是由于 AOTAutograd 编译函数与调度器钩子交互的方式所致。

使用 Dynamo 优化 DDP 的基本策略在 distributed.py 中概述,其中主要思想是在 DDP 桶边界上进行图中断。

当 DDP 中的每个节点需要与其他节点同步其权重时,它会将其梯度和参数组织成桶,这减少了通信时间,并允许节点将其梯度的一部分广播到其他等待节点。

分布式代码中的图中断意味着您可以期望 Dynamo 及其后端优化分布式程序的计算开销,而不是其通信开销。如果减小的图大小剥夺了编译器的融合机会,则图中断可能会干扰编译加速。但是,随着图大小的增加,收益会递减,因为当前的大多数计算优化都是本地融合。因此,在实践中,这种方法可能就足够了。

我仍然需要导出整个图吗?

对于绝大多数模型,您可能不需要,您可以按原样使用 torch.compile(),但在少数情况下,完整图是必要的,您可以通过简单地运行 torch.compile(..., fullgraph=True) 来确保完整图。这些情况包括

  • 大规模训练运行,例如需要流水线并行和其他高级分片策略的 25 万美元以上的运行。

  • 推理优化器,例如 TensorRTAITemplate,它们比训练优化器更积极地依赖于融合。

  • 移动端训练或推理。

未来的工作将包括将通信操作跟踪到图中,协调这些操作与计算优化,以及优化通信操作。

为什么我的代码崩溃了?

如果您的代码在没有 torch.compile 的情况下运行良好,但在启用它后开始崩溃,那么最重要的第一步是弄清楚您的故障发生在堆栈的哪个部分。要解决此问题,请按照以下步骤操作,只有在前一步成功后才尝试下一步。

  1. torch.compile(..., backend="eager") 仅运行 TorchDynamo 正向图捕获,然后使用 PyTorch 运行捕获的图。如果此操作失败,则 TorchDynamo 存在问题。

  2. torch.compile(..., backend="aot_eager") 运行 TorchDynamo 以捕获正向图,然后运行 AOTAutograd 以跟踪反向图,而无需任何额外的后端编译器步骤。然后将使用 PyTorch eager 来运行正向和反向图。如果此操作失败,则 AOTAutograd 存在问题。

  3. torch.compile(..., backend="inductor") 运行 TorchDynamo 以捕获正向图,然后运行 AOTAutograd 以使用 TorchInductor 编译器跟踪反向图。如果此操作失败,则 TorchInductor 存在问题

为什么编译速度很慢?

  • Dynamo 编译– TorchDynamo 具有内置的统计函数,用于收集和显示每个编译阶段所花费的时间。可以通过在执行 torch._dynamo 后调用 torch._dynamo.utils.compile_times() 来访问这些统计信息。默认情况下,这将返回一个字符串表示,其中包含每个 TorchDynamo 函数按名称花费的编译时间。

  • Inductor 编译– TorchInductor 具有内置的统计和跟踪函数,用于显示每个编译阶段花费的时间、输出代码、输出图可视化和 IR 转储。env TORCH_COMPILE_DEBUG=1 python repro.py。这是一个调试工具,旨在使调试/理解 TorchInductor 的内部结构更容易,其输出将类似于 此链接。该调试跟踪中的每个文件都可以通过 torch._inductor.config.trace.* 启用/禁用。默认情况下,配置文件和图表均已禁用,因为它们生成成本很高。有关更多示例,请参见 示例调试目录输出

  • 过度重新编译 当 TorchDynamo 编译函数(或函数的一部分)时,它会对局部变量和全局变量做出某些假设,以便允许编译器优化,并将这些假设表示为保护,这些保护在运行时检查特定值。如果任何这些保护失败,Dynamo 将重新编译该函数(或部分函数),最多 torch._dynamo.config.cache_size_limit 次。如果您的程序达到了缓存限制,您首先需要确定哪个保护失败以及程序的哪一部分触发了它。重新编译分析器自动化了将 TorchDynamo 的缓存限制设置为 1 并在观察者模式的“编译器”下运行程序的过程,该“编译器”记录任何保护失败的原因。您应确保程序运行的时间至少与您遇到问题时运行的时间(迭代次数)一样长,并且分析器将在此期间累积统计信息。

为什么在生产环境中重新编译?

在某些情况下,您可能不希望在程序预热后出现意外编译。例如,如果您在延迟关键型应用程序中提供生产流量。为此,TorchDynamo 提供了一种备用模式,其中使用先前的编译图,但不生成新的图

frozen_toy_example = dynamo.run(toy_example)
frozen_toy_example(torch.randn(10), torch.randn(10))

你们是如何加速我的代码的?

加速 PyTorch 代码主要有 3 种方法

  1. 通过垂直融合进行内核融合,垂直融合将顺序操作融合在一起,以避免过多的读/写。例如,融合 2 个后续的余弦运算意味着您可以执行 1 次读取 1 次写入,而不是 2 次读取 2 次写入 2 次。水平融合:最简单的示例是批处理,其中单个矩阵与一批示例相乘,但更一般的情况是分组 GEMM,其中一组矩阵乘法被一起调度

  2. 乱序执行:编译器的通用优化,通过查看图中的精确数据依赖性,我们可以决定执行节点的最佳时机以及可以重用哪些缓冲区

  3. 自动工作放置:与乱序执行点类似,但是通过将图的节点与物理硬件或内存等资源匹配,我们可以设计合适的调度

以上是加速 PyTorch 代码的一般原则,但不同的后端将对优化内容做出不同的权衡。例如,Inductor 首先处理它可以融合的任何内容,然后才生成 Triton 内核。

Triton 除了提供加速之外,还因为它在每个流式多处理器中实现了自动内存合并、内存管理和调度,并且被设计为处理平铺计算。

但是,无论您使用哪个后端,最好都使用基准测试和查看方法,因此请尝试 PyTorch 分析器,直观地检查生成的内核,并尝试自己查看发生了什么。

为什么我没有看到加速?

图中断

您在使用 Dynamo 时看不到您期望的加速的主要原因是过多的图中断。那么什么是图中断呢?

给定一个程序,例如

def some_fun(x):
    ...

torch.compile(some_fun)(x)
...

TorchDynamo 将尝试将 some_fun() 中的所有 torch/tensor 操作编译成单个 FX 图,但它可能无法将所有内容捕获到一个图中。

一些图中断原因是 TorchDynamo 无法克服的,例如调用 PyTorch 以外的 C 扩展对于 TorchDynamo 是不可见的,并且可能会执行任意操作,而 TorchDynamo 无法引入必要的保护来确保编译后的程序可以安全地重用。

为了最大限度地提高性能,尽可能减少图中断非常重要。

识别图中断的原因

要识别程序中的所有图中断以及中断的相关原因,可以使用 torch._dynamo.explain。此工具在提供的函数上运行 TorchDynamo,并聚合遇到的图中断。以下是示例用法

import torch
import torch._dynamo as dynamo
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    print("woo")
    if b.sum() < 0:
        b = b * -1
    return x * b
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
print(explanation)
"""
Graph Count: 3
Graph Break Count: 2
Op Count: 5
Break Reasons:
  Break Reason 1:
    Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
    User Stack:
      <FrameSummary file foo.py, line 5 in toy_example>
  Break Reason 2:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
Ops per Graph:
  ...
Out Guards:
  ...
"""

要在遇到的第一个图中断时抛出错误,您可以使用 fullgraph=True 禁用 python 回退,如果您使用过基于导出的编译器,则应该对此很熟悉。

def toy_example(a, b):
   ...

torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)

为什么当我更改代码时没有重新编译?

如果您通过设置 env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py 启用了动态形状,那么您的代码在形状更改时将不会重新编译。我们添加了对动态形状的支持,这避免了形状变化小于 2 倍时重新编译。这在 CV 中图像大小变化或 NLP 中可变序列长度等场景中尤其有用。在推理场景中,通常不可能预先知道批大小,因为您会从不同的客户端应用程序中获取尽可能多的数据。

总的来说,TorchDynamo 非常努力地避免不必要的重新编译,因此例如,如果 TorchDynamo 找到 3 个图,而您的更改仅修改了一个图,那么只有该图会被重新编译。因此,避免潜在的缓慢编译时间的另一个技巧是预热模型,方法是在编译一次后编译它,之后后续编译将快得多。冷启动编译时间仍然是我们显式跟踪的指标。

为什么我得到不正确的结果?

如果您设置环境变量 TORCHDYNAMO_REPRO_LEVEL=4,也可以最大限度地减少准确性问题,它以类似的 git bisect 模型运行,完整的 repro 可能类似于 TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4,我们需要这样做是因为下游编译器将生成代码,无论是 Triton 代码还是 C++ 后端,这些下游编译器的数值可能以细微的方式有所不同,但会对您的训练稳定性产生巨大影响。因此,准确性调试器对于我们检测代码生成或后端编译器中的错误非常有用。

如果您想确保 torch 和 triton 之间的随机数生成相同,则可以启用 torch._inductor.config.fallback_random = True

为什么我得到 OOM?

Dynamo 仍然是一个 alpha 产品,因此存在一些 OOM 源,如果您看到 OOM,请尝试按以下顺序禁用以下配置,然后在 GitHub 上打开一个 issue,以便我们解决根本问题 1. 如果您正在使用动态形状,请尝试禁用它们,我们默认情况下已禁用它们:env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py 2. 默认情况下,Inductor 中启用了带有 Triton 的 CUDA 图,但删除它们可能会缓解一些 OOM 问题:torch._inductor.config.triton.cudagraphs = False

torch.func 是否与 torch.compile 一起使用(用于 gradvmap 转换)?

torch.func 转换应用于使用 torch.compile 的函数确实有效

import torch

@torch.compile
def f(x):
    return torch.sin(x)

def g(x):
    return torch.grad(f)(x)

x = torch.randn(2, 3)
g(x)

在使用 torch.compile 处理的函数内部调用 torch.func 转换

使用 torch.compile 编译 torch.func.grad

import torch

def wrapper_fn(x):
    return torch.func.grad(lambda x: x.sin().sum())(x)

x = torch.randn(3, 3, 3)
grad_x = torch.compile(wrapper_fn)(x)

使用 torch.compile 编译 torch.vmap

import torch

def my_fn(x):
    return torch.vmap(lambda x: x.sum(1))(x)

x = torch.randn(3, 3, 3)
output = torch.compile(my_fn)(x)

编译除受支持函数以外的函数(逃生舱口)

对于其他转换,作为一种解决方法,请使用 torch._dynamo.allow_in_graph

allow_in_graph 是一个逃生舱口。如果您的代码不适用于内省 Python 字节码的 torch.compile,但您认为它可以通过符号跟踪方法(如 jax.jit)工作,则可以使用 allow_in_graph

通过使用 allow_in_graph 注释函数,您必须确保您的代码满足以下要求

  • 函数中的所有输出仅依赖于输入,而不依赖于任何捕获的张量。

  • 您的函数是函数式的。也就是说,它不会改变任何状态。这可能会放宽;我们实际上支持从外部看起来是函数式的函数:它们可能具有就地 PyTorch 操作,但可能不会改变全局状态或函数的输入。

  • 您的函数不会引发数据相关的错误。

import torch

@torch.compile
def f(x):
    return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)

x = torch.randn(2, 3)
f(x)

一个常见的陷阱是使用 allow_in_graph 注释调用 nn.Module 的函数。这是因为输出现在依赖于 nn.Module 的参数。要使其工作,请使用 torch.func.functional_call 提取模块状态。

NumPy 是否与 torch.compile 一起使用?

从 2.1 开始,torch.compile 理解在 NumPy 数组上运行的本机 NumPy 程序,以及通过 x.numpy()torch.from_numpy 和相关函数从 PyTorch 转换为 NumPy 并返回的混合 PyTorch-NumPy 程序。

torch.compile 支持哪些 NumPy 功能?

torch.compile 中的 NumPy 遵循 NumPy 2.0 预发布版。

通常,torch.compile 能够跟踪大多数 NumPy 构造,当它不能跟踪时,它会回退到 eager 并让 NumPy 执行该段代码。即便如此,在某些功能中,torch.compile 语义与 NumPy 的语义略有偏差

  • NumPy 标量:我们将它们建模为 0-D 数组。也就是说,np.float32(3)torch.compile 下返回 0-D 数组。为了避免图中断,最好使用此 0-D 数组。如果这破坏了您的代码,您可以通过将 NumPy 标量强制转换为相关的 Python 标量类型 bool/int/float 来解决此问题。

  • 负步长:np.flip 和使用负步长的切片返回副本。

  • 类型提升:NumPy 的类型提升将在 NumPy 2.0 中更改。新规则在 NEP 50 中描述。torch.compile 实现了 NEP 50,而不是当前即将弃用的规则。

  • {tril,triu}_indices_from/{tril,triu}_indices 返回数组,而不是数组的元组。

还有其他一些功能我们不支持跟踪,我们会优雅地回退到 NumPy 以执行它们

  • 非数字 dtype,如 datetimes、strings、chars、void、structured dtypes 和 recarrays。

  • Long dtypes np.float128/np.complex256 和一些无符号 dtypes np.uint16/np.uint32/np.uint64

  • ndarray 子类。

  • 掩码数组。

  • 深奥的 ufunc 机制,如 axes=[(n,k),(k,m)->(n,m)] 和 ufunc 方法(例如,np.add.reduce)。

  • 排序/订购 complex64/complex128 数组。

  • NumPy np.poly1dnp.polynomial

  • 具有 2 个或更多返回值的函数中的位置 out1, out2 参数(out=tuple 有效)。

  • __array_function____array_interface____array_wrap__

  • ndarray.ctypes 属性。

我可以使用 torch.compile 编译 NumPy 代码吗?

当然可以!torch.compile 本机理解 NumPy 代码,并将其视为 PyTorch 代码。为此,只需使用 torch.compile 装饰器包装 NumPy 代码即可。

import torch
import numpy as np

@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)

使用环境变量 TORCH_LOGS=output_code 执行此示例,我们可以看到 torch.compile 能够将乘法和求和融合到一个 C++ 内核中。它还能够使用 OpenMP 并行执行它们(本机 NumPy 是单线程的)。这可以轻松地使您的 NumPy 代码速度提高 n 倍,其中 n 是处理器中的内核数!

以这种方式跟踪 NumPy 代码也支持编译代码中的图中断。

我可以在 CUDA 上执行 NumPy 代码并通过 torch.compile 计算梯度吗?

是的,可以!要做到这一点,您只需在 torch.device("cuda") 上下文中执行您的代码。考虑以下示例

import torch
import numpy as np

@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
with torch.device("cuda"):
    Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)

在此示例中,numpy_fn 将在 CUDA 中执行。为了实现这一点,torch.compile 会自动将 XY 从 CPU 移动到 CUDA,然后将结果 Z 从 CUDA 移动到 CPU。如果我们在同一个程序运行中多次执行此函数,我们可能希望避免所有这些相当昂贵的内存复制。为此,我们只需要调整我们的 numpy_fn,使其接受 CUDA 张量并返回张量。我们可以通过使用 torch.compiler.wrap_numpy 来做到这一点

@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"

在这里,我们显式地在 CUDA 内存中创建张量,并将它们传递给函数,该函数在 CUDA 设备上执行所有计算。wrap_numpy 负责在 torch.compile 级别将任何 torch.Tensor 输入标记为具有 np.ndarray 语义的输入。在编译器内部标记张量是一个非常廉价的操作,因此在运行时不会发生数据复制或数据移动。

使用此装饰器,我们还可以通过 NumPy 代码进行微分!

@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    return np.mean(np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1)))

X = torch.randn(1024, 64, device="cuda", requires_grad=True)
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
Z.backward()
# X.grad now holds the gradient of the computation
print(X.grad)

我们一直在使用 fullgraph=True,因为在这种情况下,图中断是有问题的。当发生图中断时,我们需要物化 NumPy 数组。由于 NumPy 数组没有 devicerequires_grad 的概念,因此这些信息在图中断期间会丢失。

我们无法通过图中断传播梯度,因为图中断代码可能会执行任意代码,而这些代码不知道如何微分。另一方面,在 CUDA 执行的情况下,我们可以像第一个示例中那样,通过使用 torch.device("cuda") 上下文管理器来解决此问题

@torch.compile
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    prod = X[:, :, None] * Y[:, None, :]
    print("oops, a graph break!")
    return np.sum(prod, axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")

with torch.device("cuda"):
    Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"

在图中断期间,中间张量仍然需要移动到 CPU,但是当跟踪在图中断后恢复时,图的其余部分仍然在 CUDA 上跟踪。鉴于这种 CUDA <> CPU 和 CPU <> CUDA 移动,图中断在 NumPy 上下文中相当昂贵,应该避免,但至少它们允许通过复杂的代码片段进行跟踪。

如何在 torch.compile 下调试 NumPy 代码?

考虑到现代编译器的复杂性以及它们引发的令人生畏的错误,调试 JIT 编译的代码是具有挑战性的。torch.compile 故障排除文档 包含一些关于如何解决此任务的技巧和窍门。

如果以上方法不足以查明问题的根源,我们仍然可以使用一些其他特定于 NumPy 的工具。我们可以通过禁用通过 NumPy 函数的跟踪来辨别错误是否完全在 PyTorch 代码中

from torch._dynamo import config
config.trace_numpy = False

如果错误存在于跟踪的 NumPy 代码中,我们可以通过导入 import torch._numpy as np,使用 PyTorch 作为后端,急切地(不使用 torch.compile)执行 NumPy 代码。这应该仅用于 调试目的,绝不能替代 PyTorch API,因为它 性能要差得多,并且作为私有 API,可能会在没有通知的情况下更改。无论如何,torch._numpy 是 NumPy 的 Python 实现,它使用 PyTorch 术语,并且在内部被 torch.compile 用于将 NumPy 代码转换为 Pytorch 代码。它相当容易阅读和修改,因此如果您发现任何错误,请随时提交 PR 修复它或直接打开一个 issue。

如果程序在使用导入 torch._numpy as np 时可以工作,则错误很可能在 TorchDynamo 中。如果是这种情况,请随时打开一个 issue 并提供 最小可复现示例

我对一些 NumPy 代码进行了 torch.compile,但没有看到任何加速。

最好的起点是 关于如何调试这些 torch.compile 问题的通用建议的教程

一些图中断可能是由于使用了不支持的功能而发生的。请参阅 torch.compile 支持哪些 NumPy 功能?。更一般地说,记住一些广泛使用的 NumPy 功能与编译器配合不佳是很有用的。例如,就地修改使得编译器内部的推理变得困难,并且通常会产生比其异地对应物更差的性能。因此,最好避免它们。对于使用 out= 参数也是如此。相反,最好使用异地操作,并让 torch.compile 优化内存使用。对于数据相关的操作(如通过布尔掩码进行掩码索引)或数据相关的控制流(如 ifwhile 结构)也是如此。

使用哪个 API 进行细粒度跟踪?

在某些情况下,您可能需要从 torch.compile 编译中排除代码的小部分。本节提供了一些答案,您可以在 TorchDynamo 用于细粒度跟踪的 API 中找到更多信息。

如何在函数上进行图中断?

仅在函数上进行图中断不足以充分表达您希望 PyTorch 执行的操作。您需要更具体地说明您的用例。您可能要考虑的一些最常见的用例

  • 如果您想在此函数帧和递归调用的帧上禁用编译,请使用 torch._dynamo.disable

  • 如果您希望特定运算符(例如 fbgemm)使用 eager 模式,请使用 torch._dynamo.disallow_in_graph

一些不常见的用例包括

  • 如果您想在函数帧上禁用 TorchDynamo,但在递归调用的帧上重新启用它 - 请使用 torch._dynamo.disable(recursive=False)

  • 如果您想防止函数帧的内联 - 请在您想要防止内联的函数开头使用 torch._dynamo.graph_break

torch._dynamo.disabletorch._dynamo.disallow_in_graph 之间有什么区别?

Disallow-in-graph 在运算符级别工作,更具体地说,是在您在 TorchDynamo 提取的图中看到的运算符级别工作。

Disable 在函数帧级别工作,并决定 TorchDynamo 是否应该查看函数帧。

torch._dynamo.disabletorch._dynamo_skip 之间有什么区别?

注意

torch._dynamo_skip 已弃用。

您最有可能需要 torch._dynamo.disable。但在不太可能的情况下,您可能需要更精细的控制。假设您只想禁用 a_fn 函数上的跟踪,但希望在 aa_fnab_fn 中继续跟踪。下图演示了此用例

diagram of torch.compile + disable(a_fn, recursive=False)

在这种情况下,您可以使用 torch._dynamo.disable(recursive=False)。在以前的版本中,此功能由 torch._dynamo.skip 提供。现在,torch._dynamo.disable 内的 recursive 标志支持此功能。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题的解答

查看资源