常见问题解答¶
作者: Mark Saroufim
torch.compile
支持训练吗?¶
torch.compile
支持训练,使用 AOTAutograd 来捕获反向传播
The
.forward()
图和optimizer.step()
由 TorchDynamo 的 pythonevalframe
前端捕获。对于 TorchDynamo 捕获的每个
.forward()
段,它使用 AOTAutograd 生成一个反向传播图段。每对正向和反向图(可选)被最小割分区以保存正向和反向之间的最小状态。
正向和反向对被包装在
autograd.function
模块中。用户代码调用
.backward()
仍然会触发 eager 的 autograd 引擎,它会将每个编译后的反向传播图作为单个操作运行,还会运行任何未编译的 eager 操作的.backward()
函数。
您支持分布式代码吗?¶
torch.compile
支持 DistributedDataParallel
(DDP)。正在考虑对其他分布式训练库的支持。
Dynamo 分布式代码面临的主要挑战在于 AOTAutograd 会展开前向和反向传播,并为后端提供两个图进行优化。这对分布式代码来说是个问题,因为我们希望理想情况下将通信操作与计算重叠。Eager PyTorch 通过不同的方式(使用 autograd 钩子、模块钩子和模块状态的修改/变异)在 DDP/FSDP 中实现这一点。在 Dynamo 的简单应用中,应该在反向传播期间操作之后立即运行的钩子可能会延迟到整个编译的反向传播操作区域之后,这是由于 AOTAutograd 编译函数与调度程序钩子的交互方式造成的。
在 distributed.py 中概述了使用 Dynamo 优化 DDP 的基本策略,其主要思想是在 DDP 桶边界 上进行图断裂。
当 DDP 中的每个节点需要与其它的节点同步权重时,它会将它的梯度和参数组织成桶,这减少了通信时间,并允许一个节点将它的一部分梯度广播到其他等待的节点。
分布式代码中的图断裂意味着你可以预期 Dynamo 及其后端会优化分布式程序的计算开销,但不会优化它的通信开销。如果减少的图大小剥夺了编译器融合的机会,图断裂可能会干扰编译加速。然而,随着图大小的增加,收益递减,因为大多数当前的计算优化都是局部融合。因此,在实践中,这种方法可能就足够了。
我是否仍然需要导出整个图?¶
对于绝大多数模型,你可能不需要,你可以像往常一样使用 torch.compile()
,但有一些情况需要完整的图,你可以通过简单地运行 torch.compile(..., nopython=True)
来确保完整的图。这些情况包括
大规模训练运行,例如需要管道并行和其他高级分片策略的 $250K+。
像 TensorRT 或 AITemplate 这样的推理优化器,它们比训练优化器更积极地进行融合。
移动训练或推理。
未来的工作将包括将通信操作跟踪到图中,协调这些操作与计算优化,并优化通信操作。
为什么我的代码崩溃了?¶
如果您的代码在没有 torch.compile
的情况下运行良好,但在启用它后开始崩溃,那么最重要的第一步是确定故障发生在堆栈的哪个部分。要排除故障,请按照以下步骤操作,只有在上一步骤成功后才尝试下一步。
torch.compile(..., backend="eager")
仅运行 TorchDynamo 正向图捕获,然后使用 PyTorch 运行捕获的图。如果这失败了,则说明 TorchDynamo 存在问题。torch.compile(..., backend="aot_eager")
运行 TorchDynamo 捕获正向图,然后运行 AOTAutograd 跟踪反向图,而无需任何额外的后端编译步骤。然后,PyTorch eager 将用于运行正向和反向图。如果这失败了,则说明 AOTAutograd 存在问题。torch.compile(..., backend="inductor")
运行 TorchDynamo 捕获正向图,然后运行 AOTAutograd 跟踪反向图,并使用 TorchInductor 编译器。如果这失败了,则说明 TorchInductor 存在问题。
为什么编译速度很慢?¶
Dynamo 编译– TorchDynamo 具有一个内置的统计函数,用于收集和显示每个编译阶段花费的时间。这些统计信息可以通过在执行
torch._dynamo
后调用torch._dynamo.utils.compile_times()
来访问。默认情况下,这将返回每个 TorchDynamo 函数按名称花费的编译时间的字符串表示形式。电感编译 - TorchInductor 具有内置的统计和跟踪功能,用于显示每个编译阶段花费的时间、输出代码、输出图形可视化和 IR 转储。
env TORCH_COMPILE_DEBUG=1 python repro.py
。这是一个调试工具,旨在简化 TorchInductor 内部机制的调试/理解,其输出类似于 this。该调试跟踪中的每个文件都可以通过torch._inductor.config.trace.*
启用/禁用。配置文件和图表默认情况下都处于禁用状态,因为它们生成成本很高。有关更多示例,请参见 example debug directory output。过度重新编译 当 TorchDynamo 编译函数(或其一部分)时,它会对局部变量和全局变量做出某些假设,以允许编译器优化,并将这些假设表示为在运行时检查特定值的保护。如果这些保护中的任何一个失败,Dynamo 将重新编译该函数(或部分)最多
torch._dynamo.config.cache_size_limit
次。如果您的程序达到缓存限制,您首先需要确定哪个保护失败以及您的程序的哪一部分触发了它。 recompilation profiler 自动执行将 TorchDynamo 的缓存限制设置为 1 并使用仅观察的“编译器”运行您的程序的过程,该编译器记录任何保护失败的原因。您应该确保运行您的程序至少与您遇到问题时运行的时间(迭代次数)一样长,并且分析器将在此期间累积统计信息。
from torch._dynamo.utils import CompileProfiler
def my_model():
...
with CompileProfiler() as prof:
profiler_model = torch.compile(my_model, backend=prof)
profiler_model()
print(prof.report())
为什么您在生产中重新编译?¶
在某些情况下,您可能不希望在程序预热后出现意外的编译。例如,如果您在延迟敏感的应用程序中提供生产流量。为此,TorchDynamo 提供了一种备用模式,其中使用先前编译的图形,但不会生成新的图形。
frozen_toy_example = dynamo.run(toy_example)
frozen_toy_example(torch.randn(10), torch.randn(10))
您如何加速我的代码?¶
有 3 种主要方法可以加速 PyTorch 代码
通过垂直融合进行内核融合,将顺序操作融合在一起以避免过度读写。例如,融合 2 个后续的余弦函数意味着您可以执行 1 次读 1 次写,而不是 2 次读 2 次写 2 次。水平融合:最简单的例子是批处理,其中单个矩阵与一批示例相乘,但更一般的情况是分组 GEMM,其中一组矩阵乘法被一起调度。
乱序执行:编译器的一种通用优化,通过提前查看图中的确切数据依赖关系,我们可以决定执行节点的最合适时间以及哪些缓冲区可以重用
自动工作放置:类似于乱序执行点,但通过将图的节点与物理硬件或内存等资源匹配,我们可以设计一个合适的调度方案
以上是加速 PyTorch 代码的通用原则,但不同的后端会在优化方面做出不同的权衡。例如,Inductor 首先会尽可能地进行融合,然后才会生成 Triton 内核。它还可以
Triton 另外还提供了加速功能,因为它可以自动进行内存合并、内存管理和每个流式多处理器内的调度,并且被设计用于处理平铺计算。
但是,无论您使用哪种后端,最好使用基准测试和方法,因此请尝试使用 PyTorch 分析器,直观地检查生成的内核,并尝试自己了解发生了什么。
为什么我没有看到加速?¶
图中断¶
您没有看到使用 dynamo 想要得到的加速的主要原因是图中断过多。那么什么是图中断呢?
给定一个程序,例如
def some_fun(x):
...
torch.compile(some_fun)(x)
...
Torchdynamo 将尝试将 some_fun()
中的所有 torch/tensor 操作编译成一个单独的 FX 图,但它可能无法将所有内容捕获到一个图中。
一些图中断的原因是 TorchDynamo 无法克服的,例如调用非 PyTorch 的 C 扩展对 TorchDynamo 是不可见的,并且可以执行任意操作,而 TorchDynamo 无法引入必要的保护措施来确保编译后的程序可以安全地重用。
为了最大限度地提高性能,重要的是要尽可能减少图中断。
识别图中断的原因¶
要识别程序中的所有图中断以及中断的相关原因,可以使用 torch._dynamo.explain
。此工具在提供的函数上运行 TorchDynamo,并汇总遇到的图中断。以下是一个示例用法
import torch
import torch._dynamo as dynamo
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
print("woo")
if b.sum() < 0:
b = b * -1
return x * b
explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
print(explanation)
"""
Dynamo produced 3 graphs, with 2 graph break and 6 ops.
Break reasons:
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
File "t2.py", line 16, in toy_example
print("woo")
2. generic_jump
File "t2.py", line 17, in toy_example
if b.sum() < 0:
"""
要对遇到的第一个图中断抛出错误,您可以使用 nopython=True
禁用 Python 回退,如果您使用过基于导出的编译器,应该很熟悉这一点。
def toy_example(a, b):
...
torch.compile(toy_example, fullgraph=True, backend=<compiler>)
为什么我的代码在我更改它时没有重新编译?¶
如果您通过设置 env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py
来启用动态形状,那么您的代码在形状发生变化时将不会重新编译。我们添加了对动态形状的支持,这在形状变化幅度小于 2 倍的情况下可以避免重新编译。这在 CV 中图像大小变化或 NLP 中序列长度可变等场景中特别有用。在推理场景中,通常无法事先知道批次大小,因为您会从不同的客户端应用程序中获取尽可能多的数据。
一般来说,TorchDynamo 会尽力避免不必要的重新编译,例如,如果 TorchDynamo 找到 3 个图,而您的更改只修改了一个图,那么只有该图会重新编译。因此,另一个避免潜在的缓慢编译时间的技巧是通过编译一次来预热模型,之后后续的编译将快得多。冷启动编译时间仍然是我们跟踪的可见指标。
为什么我得到不正确的结果?¶
如果您设置环境变量 TORCHDYNAMO_REPRO_LEVEL=4
,也可以最小化精度问题,它使用类似于 git bisect 的模型,完整的重现可能类似于 TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4
,我们需要这样做是因为下游编译器会生成代码,无论是 Triton 代码还是 C++ 后端,这些下游编译器的数值可能在细微之处有所不同,但会对您的训练稳定性产生重大影响。因此,精度调试器对我们检测代码生成或后端编译器中的错误非常有用。
如果您想确保随机数生成在 torch 和 triton 中相同,您可以启用 torch._inductor.config.fallback_random = True
为什么我遇到 OOM 错误?¶
Dynamo 仍然是一个 alpha 产品,因此存在一些 OOM 错误的来源,如果您遇到 OOM 错误,请尝试按照以下顺序禁用以下配置,然后在 GitHub 上创建一个问题,以便我们解决根本问题 1. 如果您使用动态形状,请尝试禁用它们,我们默认情况下已禁用它们:env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py
2. CUDA 图与 Triton 在 inductor 中默认启用,但移除它们可能会缓解一些 OOM 问题:torch._inductor.config.triton.cudagraphs = False
。
torch.func
是否与 torch.compile
一起使用(用于 grad
和 vmap
变换)?¶
将 torch.func
变换应用于使用 torch.compile
的函数不起作用
import torch
@torch.compile
def f(x):
return torch.sin(x)
def g(x):
return torch.grad(f)(x)
x = torch.randn(2, 3)
g(x)
此代码将无法正常工作。有一个 问题,您可以跟踪此问题。
作为解决方法,请在 torch.func
函数之外使用 torch.compile
注意
这是一个实验性功能,可以通过设置 torch._dynamo.config.capture_func_transforms=True
来使用
import torch
torch._dynamo.config.capture_func_transforms=True
def f(x):
return torch.sin(x)
@torch.compile
def g(x):
return torch.vmap(f)(x)
x = torch.randn(2, 3)
g(x)
在使用 torch.compile
处理的函数内部调用 torch.func
变换¶
使用 torch.compile
编译 torch.func.grad
¶
import torch
torch._dynamo.config.capture_func_transforms=True
def wrapper_fn(x):
return torch.func.grad(lambda x: x.sin().sum())(x)
x = torch.randn(3, 3, 3)
grad_x = torch.compile(wrapper_fn)(x)
使用 torch.compile
编译 torch.vmap
¶
import torch
torch._dynamo.config.capture_func_transforms=True
def my_fn(x):
return torch.vmap(lambda x: x.sum(1))(x)
x = torch.randn(3, 3, 3)
output = torch.compile(my_fn)(x)
限制¶
目前有一些情况不受支持,会导致图断裂(也就是说,torch.compile 在这些情况下会回退到急切模式的 PyTorch)。我们正在努力改进下一版本(PyTorch 2.2)的情况
1. 被变换的函数的输入和输出必须是张量。我们目前还不支持诸如张量元组之类的东西。
import torch
torch._dynamo.config.capture_func_transforms=True
def fn(x):
x1, x2 = x
return x1 + x2
def my_fn(x):
return torch.func.vmap(fn)(x)
x1 = torch.randn(3, 3, 3)
x2 = torch.randn(3, 3, 3)
# Unsupported, falls back to eager-mode PyTorch
output = torch.compile(my_fn)((x1, x2))
不支持关键字参数。
import torch
torch._dynamo.config.capture_func_transforms=True
def fn(x, y):
return (x + y).sum()
def my_fn(x, y):
return torch.func.grad(fn)(x, y=y)
x = torch.randn(3, 3)
y = torch.randn(3, 3)
# Unsupported, falls back to eager-mode PyTorch
output = torch.compile(my_fn)(x, y)
3. 具有可观察副作用的函数。例如,修改在函数中创建的列表是可以的,但修改在函数外部创建的列表是不可以的。
import torch
torch._dynamo.config.capture_func_transforms=True
some_list = []
def f(x, y):
some_list.append(1)
return x + y
def my_fn(x, y):
return torch.func.vmap(f)(x, y)
x = torch.ones(2, 3)
y = torch.randn(2, 3)
# Unsupported, falls back to eager-mode PyTorch
output = torch.compile(my_fn)(x, y)
torch.vmap
在调用以下列表中的一个或多个运算符的函数上。
注意
‘stride’,‘requires_grad’,‘storage_offset’,‘layout’,‘data’,‘is_coalesced’,‘is_complex’,‘is_conj’,‘is_contiguous’,‘is_cpu’,‘is_cuda’,‘is_distributed’,‘is_floating_point’,‘is_inference’,‘is_ipu’,‘is_leaf’,‘is_meta’,‘is_mkldnn’,‘is_mps’,‘is_neg’,‘is_nested’,‘is_nonzero’,‘is_ort’,‘is_pinned’,‘is_quantized’,‘is_same_size’,‘is_set_to’,‘is_shared’,‘is_signed’,‘is_sparse’,‘is_sparse_csr’,‘is_vulkan’,‘is_xla’,‘is_xpu’
import torch
torch._dynamo.config.capture_func_transforms=True
def bad_fn(x):
x.stride()
return x
def my_fn(x):
return torch.func.vmap(bad_fn)(x)
x = torch.randn(3, 3, 3)
# Unsupported, falls back to eager-mode PyTorch
output = torch.compile(my_fn)(x)
编译除支持的函数之外的函数(逃生舱)¶
对于其他转换,作为解决方法,请使用 torch._dynamo.allow_in_graph
allow_in_graph
是一个逃生舱。如果您的代码无法与 torch.compile
(它会内省 Python 字节码)一起使用,但您认为它可以通过符号跟踪方法(如 jax.jit
)工作,那么请使用 allow_in_graph
。
通过使用 allow_in_graph
来注释函数,您必须确保您的代码满足以下要求
函数中的所有输出仅取决于输入,而不依赖于任何捕获的张量。
您的函数是函数式的。也就是说,它不会改变任何状态。这可能会放宽;我们实际上支持从外部看起来是函数式的函数:它们可能具有就地 PyTorch 操作,但可能不会改变全局状态或函数的输入。
您的函数不会引发数据相关的错误。
import torch
@torch.compile
def f(x):
return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)
x = torch.randn(2, 3)
f(x)
一个常见的陷阱是使用 allow_in_graph
来注释调用 nn.Module
的函数。这是因为输出现在取决于 nn.Module
的参数。要使此方法起作用,请使用 torch.func.functional_call
来提取模块状态。
NumPy 是否与 torch.compile
一起使用?¶
从 2.1 开始,torch.compile
了解在 NumPy 数组上运行的原生 NumPy 程序,以及通过 x.numpy()
、torch.from_numpy
和相关函数在 PyTorch 和 NumPy 之间进行转换的混合 PyTorch-NumPy 程序。
torch.compile
支持哪些 NumPy 功能?¶
torch.compile
中的 NumPy 遵循 NumPy 2.0 预发布版。
通常,torch.compile
能够跟踪大多数 NumPy 构造,当它无法跟踪时,它会回退到急切模式并让 NumPy 执行该部分代码。即使这样,也有一些功能,torch.compile
的语义略微偏离了 NumPy 的语义
NumPy 标量:我们将其建模为 0 维数组。也就是说,
np.float32(3)
在torch.compile
下返回一个 0 维数组。为了避免图中断,最好使用这个 0 维数组。如果这破坏了你的代码,你可以通过将 NumPy 标量转换为相关的 Python 标量类型bool/int/float
来解决这个问题。负步长:
np.flip
和使用负步长的切片会返回一个副本。类型提升:NumPy 的类型提升将在 NumPy 2.0 中发生变化。新规则在 NEP 50 中描述。
torch.compile
实现 NEP 50 而不是当前即将弃用的规则。{tril,triu}_indices_from/{tril,triu}_indices
返回数组而不是数组元组。
还有其他一些功能,我们不支持跟踪,我们会优雅地回退到 NumPy 来执行它们。
非数值数据类型,如日期时间、字符串、字符、空值、结构化数据类型和记录数组。
长数据类型
np.float128/np.complex256
和一些无符号数据类型np.uint16/np.uint32/np.uint64
。ndarray
子类。掩码数组。
深奥的 ufunc 机制,如
axes=[(n,k),(k,m)->(n,m)]
和 ufunc 方法(例如,np.add.reduce
)。排序/排列
complex64/complex128
数组。NumPy
np.poly1d
和np.polynomial
。具有 2 个或更多返回值的函数中的位置
out1, out2
参数(out=tuple
是有效的)。__array_function__
、__array_interface__
和__array_wrap__
。ndarray.ctypes
属性。
我可以用 torch.compile
编译 NumPy 代码吗?¶
当然可以!torch.compile
本地理解 NumPy 代码,并将其视为 PyTorch 代码。为此,只需用 torch.compile
装饰器包装 NumPy 代码即可。
import torch
import numpy as np
@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)
如果在执行此示例时设置环境变量 TORCH_LOGS=output_code
,我们可以看到 torch.compile
能够将乘法和加法融合成一个 C++ 内核。它还能够使用 OpenMP 并行执行它们(原生 NumPy 是单线程的)。这可以轻松地使您的 NumPy 代码速度提高 n
倍,其中 n
是您处理器中的核心数量!
以这种方式跟踪 NumPy 代码还支持编译代码中的图中断。
我是否可以通过 torch.compile
在 CUDA 上执行 NumPy 代码并计算梯度?¶
是的,您可以!为此,您只需在 torch.device("cuda")
上下文中执行您的代码即可。考虑以下示例
import torch
import numpy as np
@torch.compile
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = np.random.randn(1024, 64)
Y = np.random.randn(1024, 64)
with torch.device("cuda"):
Z = numpy_fn(X, Y)
assert isinstance(Z, np.ndarray)
在这个例子中,numpy_fn
将在 CUDA 上执行。为了实现这一点,torch.compile
会自动将 X
和 Y
从 CPU 移动到 CUDA,然后将结果 Z
从 CUDA 移动到 CPU。如果我们在同一个程序运行中多次执行此函数,我们可能希望避免所有这些相当昂贵的内存复制。为此,我们只需要调整我们的 numpy_fn
,使其接受 CUDA 张量并返回张量。我们可以使用 torch.compiler.wrap_numpy
来做到这一点
@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"
在这里,我们显式地在 CUDA 内存中创建张量,并将它们传递给函数,该函数在 CUDA 设备上执行所有计算。 wrap_numpy
负责在 torch.compile
级别将任何 torch.Tensor
输入标记为具有 np.ndarray
语义的输入。在编译器内部标记张量是一个非常便宜的操作,因此在运行时不会发生任何数据复制或数据移动。
使用此装饰器,我们还可以通过 NumPy 代码进行微分!
@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
X = torch.randn(1024, 64, device="cuda", requires_grad=True)
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
Z.backward()
# X.grad now holds the gradient of the computation
print(X.grad)
我们一直在使用 fullgraph=True
,因为图中断在这种情况下存在问题。当发生图中断时,我们需要将 NumPy 数组具体化。由于 NumPy 数组没有 device
或 requires_grad
的概念,因此此信息在图中断期间会丢失。
我们无法通过图中断传播梯度,因为图中断代码可能会执行不知道如何微分的任意代码。另一方面,在 CUDA 执行的情况下,我们可以像在第一个示例中那样解决这个问题,方法是使用 torch.device("cuda")
上下文管理器
@torch.compile
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
prod = X[:, :, None] * Y[:, None, :]
print("oops, a graph break!")
return np.sum(prod, axis=(-2, -1))
X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
with torch.device("cuda"):
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"
在图中断期间,中间张量仍然需要移动到 CPU,但是当跟踪在图中断后恢复时,图的其余部分仍然在 CUDA 上跟踪。鉴于这种 CUDA <> CPU 和 CPU <> CUDA 移动,图中断在 NumPy 上下文中相当昂贵,应该避免,但至少它们允许通过复杂的代码片段进行跟踪。
如何在 torch.compile
下调试 NumPy 代码?¶
调试 JIT 编译的代码具有挑战性,因为现代编译器的复杂性和它们引发的令人望而生畏的错误。 关于如何在 torch.compile 中诊断运行时错误的教程 包含一些关于如何解决此任务的技巧和窍门。
如果以上内容不足以确定问题的根源,我们仍然可以使用一些其他特定于 NumPy 的工具。我们可以通过禁用 NumPy 函数的跟踪来判断错误是否完全在 PyTorch 代码中。
from torch._dynamo import config
config.trace_numpy = False
如果错误位于跟踪的 NumPy 代码中,我们可以通过导入 import torch._numpy as np
,以 PyTorch 作为后端,急切地执行 NumPy 代码(不使用 torch.compile
)。这仅应用于 **调试目的**,绝不是 PyTorch API 的替代品,因为它 **性能要低得多**,并且作为私有 API,**可能会在未经通知的情况下更改**。无论如何,torch._numpy
是用 PyTorch 实现的 NumPy 的 Python 版本,它被 torch.compile
内部用于将 NumPy 代码转换为 Pytorch 代码。它很容易阅读和修改,因此如果您发现任何错误,请随时提交修复它的 PR 或简单地打开一个问题。
如果程序在导入 torch._numpy as np
时可以正常工作,那么很可能是 TorchDynamo 中存在错误。如果是这种情况,请随时打开一个包含 最小可重现示例 的问题。
我 torch.compile
了一些 NumPy 代码,但我没有看到任何加速。¶
最好的起点是 包含如何调试此类 torch.compile 问题的通用建议的教程。
由于使用不支持的功能,可能会出现一些图中断。请参阅 torch.compile 支持哪些 NumPy 功能?。更一般地说,需要注意的是,一些广泛使用的 NumPy 功能与编译器不兼容。例如,就地修改会使编译器难以推理,并且通常比其非就地对应物性能更差。因此,最好避免它们。使用 out=
参数也是如此。相反,请优先使用非就地操作,并让 torch.compile
优化内存使用。对于数据相关的操作,例如通过布尔掩码进行的掩码索引,或数据相关的控制流,例如 if
或 while
结构,也是如此。
使用哪个 API 进行细粒度跟踪?¶
在某些情况下,您可能需要将代码的某些小部分排除在 torch.compile 编译之外。本节提供了一些答案,您可以在 TorchDynamo API 用于细粒度跟踪 中找到更多信息。
如何在函数上进行图中断?¶
在函数上进行图中断不足以充分表达您希望 PyTorch 执行的操作。您需要更具体地说明您的用例。您可能需要考虑的一些最常见的用例
如果您想在此函数帧和递归调用的帧上禁用编译,请使用
torch._dynamo.disable
。如果您想让特定运算符(例如
fbgemm
)使用急切模式,请使用torch._dynamo.disallow_in_graph
。
一些不常见的用例包括
如果您想在函数帧上禁用 TorchDynamo,但在递归调用的帧上重新启用它,请使用
torch._dynamo.disable(recursive=False)
。如果您想阻止函数帧内联,请在要阻止内联的函数开头使用
torch._dynamo.graph_break
。
torch._dynamo.disable
和 torch._dynamo.disallow_in_graph
之间的区别是什么?¶
Disallow-in-graph 在操作符级别起作用,更准确地说,是在 TorchDynamo 提取的图中看到的操作符级别。
Disable 在函数帧级别起作用,并决定 TorchDynamo 是否应该查看函数帧。
torch._dynamo.disable
和 torch._dynamo_skip
之间的区别是什么?¶
注意
torch._dynamo_skip
已弃用。
您很可能需要 torch._dynamo.disable
。但在极少数情况下,您可能需要更精细的控制。假设您想禁用 a_fn
函数上的跟踪,但想继续在 aa_fn
和 ab_fn
中进行跟踪。下图演示了这种用例

在这种情况下,您可以使用 torch._dynamo.disable(recursive=False)
。在以前的版本中,此功能由 torch._dynamo.skip
提供。现在,torch._dynamo.disable
内部的 recursive
标志支持此功能。