快捷方式

torch.compile 疑难解答

您正在尝试在 PyTorch 模型上使用 torch.compile 以增强其性能,但它未按预期工作。也许性能没有提高,发生崩溃,或者编译时间太长。本文提供了技巧、解决方法和调试工具,以帮助您克服这些挑战。

目录

设定预期

torch.compile 被设计为通用的 PyTorch 编译器。与之前的编译器解决方案 TorchScript 不同,torch.compile 需要更少的代码更改,这意味着模型通常不需要从头开始重写。它还能更优雅地管理不支持的代码 - 不支持的代码会导致失去优化机会,而不是崩溃。

在理想情况下,人们可以简单地将 torch.compile 应用于任何 PyTorch 模型,并享受自动加速。然而,在现实中,代码复杂性可能导致以下三种情况之一

  1. torch.compile 无缝工作,提供加速。

  2. 需要进行一些代码修改。torch.compile 不会崩溃或花费太长时间,但您可能看不到明显的性能提升。

  3. 需要对您的代码进行大量更改。

我们预计大多数代码将属于情况 (1) 和 (2)。本文档提供了按参与程度排列的技巧,以帮助解决情况 (2) 中的代码问题。

编译时间

torch.compile 作为即时编译器运行,因此编译函数的最初一两次运行预计会明显较慢。在某些条件下(详见下文)可能发生的重新编译也会使运行速度变慢。各种 torch.compile 组件会缓存结果,以减少未来调用(即使在不同进程中)的编译时间。对于常见或基准测试模型,冷启动(未缓存)编译时间通常为几秒到几分钟。较大的模型可能需要 30 分钟到几个小时以上。

术语

以下术语与 torch.compile 问题的故障排除相关。

图中断

torch.compile 跟踪您的代码,并尝试将您的 PyTorch 代码捕获到 PyTorch 运算符(FX 图)的单个计算图中。但是,这并非总是可能的。当遇到无法跟踪的代码时,会发生“图中断”。图中断涉及编译到目前为止已确定的 FX 图,运行不支持的代码,然后在不支持的代码之后使用新的 FX 图恢复跟踪。由于计算图被分解,我们失去了优化机会,因此模型代码应尽可能避免图中断。图中断发生在如下情况:

  • 数据相关的 if 语句

  • 许多 Python 内置函数

  • C 函数

以下是由于 Python 内置库中的函数 copy.deepcopy 导致的图中断示例(确切输出可能有所不同)。

import torch

@torch.compile
def fn(x):
    x = x + 1
    with open("test.txt", "r") as f:
        return x + len(f.read())

fn(torch.ones(3, 3))
$TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in fn
    with open("test.txt", "r") as f:
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 635, in wrapper
    return inner_fn(self, inst)
        ^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2414, in CALL
    self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _call
    self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 962, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 997, in call_function
    return handler(tx, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 831, in <lambda>
    return lambda *args: unimplemented(error_msg)
                        ^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False

Guard

torch.compile 在我们跟踪代码时,会对运行时值做出一些假设。在跟踪期间,我们生成“Guard”,它们是对这些假设的运行时检查。Guard 在未来对编译函数的调用中运行,以确定我们是否可以重用先前编译的代码。运行时检查的示例包括常量值、类型和对象 ID。

以下是生成的 Guard 的示例。TENSOR_MATCH Guard 检查输入的类型、设备、dtype、形状等。

import torch

@torch.compile
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
$ TORCH_LOGS="guards" python playground.py
GUARDS:

TREE_GUARD_MANAGER:
+- RootGuardManager
| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:471 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state()
| +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
| +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1])  # return x + 1  # playground.py:6 in fn
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # return x + 1  # playground.py:6 in fn

重新编译

如果先前编译的代码的每个实例的 Guard 都失败,则 torch.compile 必须“重新编译”该函数,需要再次跟踪原始代码。

在下面的示例中,重新编译是必要的,因为检查张量参数形状的 Guard 失败了。

import torch

@torch.compile
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
    triggered by the following guard failure(s):
    - 0/0: tensor 'L['x']' size mismatch at index 0. expected 3, actual 4

动态形状

torch.compile 最初假设张量形状是静态/常量,并基于这些假设进行 Guard。通过使用“动态形状”,我们可以让 torch.compile 生成可以接受具有不同形状的张量输入的编译代码 - 我们避免了每次形状不同时都重新编译。默认情况下,自动动态形状已启用 torch.compile(dynamic=None) - 如果由于形状不匹配导致编译失败,则会尝试使用动态形状重新编译。动态形状也可以完全启用 dynamic=True 或禁用 dynamic=False

下面,我们启用动态形状,并注意到我们不再需要重新编译。

import torch

@torch.compile(dynamic=True)
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="dynamic,recompiles" python playground.py
create_symbol s0 = 3 for L['x'].size()[0] [2, int_oo] at playground.py:5 in fn (_dynamo/variables/builder.py:2718 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
produce_guards
produce_guards

有关动态形状的更多信息,请参阅 动态形状手册

日志工具

tlparse / TORCH_TRACE

tlparse / TORCH_TRACE 是一对工具,用于生成如下所示的编译报告:https://web.mit.edu/~ezyang/Public/bhack-20240609-tlparse/index.html

跟踪非常容易收集。要收集跟踪,请使用以下命令运行您的复现命令

TORCH_TRACE="/tmp/tracedir" python foo.py
pip install tlparse
tlparse /tmp/tracedir

即使您正在运行分布式作业,此方法也有效,为每个 rank 提供跟踪。它将打开您的浏览器,显示类似于上面生成的 HTML。如果您正在为无法独立复现的复杂问题制作错误报告,您仍然可以通过附加在 /tmp/tracedir 中生成的跟踪日志来极大地帮助 PyTorch 开发者。

警告

跟踪日志包含您的所有模型代码。如果您正在处理的模型是敏感的,请勿共享跟踪日志。跟踪日志不包含权重。

tlparse 的输出主要面向 PyTorch 开发者,并且日志格式易于上传并在 GitHub 上共享。但是,作为非 PyTorch 开发者,您仍然可以从中提取有用的信息。我们建议从报告中的内联帮助文本开始,其中解释了其内容。以下是您可以从 tlparse 中获得的一些见解

  • 通过查看堆栈 trie,了解哪些模型代码被编译?如果您不熟悉正在编译的代码库,这将特别有用!

  • 有多少图中断/不同的编译区域?(每个不同的编译都是其自己的颜色编码块,如 [0/0])。可能是图中断的帧是浅绿色 [2/4]。如果有很多帧,那很可疑,表明您遇到了一些灾难性的图中断,或者您的代码可能与 torch.compile 不太匹配。

  • 我重新编译特定帧多少次?重新编译多次的内容将如下所示:[10/0] [10/1] [10/2] - 如果某些内容被重新编译多次,那非常可疑,值得研究,即使它不是您问题的根本原因。

  • 是否有编译错误?出错的帧将如下所示:[0/1]

  • 我为给定帧生成了哪些中间编译器产品?例如,您可以查看高级生成的 FX 图或生成的 Triton 代码。

  • 特定帧是否有相关信息?您可以在 compilation_metrics 中找到这些信息。

TORCH_LOGS

您可以使用 TORCH_LOGS 环境变量来选择性地启用 torch.compile 堆栈的某些部分进行日志记录。TORCH_LOGS 实际上是 tlparse 的日志来源。TORCH_LOGS 环境变量的格式如下所示

TORCH_LOGS="<option1>,<option2>,..." python foo.py

有用的高级选项包括

  • graph_breaks:记录用户代码中图中断的位置以及图中断的原因

  • guards:记录生成的 Guard

  • recompiles:记录重新编译的函数以及导致重新编译的失败 Guard

  • dynamic:记录与动态形状相关的日志

此外,您可以使用 torch._logging.set_logs 以编程方式设置日志记录选项

import logging
torch._logging.set_logs(graph_breaks=True)
...

更多 TORCH_LOGS 选项在 下面详细介绍。有关选项的完整列表,请参阅 torch._loggingtorch._logging.set_logs

tlparse vs. TORCH_LOGS

一般来说,我们建议在遇到问题时首先使用 tlparsetlparse 非常适合调试大型模型并获得模型编译方式的高级概述。另一方面,当我们已经知道哪个 torch.compile 组件导致问题时,TORCH_LOGS 更适合用于小型示例和细粒度的调试细节。

简单解决方法

在这里,我们描述了一些 torch.compile 问题的解决方法,这些问题涉及小的代码修改或更改一些 torch.compile 设置。

在何处应用 torch.compile?

我们建议将 torch.compile 应用于不会引起过多问题的最高级别函数。通常,它是您的训练或评估步骤(带有优化器但没有循环)、您的顶层 nn.Module 或某些子 nn.Moduletorch.compile 特别不擅长处理分布式包装器模块(如 DDP 或 FSDP),因此请考虑将 torch.compile 应用于传递给包装器的内部模块。

# inference
model = ...
opt_model = torch.compile(model)

for _ in range(N_ITERS):
    inp = ...
    out = opt_model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())

@torch.compile
def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

for _ in range(N_ITERS):
    inp = ...
    train(model, inp)
# DistributedDataParallel
model = ...
opt_model = torch.compile(model)
model_ddp = DistributedDataParallel(opt_model, ...)

for _ in range(N_ITERS):
    inp = ...
    out = model_ddp(inp)

禁用和抑制错误

对于某些模型架构,模型的部分区域特别难以编译 - 要么存在许多图中断,要么发生崩溃。您可能希望显式禁用这些有问题的模型部分,以便可以将 torch.compile 应用于可以工作的部分。您可以使用 @torch.compiler.disable 装饰器来执行此操作。当 torch.compile 尝试调用禁用的函数时,它会中断图并跳过跟踪禁用的函数,并在调用后恢复跟踪。默认情况下,从禁用的函数进行的所有递归调用也会被禁用。使用 recursive=False 选项以允许对递归调用进行编译。

def bad1_inner(...):
    # skipped

@torch.compiler.disable
def bad1_outer(...):
    # skipped
    bad1_inner(...)

def bad2_inner(...)
    # traced

@torch.compiler.disable(recursive=False)
def bad2_outer(...):
    # skipped
    bad2_inner(...)

@torch.compile
def fn(...):
    # graph break
    bad1_outer(...)
    ...
    # graph break
    bad2_outer(...)

例如,我们在推荐模型中对稀疏架构使用 torch.compiler.disable 来禁用 torch.compile,因为稀疏架构难以编译。预处理和日志记录函数是通常会导致大量图中断并且无法从编译中获得价值的其他函数示例。

如果您遇到编译器崩溃并且想要继续运行,您可以设置 torch._dynamo.config.suppress_errors = True。当编译器崩溃时,我们将只跳过跟踪该函数,稍后重试。这不是最佳实践 - 最好最终手动添加必要的禁用注释。

解决图中断

为了最大化优化机会,减少图中断的数量非常重要。回想一下,您可以使用 tlparseTORCH_LOGS="graph_breaks" 查看正在发生的图中断。一般来说,图中断是由以下原因之一引起的

  1. 您正在尝试执行一些从根本上无法跟踪的操作,例如数据相关的控制流。

  2. 您正在尝试执行一些尚不支持的操作。例如,我们目前对跟踪使用内置 Python inspect 模块的代码的支持有限。

  3. 您的代码中存在错误。例如,您可能尝试使用不正确的参数数量调用函数。

图中断日志将告诉您用户代码位置和图中断的原因。不幸的是,如果不深入了解 Dynamo,许多图中断是无法采取措施的。甚至很难确定这三个原因中的哪一个是您的图中断的真正原因。我们正在努力使图中断消息更具可操作性。

此外,不同图中断之间,失去优化机会的影响也不同。例如,在模型 forward 中间发生的图中断可能比在 forward 开始的预处理部分发生的图中断具有更负面的影响。因此,防止每一个中断并非至关重要,而是防止那些导致显着性能下降的中断。

如果图中断消息没有建议任何操作,您怀疑您的图中断的原因是 (2),并且您认为图中断正在导致性能下降,请将图中断报告为问题。如果一个函数有很多图中断,请考虑禁用该函数的编译,因为图中断的开销成本可能会变得过高。

以下是一些常见的图中断和一些解决方法。

数据相关的操作

torch.compile 在数据相关的操作(如数据相关的控制流(if 语句、带有张量的循环)和直接张量数据访问(.item.data_ptr))上发生图中断。

import torch

@torch.compile
def fn(x):
    y = x.sum()
    if y > 0:
        return x + y.item()
    return x - y.item()

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:6
Reason: Data-dependent jump
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 6, in fn
    if y > 0:

Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: Tensor.item
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in torch_dynamo_resume_in_fn_at_6
    return x + y.item()
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 616, in wrapper
    return inner_fn(self, inst)
        ^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2288, in CALL
    self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2282, in _call
    self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 838, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py", line 1038, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 527, in call_method
    result = handler_method(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 773, in method_item
    unimplemented("Tensor.item")
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 304, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item

这些图中断的通用解决方法是避免执行数据相关的操作。一些具体的解决方法是

  • 如果您的控制流实际上不依赖于数据值,请考虑修改您的代码以对常量执行控制流。

# old
x = torch.randn(3, 3)
@torch.compile
def fn(y):
    if x.sum() > 0:
        return y + x
    else:
        return y - x

# new
x = torch.randn(3, 3)
cond = (x.sum() > 0).item()
@torch.compile
def fn(y):
    if cond:
        return y + x
    else:
        return y - x
# old
@torch.compile
def fn(x):
    if x.sum() > 0:
        return x + 1
    return x - 1

# new
@torch.compile
def fn(x):
    return torch.cond(
        x.sum() > 0,
        lambda x: x + 1,
        lambda x: x - 1,
        (x,),
    )
  • 如果您有 .item() 调用,请尝试 torch._dynamo.config.capture_scalar_outputs = TrueTORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1

  • 将函数中存在问题的部分包装在自定义操作中

自定义操作

如果您有 torch.compile 难以跟踪的代码,无论是由于缺少支持还是根本不兼容,您可以考虑将有问题的代码包装在自定义操作中。

自定义操作需要做一些额外的工作才能使其与 torch.compile 兼容。有关更多详细信息,请参阅 https://pytorch.ac.cn/tutorials/advanced/custom_ops_landing_page.html

打印

打印/日志记录/发出警告将导致图中断。如果您有一个函数进行许多日志记录调用,例如,一个记录有关训练迭代数据的函数,请考虑对其应用 torch.compiler.disable

或者,您可以尝试使用 torch._dynamo.config.reorderable_logging_functions。此配置用于重新排序日志记录函数,以便在跟踪函数的末尾调用它们,从而避免图中断。但是,如果发生突变,则记录的内容可能会有所不同。

import torch

torch._dynamo.config.reorderable_logging_functions.add(print)

@torch.compile
def fn(x):
    x += 1
    print("log!")
    return torch.sin(x)

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
log!

不正确的代码

您的代码可能存在错误,或者遇到来自 torch.compile 外部的错误。在下面的代码中,我们在 torch.sin 调用中输入了错误,提供了额外的参数。

import torch

@torch.compile
def fn(x):
    y = torch.sin(x, x)
    return y

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:5
Reason: Unsupported: TypeError <built-in method sin of type object at 0x7fd6fd764600>: sin() takes 1 positional argument but 2 were given
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 5, in fn
    y = torch.sin(x, x)
...

很难从日志中判断错误是由您的代码引起的还是由于 torch.compile 错误引起的。为了区分,我们建议尝试在不使用 torch.compile 的情况下运行您的代码,看看是否仍然会收到错误。

处理重新编译

您可以使用 tlparseTORCH_LOGS=recompiles 查看重新编译及其原因。

是否启用了动态形状?

由于形状不匹配导致的重新编译形式如下

tensor 'L['x']' size mismatch at index 0. expected 3, actual 4

确保 torch.compiledynamic 选项未设置为 False。默认选项 dynamic=None 只会在第一次编译后尝试动态形状。您可以设置 dynamic=True 以尽可能预先编译为动态。

有关动态形状的更多信息,请参阅 动态形状手册

更改缓存大小限制

函数可以重新编译的次数是有限制的,由 torch._dynamo.config.cache_size_limittorch._dynamo.config.accumulated_cache_size_limit 确定。如果超出任一限制,我们将不再尝试重新编译该函数,而是会急切地运行该函数。torch.compile 还会发出警告,其中包含受影响的函数以及命中的限制。在下面的示例中,每个函数调用都会导致重新编译尝试。当我们达到缓存大小限制 (8) 时,我们停止尝试重新编译。

import torch

@torch.compile(dynamic=False)
def fn(x):
    return x + 1

for i in range(1, 10):
    fn(torch.ones(i))
$ python playground.py
torch._dynamo hit config.cache_size_limit (8)
    function: 'fn' (/data/users/williamwen/pytorch/playground.py:5)
    last reason: 0/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 9

如果您知道重新编译的次数具有合理的常数上限,您可以提高缓存大小限制。如果重新编译的成本超过了编译的好处,那么您可以考虑降低缓存大小限制。

使用张量包装常量

默认情况下,int / float 变量被视为常量,并以此方式进行 Guard。在下面的示例中,我们为每个函数调用都进行了重新编译。

import torch

@torch.compile
def fn(x, c):
    return x + c

for i in range(1, 10):
    fn(torch.ones(i), 0.5 + i)
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
    triggered by the following guard failure(s):
    - 0/7: L['c'] == 8.5
    - 0/6: L['c'] == 7.5
    - 0/5: L['c'] == 6.5
    - 0/4: L['c'] == 5.5
    - 0/3: L['c'] == 4.5
    - 0/2: L['c'] == 3.5
    - 0/1: L['c'] == 2.5
    - 0/0: L['c'] == 1.5
torch._dynamo hit config.cache_size_limit (8)
    function: 'fn' (/data/users/williamwen/pytorch/playground.py:3)
    last reason: 0/0: L['c'] == 1.5

特别是,对于 LR 调度器,使用常量初始化可能会导致重新编译

import torch

mod = torch.nn.Linear(3, 3)
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)

@torch.compile
def fn(inp):
    opt.zero_grad(True)
    out = mod(inp).sum()
    out.backward()
    opt.step()
    sched.step()

for i in range(1, 10):
    fn(torch.ones(3, 3))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function step in /data/users/williamwen/pytorch/torch/optim/adam.py:189
    triggered by the following guard failure(s):
    - 3/7: L['self'].param_groups[0]['lr'] == 0.004782969000000002
    - 3/6: L['self'].param_groups[0]['lr'] == 0.005314410000000002
    - 3/5: L['self'].param_groups[0]['lr'] == 0.005904900000000002
    - 3/4: L['self'].param_groups[0]['lr'] == 0.006561000000000002
    - 3/3: L['self'].param_groups[0]['lr'] == 0.007290000000000001
    - 3/2: L['self'].param_groups[0]['lr'] == 0.008100000000000001
    - 3/1: L['self'].param_groups[0]['lr'] == 0.009000000000000001
    - 3/0: L['self'].param_groups[0]['lr'] == 0.01
torch._dynamo hit config.cache_size_limit (8)
    function: 'step' (/data/users/williamwen/pytorch/torch/optim/adam.py:189)
    last reason: 3/0: L['self'].param_groups[0]['lr'] == 0.01

在这两个示例中,我们可以将浮点变量包装在张量中,以防止重新编译。

# first example
for i in range(1, 10):
    fn(torch.ones(i), torch.tensor(0.5 + i))

# second example
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))

报告问题

如果上面提供的解决方法不足以使 torch.compile 工作,那么您应该考虑向 PyTorch 报告问题。但是,您可以做一些事情来使我们的工作轻松得多。

消融

使用 torch.compilebackend= 选项检查 torch.compile 堆栈的哪个组件导致了问题。特别是,尝试

  • torch.compile(fn, backend="eager"),它仅运行 TorchDynamo,即 torch.compile 的图捕获组件。

  • torch.compile(fn, backend="aot_eager"),它运行 TorchDynamo 和 AOTAutograd,后者还在编译期间生成反向图。

  • torch.compile(fn, backend="aot_eager_decomp_partition"),它运行 TorchDynamo 和 AOTAutograd,并带有运算符分解/分区。

  • torch.compile(fn, backend="inductor"),它运行 TorchDynamo、AOTAutograd 和 TorchInductor,后者是后端 ML 编译器,用于生成编译后的内核。

如果您仅在使用 Inductor 后端时失败,您可以额外测试各种 Inductor 模式

  • torch.compile(fn, backend="inductor", mode="default")

  • torch.compile(fn, backend="inductor", mode="reduce-overhead")

  • torch.compile(fn, backend="inductor", mode="max-autotune")

您还可以检查动态形状是否导致任何后端出现问题

  • torch.compile(fn, dynamic=True) (始终使用动态形状)

  • torch.compile(fn, dynamic=False) (从不使用动态形状)

  • torch.compile(fn, dynamic=None) (自动动态形状)

二分法

您是否尝试过最新的 nightly 版本?过去某些功能正常,但现在无法正常工作了吗?您能否通过二分法确定问题首次出现的 nightly 版本?二分法对于性能、准确性或编译时间衰退尤其有用,因为这些问题的原因通常不明显。

创建复现器

创建复现器需要大量工作,如果您没有时间这样做,完全没有问题。但是,如果您是一位积极的用户,不熟悉 torch.compile 的内部机制,那么创建一个独立的复现器可以极大地帮助我们修复错误。如果没有复现器,您的错误报告必须包含足够的信息,以便我们能够识别问题的根本原因并从头开始编写复现器。

以下是常用复现器的列表,按首选程度从高到低排序

  1. 自包含、小型复现器: 一个不依赖外部依赖项的脚本,代码行数少于 100 行,运行时可以重现问题。

  2. 自包含、大型复现器: 即使它很大,自包含也是一个巨大的优势!

  3. 非自包含、但依赖项可管理的复现器: 例如,如果您可以通过在 pip install transformers 之后运行脚本来重现问题,那是可以管理的。我们很可能可以运行它并进行调查。

  4. 非自包含、需要大量设置的复现器: 这可能涉及下载数据集、多个环境设置步骤或需要 Docker 镜像的特定系统库版本。设置越复杂,我们越难重现环境。

    注意

    Docker 简化了设置,但使环境更改复杂化,因此它不是一个完美的解决方案,尽管在必要时我们会使用它。

在某种程度上正交地,可以在单个进程中运行的复现器比需要多进程训练的复现器更好(但再次强调,如果您只有多进程复现器,我们也会接受!)。

此外,以下是在您的问题中需要检查的非详尽列表,您可以尝试在复现器中复制这些方面

  • Autograd。您是否有 requires_grad=True 的张量输入?您是否在输出上调用了 backward()

  • 动态形状。您是否设置了 dynamic=True?或者您是否使用不同的形状多次运行测试代码?

  • 自定义运算符。实际工作流程中是否涉及自定义运算符?您可以使用 Python 自定义运算符 API 复制其某些重要特征吗?

  • 配置。您是否设置了所有相同的配置?这包括 torch._dynamo.configtorch._inductor.config 设置,以及 torch.compile 的参数,如 backend / mode

  • 上下文管理器。您是否复制了任何活动的上下文管理器?这可能是 torch.no_grad、自动混合精度、TorchFunctionMode / TorchDispatchMode、激活检查点、编译后的 autograd 等。

  • 张量子类。是否涉及张量子类?

Minifier

Minifier 是一个早期的 torch.compile 工具,给定一个在尝试运行或编译时崩溃的 FX 图,它可以找到一个也崩溃的子图,并输出执行该子图操作的代码。本质上,Minifier 为某一类 torch.compile 相关崩溃找到一个最小的复现。这假设我们能够成功地追踪代码。

不幸的是,现在大多数时候,Minifier 的工作效果不如预期,可能需要替代方法。这可能是因为可以自动以这种方式重现的错误通常更容易修复,并且已经得到解决,剩下的问题更复杂,不容易重现。但是,尝试使用 Minifier 很简单,因此即使可能不成功,也值得尝试。

有关操作 Minifier 的说明,请参见此处。如果编译器崩溃,您可以设置 TORCHDYNAMO_REPRO_AFTER="dynamo"TORCHDYNAMO_REPRO_AFTER="aot"aot 选项更可能成功,尽管它可能无法识别 AOTAutograd 问题。这将生成 repro.py 文件,这可能有助于诊断问题。对于与准确性相关的问题,请考虑设置 TORCHDYNAMO_REPRO_LEVEL=4。请注意,这可能并不总是成功识别有问题的子图。

更深层次的调试

本节提供用于独立调试 torch.compile 问题或更深入了解 torch.compile 堆栈的工具和技术。这些方法比上面介绍的方法更复杂,PyTorch 开发人员经常使用这些方法来调试真实的 torch.compile 问题。

以下是堆栈的概要概述

_images/td_stack.png

该堆栈包含三个主要组件:TorchDynamo、AOTAutograd 和 Inductor。我们的调试策略首先是识别错误发生的组件,然后单独调试该组件。要确定哪个组件负责该问题,请参阅上面报告问题下的消融部分。有关调试特定组件的指导,请查阅以下部分。

TorchDynamo

记录 Dynamo 正在追踪的内容

TORCH_LOGS=trace_bytecode 选项使您能够查看 Dynamo 正在追踪的精确字节码指令,以及 Python 解释器堆栈的符号表示。当遇到图中断或崩溃时,建议检查最后几个追踪的字节码指令。

您还可以使用 TORCH_LOGS=trace_source 来查看 Dynamo 正在追踪的源代码行。这与 trace_bytecode 结合使用很有用,可以查看每个追踪的字节码指令对应的源代码行。

最后,您可以使用 TORCH_LOGS=graph_code 来查看表示 Dynamo 追踪的 FX 图的 Python 代码。您可以查看此代码以仔细检查是否追踪了正确的 ops。

import torch

def g(x, y):
    return x + y

@torch.compile(backend="eager")
def f(x):
    x = torch.sin(x)
    x = g(x, x)
    return x

f(torch.ones(3, 3))
$ TORCH_LOGS="trace_bytecode,trace_source,graph_code" python playground.py
TRACE starts_line /data/users/williamwen/pytorch/playground.py:6 in f ()
    @torch.compile(backend="eager")
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:8 in f (f)
        x = torch.sin(x)
TRACE LOAD_GLOBAL torch []
TRACE LOAD_ATTR sin [NullVariable(), PythonModuleVariable(<module 'torch' from '/data/users/williamwen/pytorch/torch/__init__.py'>)]
TRACE LOAD_FAST x [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>)]
TRACE CALL 1 [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>), LazyVariableTracker()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:9 in f (f)
        x = g(x, x)
TRACE LOAD_GLOBAL g []
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable()]
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable(), TensorVariable()]
TRACE CALL 2 [NullVariable(), UserFunctionVariable(), TensorVariable(), TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:3 in g (g) (inline depth: 1)
    def g(x, y):
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:4 in g (g) (inline depth: 1)
        return x + y
TRACE LOAD_FAST x []
TRACE LOAD_FAST y [TensorVariable()]
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
TRACE RETURN_VALUE None [TensorVariable()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:10 in f (f)
        return x
TRACE LOAD_FAST x []
TRACE RETURN_VALUE None [TensorVariable()]
TRACED GRAPH
===== __compiled_fn_1 =====
/data/users/williamwen/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3][3, 1]cpu"):
        l_x_ = L_x_

        # File: /data/users/williamwen/pytorch/playground.py:8 in f, code: x = torch.sin(x)
        x: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None

        # File: /data/users/williamwen/pytorch/playground.py:4 in g, code: return x + y
        x_1: "f32[3, 3][3, 1]cpu" = x + x;  x = None
        return (x_1,)

断点调试 Dynamo 追踪

有时在 Dynamo/用户代码中插入断点有助于查看 Dynamo 在追踪用户代码时的状态。不幸的是,以正常的 Python 方式插入断点会导致 TorchDynamo 中的图中断,因此我们将无法在预期断点的位置查看 Dynamo 的状态。

设置断点的第一种方法是在 Dynamo 源代码中插入断点。建议放置断点的三个位置是

  • torch/_dynamo/symbolic_convert.py 中,在以有问题的字节码指令命名的函数处设置断点,例如 def CALL_FUNCTIONdef STORE_ATTR。您可以根据输入有条件地设置断点,例如指令的 argval,或堆栈顶部的对象的名称,因为某些字节码操作码经常使用。

  • 在图中断或错误发生的位置设置断点。通常,图中断是从对 unimplemented(...) 的调用发出的。

  • torch/_dynamo/variables/builder.py, function:_wrap 中设置断点。您可能需要根据输入有条件地设置断点。此函数确定如何以符号方式表示给定值。如果您怀疑某个值的表示不正确,请考虑在此处设置断点。

插入断点的第二种方法是使用 torch._dynamo.comptime.comptime.breakpoint

from torch._dynamo.comptime import comptime

@torch.compile
def f(...):
    ...
    comptime.breakpoint()
    ...

comptime 断点很方便,因为它使您能够在被追踪的用户代码中的特定位置检查 Dynamo 状态。它不需要您在 Dynamo 源代码中插入断点或根据变量有条件地设置断点。

当触发 comptime 断点时,您可以执行以下操作

  • ctx.print_bt() 打印用户堆栈跟踪

  • ctx.print_locals() 打印所有当前局部变量

  • ctx.print_graph() 打印当前追踪的图

  • ctx.disas() 打印当前追踪函数的字节码

  • 使用标准 pdb 命令,例如 bt/u/d/n/s/r - 您可以向上移动 pdb 堆栈以检查更多 Dynamo 内部结构

import torch
from torch._dynamo.comptime import comptime

@torch.compile(backend="eager")
def f(x):
    y = x + 1
    comptime.breakpoint()
    y = y + 1
    return y

f(torch.ones(3, 3))
$ python playground.py
--Return--
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_bt()
File "/data/users/williamwen/pytorch/playground.py", line 7, in f
    comptime.breakpoint()

(Pdb) ctx.print_locals()
x = FakeTensor(..., size=(3, 3))
y = FakeTensor(..., size=(3, 3))
(Pdb) bt
...
/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py(826)call_function()
-> self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py(331)call_function()
-> func(ComptimeContext(tx))
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_graph()



def forward(self, L_x_: "f32[3, 3]"):
    l_x_ = L_x_

    # File: /data/users/williamwen/pytorch/playground.py:6 in f, code: y = x + 1
    y: "f32[3, 3]" = l_x_ + 1;  l_x_ = y = None

字节码生成错误

虽然不常见,但 Dynamo 可能会生成不正确的字节码。如果您确定以下情况,可能会发生这种情况

  • 消融显示错误发生在 TorchDynamo 级别

  • 错误不是从 TorchDynamo 堆栈帧发出的

  • 错误看起来更像是用户错误而不是 Dynamo 错误,或者是段错误

  • 没有 torch.compile 时不会发生错误

字节码生成错误通常很难修复,我们建议提交问题而不是尝试自己修复。如果您有兴趣查看 Dynamo 生成的字节码,可以使用 TORCH_LOGS=bytecode。您可以在此处查看 Dynamo 生成的字节码的高级概述。

AOTAutograd

AOTAutograd 错误通常难以调试 - 我们建议直接提交问题。AOTAutograd 日志输出主要有助于查看 Inductor 的输入是什么。

TORCH_LOGS 选项摘要

有用的 TORCH_LOGS 选项摘要如下

选项

描述

+all

输出所有 torch.compile 组件的调试日志

+dynamo

输出 TorchDynamo 的调试日志

+aot

输出 AOTAutograd 的调试日志

+inductor

输出 TorchInductor 的调试日志

dynamic

输出动态形状的日志

graph_code

输出 Dynamo 生成的 FX 图的 Python 代码

graph_sizes

输出 Dynamo 生成的 FX 图的张量大小

trace_bytecode

输出 Dynamo 正在追踪的字节码指令以及 Dynamo 正在跟踪的符号解释器堆栈

trace_source

输出 Dynamo 当前正在追踪的原始源代码中的代码行

bytecode

输出 Dynamo 生成的字节码

guards

输出生成的 guards

recompiles

输出重新编译原因(仅限第一个失败的 guard 检查)

recompiles_verbose

输出重新编译发生时所有失败的 guard 检查

aot_graphs

输出 AOTAutograd 生成的图

aot_joint_graphs

输出 AOTAutograd 生成的联合前向-后向图

output_code

输出 Inductor 生成的代码

kernel_code

输出 Inductor 基于内核生成的代码

schedule

输出 Inductor 调度日志

perf_hints

输出 Inductor 性能提示日志

fusion

输出 Inductor fusion 日志

有关选项的完整列表,请参见torch._loggingtorch._logging.set_logs

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源