快捷方式

torch.compile 故障排除

您正尝试在 PyTorch 模型上使用 torch.compile 来提升性能,但其表现不如预期。也许性能并未提升,或出现崩溃,或编译时间过长。本文提供了一些技巧、变通方法和调试工具,帮助您克服这些挑战。

目录

设定预期

torch.compile 被设计为一个通用的 PyTorch 编译器。与先前的编译器解决方案 TorchScript 不同,torch.compile 需要更少的代码修改,这意味着通常无需从头重写模型。它还能更优雅地处理不受支持的代码——不受支持的代码只会损失优化机会,而不会导致崩溃。

在理想情况下,只需将 torch.compile 应用于任何 PyTorch 模型即可享受自动加速。然而,实际上,代码复杂性可能导致以下三种情况之一:

  1. torch.compile 无缝工作,提供加速。

  2. 需要进行一些代码修改。torch.compile 不会崩溃或花费太长时间,但您可能未看到显著的性能提升。

  3. 需要对代码进行大量修改。

我们预计大多数代码将属于情景 (1) 和 (2)。本文档提供了一些技巧,按涉入程度排列,以帮助解决情景 (2) 中的代码问题。

编译时间

torch.compile 作为即时编译器运行,因此编译函数的最初一两次运行预计会显著变慢。重新编译(在某些条件下发生,详情见下文)也会使运行变慢。各种 torch.compile 组件会缓存结果,以减少未来调用(甚至在不同进程中)的编译时间。常见或基准模型的冷启动(未缓存)编译时间通常在几秒到几分钟不等。大型模型可能需要 30 分钟到几个小时。

术语

以下术语与排查 torch.compile 问题相关。

图中断

torch.compile 追踪您的代码,并尝试将 PyTorch 代码捕获到 PyTorch 算子的单个计算图(FX 图)中。然而,这并非总是可能。当遇到无法追踪的代码时,会发生“图中断”。图中断涉及编译已确定部分的 FX 图,运行不受支持的代码,然后在不受支持的代码之后使用新的 FX 图恢复追踪。由于计算图被中断,我们失去了优化机会,因此模型代码应尽可能避免图中断。图中断发生在以下情况:

  • 数据依赖的 if 语句

  • 许多 Python 内置函数

  • C 函数

下面是一个因 Python 内置库函数 copy.deepcopy 导致的图中断示例(确切输出可能有所不同)。

import torch

@torch.compile
def fn(x):
    x = x + 1
    with open("test.txt", "r") as f:
        return x + len(f.read())

fn(torch.ones(3, 3))
$TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in fn
    with open("test.txt", "r") as f:
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 635, in wrapper
    return inner_fn(self, inst)
        ^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2414, in CALL
    self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _call
    self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 962, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 997, in call_function
    return handler(tx, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 831, in <lambda>
    return lambda *args: unimplemented(error_msg)
                        ^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False

守卫

torch.compile 在追踪代码时会对运行时值做一些假设。在追踪过程中,我们生成“守卫”,即这些假设的运行时检查。守卫会在未来调用编译函数时运行,以确定我们是否可以重用先前编译的代码。运行时检查的例子包括常量值、类型和对象 ID。

下面是生成的守卫示例。 TENSOR_MATCH 守卫检查输入的类型、设备、dtype、形状等。

import torch

@torch.compile
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
$ TORCH_LOGS="guards" python playground.py
GUARDS:

TREE_GUARD_MANAGER:
+- RootGuardManager
| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:471 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state()
| +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
| +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1])  # return x + 1  # playground.py:6 in fn
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # return x + 1  # playground.py:6 in fn

重新编译

如果先前编译代码的每次实例的守卫都失败,那么 torch.compile 必须“重新编译”该函数,这需要再次追踪原始代码。

在下面的示例中,重新编译是必要的,因为检查张量参数形状的守卫失败了。

import torch

@torch.compile
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
    triggered by the following guard failure(s):
    - 0/0: tensor 'L['x']' size mismatch at index 0. expected 3, actual 4

动态形状

torch.compile 最初假设张量形状是静态/常量,并基于这些假设进行守卫。通过使用“动态形状”,我们可以让 torch.compile 生成可以接受不同形状张量输入的编译代码——我们避免在形状不同时每次都重新编译。默认情况下,自动动态形状已启用 torch.compile(dynamic=None)——如果因形状不匹配而编译失败,则会尝试使用动态形状进行重新编译。动态形状也可以完全启用 dynamic=True 或禁用 dynamic=False

下面,我们启用动态形状,并注意到我们不再需要重新编译。

import torch

@torch.compile(dynamic=True)
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="dynamic,recompiles" python playground.py
create_symbol s0 = 3 for L['x'].size()[0] [2, int_oo] at playground.py:5 in fn (_dynamo/variables/builder.py:2718 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
produce_guards
produce_guards

有关动态形状的更多信息,请参阅 动态形状手册

日志工具

tlparse / TORCH_TRACE

tlparse / TORCH_TRACE 是一对工具,用于生成编译报告,如下所示:https://web.mit.edu/~ezyang/Public/bhack-20240609-tlparse/index.html

追踪日志收集非常简单。要收集追踪日志,使用以下命令运行您的复现命令:

TORCH_TRACE="/tmp/tracedir" python foo.py
pip install tlparse
tlparse /tmp/tracedir

即使您正在运行分布式作业,此方法也有效,为每个进程提供一个追踪日志。它将在您的浏览器中打开与上面生成的 HTML 类似的页面。如果您正在报告一个复杂问题且没有独立的复现示例,您仍然可以通过附带在 /tmp/tracedir 中生成的追踪日志来极大地帮助 PyTorch 开发者。

警告

追踪日志包含您的所有模型代码。如果您的模型是敏感的,请勿共享追踪日志。追踪日志不包含权重。

tlparse 的输出主要面向 PyTorch 开发者,且日志格式易于上传和在 GitHub 上共享。然而,作为非 PyTorch 开发者,您仍然可以从中提取有用信息。我们建议从报告中的内联帮助文本开始,其中解释了报告内容。以下是您可以从 tlparse 中获得的一些见解:

  • 通过查看堆栈树,哪些模型代码被编译了?如果您不熟悉正在编译的代码库,这尤其有用!

  • 有多少图中断/不同的编译区域?(每个不同的编译都是一个独立的颜色编码块,如 [0/0])。可能是图中断的帧显示为浅绿色 [2/4]。如果帧很多,这很可疑,表明您遇到了一些灾难性的图中断,或者您的代码可能不太适合 torch.compile

  • 我重新编译了特定帧多少次?重新编译多次的帧会看起来像:[10/0] [10/1] [10/2]——如果某项内容被重新编译多次,则非常可疑,值得深入研究,即使它不是您问题的根本原因。

  • 是否发生了编译错误?出错的帧会看起来像 [0/1]

  • 我为给定帧生成了哪些中间编译器产品?例如,您可以查看高级生成的 FX 图或生成的 Triton 代码。

  • 特定帧是否有相关信息?您可以在 compilation_metrics 中找到这些信息。

TORCH_LOGS

您可以使用 TORCH_LOGS 环境变量选择性地启用 torch.compile 堆栈的某些部分进行日志记录。TORCH_LOGS 实际上是 tlparse 的日志来源。TORCH_LOGS 环境变量的格式如下:

TORCH_LOGS="<option1>,<option2>,..." python foo.py

有用的高级选项包括:

  • graph_breaks:记录用户代码中图中断的位置和原因

  • guards:记录生成的守卫

  • recompiles:记录哪些函数重新编译以及导致重新编译失败的守卫

  • dynamic:记录与动态形状相关的信息

此外,您可以使用 torch._logging.set_logs 以编程方式设置日志选项:

import logging
torch._logging.set_logs(graph_breaks=True)
...

更多 TORCH_LOGS 选项在下方详细说明。有关完整选项列表,请参阅torch._loggingtorch._logging.set_logs

tlparse vs. TORCH_LOGS

一般来说,我们建议在遇到问题时首先使用 tlparsetlparse 非常适合调试大型模型并获得模型编译的高级概览。另一方面,当对导致问题的 torch.compile 组件已有概念时,TORCH_LOGS 更适合小型示例和细粒度调试。

简单变通方法

在这里,我们描述了一些解决 torch.compile 问题的方法,这些方法涉及少量代码修改或更改一些 torch.compile 设置。

在哪里应用 torch.compile?

我们建议将 torch.compile 应用于不会导致过度问题的最高层函数。通常,这是您带有优化器但不包含循环的训练或评估步骤、您的顶层 nn.Module,或某些子 nn.Module``s。 ``torch.compile 对 DDP 或 FSDP 等分布式包装器模块处理得不是很好,因此请考虑将 torch.compile 应用于传递给包装器的内部模块。

# inference
model = ...
opt_model = torch.compile(model)

for _ in range(N_ITERS):
    inp = ...
    out = opt_model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())

@torch.compile
def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

for _ in range(N_ITERS):
    inp = ...
    train(model, inp)
# DistributedDataParallel
model = ...
opt_model = torch.compile(model)
model_ddp = DistributedDataParallel(opt_model, ...)

for _ in range(N_ITERS):
    inp = ...
    out = model_ddp(inp)

禁用和抑制错误

对于某些模型架构,模型中有部分特别难以编译——要么有很多图中断,要么会崩溃。您可能希望明确禁用模型中这些有问题的部分,以便将 torch.compile 应用于可以正常工作的部分。您可以使用 @torch.compiler.disable 装饰器来实现这一点。当 torch.compile 尝试调用禁用的函数时,它会中断图并跳过追踪禁用函数,在调用后恢复追踪。默认情况下,从禁用函数进行的所有递归调用也会被禁用。使用 recursive=False 选项允许对递归调用进行编译。

def bad1_inner(...):
    # skipped

@torch.compiler.disable
def bad1_outer(...):
    # skipped
    bad1_inner(...)

def bad2_inner(...)
    # traced

@torch.compiler.disable(recursive=False)
def bad2_outer(...):
    # skipped
    bad2_inner(...)

@torch.compile
def fn(...):
    # graph break
    bad1_outer(...)
    ...
    # graph break
    bad2_outer(...)

例如,我们使用 torch.compiler.disable 禁用推荐模型中稀疏架构上的 torch.compile,因为稀疏架构难以编译。预处理和日志记录函数是通常会导致许多图中断且从编译中得不到价值的其他函数示例。

如果您遇到编译器崩溃并希望继续,可以将 torch._dynamo.config.suppress_errors = True。当编译器崩溃时,我们将跳过追踪该函数,稍后再试。这不是最佳实践——最好最终手动添加必要的禁用注解。

解决图中断

为了最大化优化机会,减少图中断的数量非常重要。请记住,您可以使用 tlparseTORCH_LOGS="graph_breaks" 查看发生了哪些图中断。一般来说,图中断是由以下原因之一引起的:

  1. 您尝试做一些根本无法追踪的事情,例如数据依赖的控制流。

  2. 您尝试做一些尚未支持的事情。例如,我们目前对追踪使用 Python 内置 inspect 模块的代码支持有限。

  3. 您的代码本身存在错误,或者遇到了来自 torch.compile 外部的错误。例如,您可能尝试调用函数时提供了错误的参数数量。

图中断日志会告诉您用户代码位置以及图中断的原因。不幸的是,许多图中断如果不深入了解 Dynamo 则无法处理。甚至很难确定是哪三种原因之一导致了您的图中断。我们正在努力使图中断消息更具可操作性。

此外,丢失优化机会的影响因图中断而异。例如,发生在模型 forward 中间的图中断可能比 forward 开头的预处理部分中的图中断产生更负面的影响。因此,阻止 每一个 中断并不是关键,关键是阻止那些导致显著性能损失的中断。

如果图中断消息没有提出任何建议,您怀疑您的图中断原因是 (2),并且您认为该图中断正在导致性能损失,那么请将其作为问题报告。如果一个函数有很多图中断,请考虑在该函数上禁用编译,因为图中断的开销可能会变得过高。

下面是一些常见的图中断和一些变通方法。

数据依赖操作

torch.compile 在数据依赖操作(例如数据依赖的控制流(if 语句、带有张量的循环)和直接张量数据访问(.item.data_ptr))上会发生图中断。

import torch

@torch.compile
def fn(x):
    y = x.sum()
    if y > 0:
        return x + y.item()
    return x - y.item()

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:6
Reason: Data-dependent jump
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 6, in fn
    if y > 0:

Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: Tensor.item
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in torch_dynamo_resume_in_fn_at_6
    return x + y.item()
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 616, in wrapper
    return inner_fn(self, inst)
        ^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2288, in CALL
    self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2282, in _call
    self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 838, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py", line 1038, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 527, in call_method
    result = handler_method(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 773, in method_item
    unimplemented("Tensor.item")
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 304, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item

这些图中断的一般变通方法是避免进行数据依赖操作。一些具体的变通方法是:

  • 如果您的控制流实际上不依赖于数据值,请考虑修改代码,使其在常量上执行控制流。

# old
x = torch.randn(3, 3)
@torch.compile
def fn(y):
    if x.sum() > 0:
        return y + x
    else:
        return y - x

# new
x = torch.randn(3, 3)
cond = (x.sum() > 0).item()
@torch.compile
def fn(y):
    if cond:
        return y + x
    else:
        return y - x
# old
@torch.compile
def fn(x):
    if x.sum() > 0:
        return x + 1
    return x - 1

# new
@torch.compile
def fn(x):
    return torch.cond(
        x.sum() > 0,
        lambda x: x + 1,
        lambda x: x - 1,
        (x,),
    )
  • 如果您调用了 .item(),请尝试 torch._dynamo.config.capture_scalar_outputs = TrueTORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1

  • 将函数的有问题部分包装在自定义算子中

自定义算子

如果您有 torch.compile 难以追踪的代码,无论是由于缺少支持还是根本不兼容,您可以考虑将有问题代码包装在自定义算子中。

自定义算子需要额外做一些工作才能与 torch.compile 兼容。请参阅 https://pytorch.ac.cn/tutorials/advanced/custom_ops_landing_page.html 获取更多详细信息。

打印

打印/日志记录/发出警告会导致图中断。如果您的函数进行了许多日志调用(例如,记录训练迭代数据的函数),请考虑对其应用 torch.compiler.disable

或者,您可以尝试使用 torch._dynamo.config.reorderable_logging_functions。此配置用于重新排序日志记录函数,使其在追踪函数的末尾调用,从而避免图中断。但是,如果发生了变动,日志记录的内容可能会有所不同。

import torch

torch._dynamo.config.reorderable_logging_functions.add(print)

@torch.compile
def fn(x):
    x += 1
    print("log!")
    return torch.sin(x)

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
log!

错误代码

您的代码可能有误,或者遇到了来自 torch.compile 外部的错误。在下面的代码中,我们在调用 torch.sin 时因提供了额外的参数而导致了输入错误。

import torch

@torch.compile
def fn(x):
    y = torch.sin(x, x)
    return y

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:5
Reason: Unsupported: TypeError <built-in method sin of type object at 0x7fd6fd764600>: sin() takes 1 positional argument but 2 were given
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 5, in fn
    y = torch.sin(x, x)
...

从日志中很难判断错误是由您的代码引起的还是 torch.compile 的 bug 引起的。为了区分,我们建议尝试不使用 torch.compile 运行您的代码,看看是否仍然会收到错误。

处理重新编译

您可以使用 tlparseTORCH_LOGS=recompiles 查看重新编译及其原因。

动态形状是否启用?

由于形状不匹配导致的重新编译形式如下:

tensor 'L['x']' size mismatch at index 0. expected 3, actual 4

确保 torch.compiledynamic 选项未设置为 False。默认选项 dynamic=None 只会在首次编译后尝试动态形状。您可以将 dynamic=True 设置为预先尽可能地编译为动态。

有关动态形状的更多信息,请参阅 动态形状手册

更改缓存大小限制

函数可以重新编译的次数受到 torch._dynamo.config.recompile_limittorch._dynamo.config.accumulated_recompile_limit 的限制。如果超出任一限制,我们将不再尝试重新编译该函数,而是急切地运行该函数。torch.compile 还会发出警告,包含受影响的函数和触及的限制。在下面的示例中,每次函数调用都会导致重新编译尝试。当我们触及缓存大小限制 (8) 时,我们停止尝试重新编译。

import torch

@torch.compile(dynamic=False)
def fn(x):
    return x + 1

for i in range(1, 10):
    fn(torch.ones(i))
$ python playground.py
torch._dynamo hit config.recompile_limit (8)
    function: 'fn' (/data/users/williamwen/pytorch/playground.py:5)
    last reason: 0/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 9

如果您知道重新编译次数有一个合理的常量上限,可以提高缓存大小限制。如果重新编译的成本超过编译的好处,则可以考虑降低缓存大小限制。

将常量包装为张量

默认情况下,int / float 变量被视为常量并因此受到守卫。在下面的示例中,每次函数调用都会导致一次重新编译。

import torch

@torch.compile
def fn(x, c):
    return x + c

for i in range(1, 10):
    fn(torch.ones(i), 0.5 + i)
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
    triggered by the following guard failure(s):
    - 0/7: L['c'] == 8.5
    - 0/6: L['c'] == 7.5
    - 0/5: L['c'] == 6.5
    - 0/4: L['c'] == 5.5
    - 0/3: L['c'] == 4.5
    - 0/2: L['c'] == 3.5
    - 0/1: L['c'] == 2.5
    - 0/0: L['c'] == 1.5
torch._dynamo hit config.recompile_limit (8)
    function: 'fn' (/data/users/williamwen/pytorch/playground.py:3)
    last reason: 0/0: L['c'] == 1.5

特别是对于 LR 调度器,使用常量初始化可能导致重新编译:

import torch

mod = torch.nn.Linear(3, 3)
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)

@torch.compile
def fn(inp):
    opt.zero_grad(True)
    out = mod(inp).sum()
    out.backward()
    opt.step()
    sched.step()

for i in range(1, 10):
    fn(torch.ones(3, 3))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function step in /data/users/williamwen/pytorch/torch/optim/adam.py:189
    triggered by the following guard failure(s):
    - 3/7: L['self'].param_groups[0]['lr'] == 0.004782969000000002
    - 3/6: L['self'].param_groups[0]['lr'] == 0.005314410000000002
    - 3/5: L['self'].param_groups[0]['lr'] == 0.005904900000000002
    - 3/4: L['self'].param_groups[0]['lr'] == 0.006561000000000002
    - 3/3: L['self'].param_groups[0]['lr'] == 0.007290000000000001
    - 3/2: L['self'].param_groups[0]['lr'] == 0.008100000000000001
    - 3/1: L['self'].param_groups[0]['lr'] == 0.009000000000000001
    - 3/0: L['self'].param_groups[0]['lr'] == 0.01
torch._dynamo hit config.recompile_limit (8)
    function: 'step' (/data/users/williamwen/pytorch/torch/optim/adam.py:189)
    last reason: 3/0: L['self'].param_groups[0]['lr'] == 0.01

在这两个示例中,我们可以将浮点变量包装在张量中,以防止重新编译。

# first example
for i in range(1, 10):
    fn(torch.ones(i), torch.tensor(0.5 + i))

# second example
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))

报告问题

如果上述提供的变通方法不足以使 torch.compile 工作,那么您应该考虑向 PyTorch 报告问题。但您可以做一些事情,使我们的工作变得更轻松。

消融

使用 torch.compilebackend= 选项检查是 torch.compile 堆栈的哪个组件导致了问题。特别是,尝试:

  • torch.compile(fn, backend="eager"),仅运行 TorchDynamo,它是 torch.compile 的图捕获组件。

  • torch.compile(fn, backend="aot_eager"),运行 TorchDynamo 和 AOTAutograd,它在编译期间额外生成反向图。

  • torch.compile(fn, backend="aot_eager_decomp_partition"),运行 TorchDynamo 和带有算子分解/分区的 AOTAutograd。

  • torch.compile(fn, backend="inductor"),运行 TorchDynamo、AOTAutograd 和 TorchInductor,它是生成编译核的后端 ML 编译器。

如果您仅在使用 Inductor 后端时失败,可以额外测试各种 Inductor 模式:

  • torch.compile(fn, backend="inductor", mode="default")

  • torch.compile(fn, backend="inductor", mode="reduce-overhead")

  • torch.compile(fn, backend="inductor", mode="max-autotune")

您还可以检查动态形状是否正在导致任何后端出现问题:

  • torch.compile(fn, dynamic=True) (始终使用动态形状)

  • torch.compile(fn, dynamic=False) (从不使用动态形状)

  • torch.compile(fn, dynamic=None) (自动动态形状)

二分法

您是否尝试过最新的每夜构建版?过去正常工作的功能现在是否无法使用了?您能否通过双向查找确定问题首次出现的每夜构建版?对于性能、准确性或编译时间回归等问题,双向查找特别有用,因为这些问题不立即明确其来源。

创建可复现示例

创建可复现示例需要大量工作,如果您没有时间做,完全没问题。但是,如果您是一位对 torch.compile 内部机制不熟悉的有意愿的用户,创建一个独立的可复现示例能极大地帮助我们修复 bug。没有可复现示例,您的 bug 报告必须包含足够的信息,以便我们找出问题的根本原因并从头编写一个可复现示例。

以下是按偏好程度从高到低排列的有用可复现示例列表

  1. 自包含的小型可复现示例: 没有外部依赖、代码少于 100 行的脚本,运行时能够复现问题。

  2. 自包含的大型可复现示例: 即使它很大,自包含也是一个巨大的优势!

  3. 非自包含但依赖易于管理的复现示例: 例如,如果您在 pip install transformers 后运行脚本可以复现问题,这是易于管理的。我们很可能能够运行它并进行调查。

  4. 非自包含且需要大量设置的复现示例: 这可能涉及下载数据集、多个环境设置步骤,或者需要特定系统库版本并需要 Docker 镜像。设置越复杂,我们越难重现环境。

    注意

    Docker 简化了设置,但使环境修改复杂化,因此它不是一个完美的解决方案,但如果必要,我们还是会使用它。

在某种程度上,与上述分类正交的是,可以在单个进程中运行的可复现示例好于需要多进程训练的可复现示例(但再次强调,如果您只有一个多进程可复现示例,我们也接受!)。

此外,以下是您的问题中需要检查的方面(可尝试在可复现示例中复现)的非详尽列表

  • 自动微分 (Autograd)。您的输入 tensor 是否设置了 requires_grad=True?您是否在输出上调用了 backward()

  • 动态形状 (Dynamic shapes)。是否设置了 dynamic=True?或者您是否使用不同形状多次运行了测试代码?

  • 自定义算子 (Custom operators)。实际工作流程中是否涉及自定义算子?能否使用 Python 自定义算子 API 复现其一些重要特性?

  • 配置 (Configuration)。是否设置了所有相同的配置?这包括 torch._dynamo.configtorch._inductor.config 设置,以及 torch.compile 的参数,如 backend / mode

  • 上下文管理器 (Context managers)。是否复现了任何活跃的上下文管理器?这可能是 torch.no_grad、自动混合精度、TorchFunctionMode / TorchDispatchMode、激活检查点、编译后的自动微分等。

  • Tensor 子类 (Tensor subclasses)。是否涉及 tensor 子类?

缩小器 (Minifier)

缩小器是一个早期的 torch.compile 工具,它给定一个在我们尝试运行或编译时崩溃的 FX graph,然后找到一个也会崩溃的子图,并输出执行该子图操作的代码。本质上,缩小器能找到针对特定类型 torch.compile 相关崩溃的最小可复现示例。这假设我们能够成功地跟踪代码。

不幸的是,如今大多数时候,缩小器无法按预期工作,可能需要替代方法。这可能是因为通过这种方式可以自动复现的 bug 通常更容易修复并且已经被解决,留下了更难复现的复杂问题。然而,尝试使用缩小器是很直接的,所以即使可能不会成功,也值得一试。

缩小器的使用说明可以在 这里 找到。如果编译器崩溃,您可以设置 TORCHDYNAMO_REPRO_AFTER="dynamo"TORCHDYNAMO_REPRO_AFTER="aot"aot 选项更可能成功,尽管它可能无法识别 AOTAutograd 的问题。这将生成 repro.py 文件,这可能有助于诊断问题。对于准确性相关的问题,考虑设置 TORCHDYNAMO_REPRO_LEVEL=4。请注意,这可能无法始终成功地识别出有问题的子图。

深入调试

本节提供用于独立调试 torch.compile 问题或更深入理解 torch.compile 技术栈的工具和技术。这些方法比上面介绍的更复杂,是 PyTorch 开发者日常用于调试实际 torch.compile 问题的方法。

以下是技术栈的高层概述

_images/td_stack.png

该技术栈包含三个主要组件:TorchDynamo、AOTAutograd 和 Inductor。我们的调试策略包括首先确定错误发生的组件,然后单独调试该组件。要确定导致问题的组件,请参阅上面 报告问题 (Reporting Issues) 部分下的 消融分析 (Ablation) 子节。关于调试特定组件的指导,请查阅以下各节。

TorchDynamo

记录 Dynamo 跟踪的内容

TORCH_LOGS=trace_bytecode 选项使您能够查看 Dynamo 正在跟踪的精确字节码指令以及 Python 解释器堆栈的符号表示。当遇到 graph break 或崩溃时,建议检查最后跟踪的几条字节码指令。

您还可以使用 TORCH_LOGS=trace_source 来查看 Dynamo 正在跟踪哪些源代码行。这与 trace_bytecode 结合使用非常有用,可以查看每条被跟踪的字节码指令对应的源代码行。

最后,您可以使用 TORCH_LOGS=graph_code 查看 Dynamo 跟踪的 FX graph 所对应的 Python 代码。您可以查看此代码,以仔细检查是否跟踪了正确的算子。

import torch

def g(x, y):
    return x + y

@torch.compile(backend="eager")
def f(x):
    x = torch.sin(x)
    x = g(x, x)
    return x

f(torch.ones(3, 3))
$ TORCH_LOGS="trace_bytecode,trace_source,graph_code" python playground.py
TRACE starts_line /data/users/williamwen/pytorch/playground.py:6 in f ()
    @torch.compile(backend="eager")
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:8 in f (f)
        x = torch.sin(x)
TRACE LOAD_GLOBAL torch []
TRACE LOAD_ATTR sin [NullVariable(), PythonModuleVariable(<module 'torch' from '/data/users/williamwen/pytorch/torch/__init__.py'>)]
TRACE LOAD_FAST x [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>)]
TRACE CALL 1 [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>), LazyVariableTracker()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:9 in f (f)
        x = g(x, x)
TRACE LOAD_GLOBAL g []
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable()]
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable(), TensorVariable()]
TRACE CALL 2 [NullVariable(), UserFunctionVariable(), TensorVariable(), TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:3 in g (g) (inline depth: 1)
    def g(x, y):
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:4 in g (g) (inline depth: 1)
        return x + y
TRACE LOAD_FAST x []
TRACE LOAD_FAST y [TensorVariable()]
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
TRACE RETURN_VALUE None [TensorVariable()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:10 in f (f)
        return x
TRACE LOAD_FAST x []
TRACE RETURN_VALUE None [TensorVariable()]
TRACED GRAPH
===== __compiled_fn_1 =====
/data/users/williamwen/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3][3, 1]cpu"):
        l_x_ = L_x_

        # File: /data/users/williamwen/pytorch/playground.py:8 in f, code: x = torch.sin(x)
        x: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None

        # File: /data/users/williamwen/pytorch/playground.py:4 in g, code: return x + y
        x_1: "f32[3, 3][3, 1]cpu" = x + x;  x = None
        return (x_1,)

在 Dynamo 跟踪中设置断点

在 Dynamo/用户代码中插入断点有时有助于查看 Dynamo 在跟踪用户代码时的状态。不幸的是,以正常的 Python 方式插入断点会导致 TorchDynamo 中发生 graph break,因此我们将无法在我们打算设置断点的点查看 Dynamo 的状态。

设置断点的第一种方法是在 Dynamo 源代码中插入断点。建议放置断点的三个位置是

  • torch/_dynamo/symbolic_convert.py 中,在以有问题的字节码指令命名的函数处设置断点,例如 def CALL_FUNCTIONdef STORE_ATTR。您可以根据输入有条件地设置断点,例如指令的 argval,或栈顶对象的名称,因为有些字节码操作码使用频繁。

  • 在 graph break 或错误起源处设置断点。通常,graph break 是由调用 unimplemented(...) 发出的。

  • torch/_dynamo/variables/builder.py 中的 _wrap 函数处设置断点。您很可能需要根据输入有条件地设置断点。此函数确定如何符号化表示给定值。如果您怀疑某个值被错误地表示,请考虑在此处设置断点。

插入断点的第二种方法是使用 torch._dynamo.comptime.comptime.breakpoint

from torch._dynamo.comptime import comptime

@torch.compile
def f(...):
    ...
    comptime.breakpoint()
    ...

comptime 断点很方便,因为它使您能够在正在跟踪的用户代码中的特定位置检查 Dynamo 的状态。它不需要您在 Dynamo 源代码中插入断点,也不需要根据变量有条件地设置断点。

当 comptime 断点被触发时,您可以执行以下操作

  • ctx.print_bt() 打印用户堆栈跟踪

  • ctx.print_locals() 打印所有当前局部变量

  • ctx.print_graph() 打印当前跟踪的 graph

  • ctx.disas() 打印当前跟踪函数的字节码

  • 使用标准的 pdb 命令,例如 bt/u/d/n/s/r,您可以向上回溯 pdb 堆栈以检查更多 Dynamo 内部细节

import torch
from torch._dynamo.comptime import comptime

@torch.compile(backend="eager")
def f(x):
    y = x + 1
    comptime.breakpoint()
    y = y + 1
    return y

f(torch.ones(3, 3))
$ python playground.py
--Return--
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_bt()
File "/data/users/williamwen/pytorch/playground.py", line 7, in f
    comptime.breakpoint()

(Pdb) ctx.print_locals()
x = FakeTensor(..., size=(3, 3))
y = FakeTensor(..., size=(3, 3))
(Pdb) bt
...
/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py(826)call_function()
-> self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py(331)call_function()
-> func(ComptimeContext(tx))
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_graph()



def forward(self, L_x_: "f32[3, 3]"):
    l_x_ = L_x_

    # File: /data/users/williamwen/pytorch/playground.py:6 in f, code: y = x + 1
    y: "f32[3, 3]" = l_x_ + 1;  l_x_ = y = None

字节码生成错误

虽然不常见,但 Dynamo 可能会生成错误的字节码。如果您确定以下几点,则可能发生这种情况

  • 消融分析显示错误发生在 TorchDynamo 级别

  • 错误不是从 TorchDynamo 堆栈帧发出的

  • 错误看起来更像是用户错误而不是 Dynamo 错误,或者是一个段错误

  • 没有 torch.compile 时不会发生该错误

字节码生成 bug 通常很难修复,我们建议提交一个 issue,而不是尝试自己修复这些问题。如果您有兴趣查看 Dynamo 生成的字节码,您可以使用 TORCH_LOGS=bytecode。您可以在 这里 查看 Dynamo 生成的字节码的高层概述。

AOTAutograd

AOTAutograd 错误通常难以调试 - 我们建议直接提交一个 issue。AOTAutograd 的日志输出主要有助于查看 Inductor 的输入是什么。

TORCH_LOGS 选项总结

以下是一些有用的 TORCH_LOGS 选项总结

选项

描述

+all

输出所有 torch.compile 组件的调试日志

+dynamo

输出 TorchDynamo 的调试日志

+aot

输出 AOTAutograd 的调试日志

+inductor

输出 TorchInductor 的调试日志

dynamic

输出动态形状相关的日志

graph_code

输出 Dynamo 生成的 FX graph 的 Python 代码

graph_sizes

输出 Dynamo 生成的 FX graph 的 tensor 形状

trace_bytecode

输出 Dynamo 正在跟踪的字节码指令以及 Dynamo 正在跟踪的符号解释器堆栈

trace_source

输出 Dynamo 当前正在跟踪的原始源代码行

bytecode

输出 Dynamo 生成的字节码

guards

输出生成的 guards

recompiles

输出重新编译原因(仅第一个 guard 检查失败)

recompiles_verbose

输出重新编译时所有失败的 guard 检查

aot_graphs

输出 AOTAutograd 生成的 graph

aot_joint_graphs

输出 AOTAutograd 生成的联合前向-后向 graph

output_code

输出 Inductor 生成的代码

kernel_code

按每个 kernel 输出 Inductor 生成的代码

schedule

输出 Inductor 调度日志

perf_hints

输出 Inductor 性能提示日志

fusion

输出 Inductor 算子融合日志

完整选项列表请参阅 torch._loggingtorch._logging.set_logs

文档

获取 PyTorch 全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源