• 教程 >
  • torch.export 教程
快捷方式

torch.export 教程

创建于:2023 年 10 月 02 日 | 最近更新:2025 年 01 月 27 日 | 最近验证:2024 年 11 月 05 日

作者: William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan

警告

torch.export 及其相关功能处于原型状态,可能会发生向后不兼容的更改。本教程提供 PyTorch 2.5 中 torch.export 用法的快照。

torch.export() 是 PyTorch 2.X 中将 PyTorch 模型导出为标准化模型表示的方法,旨在在不同的(即无 Python)环境中运行。官方文档可在此处找到

在本教程中,您将学习如何使用 torch.export() 从 PyTorch 程序中提取 ExportedProgram(即单图表示)。我们还将详细介绍一些您可能需要进行的考虑/修改,以便您的模型与 torch.export 兼容。

目录

基本用法

torch.export 通过跟踪目标函数并给定示例输入,从 PyTorch 程序中提取单图表示。torch.export.export()torch.export 的主要入口点。

在本教程中,torch.exporttorch.export.export() 在实践中是同义的,尽管 torch.export 通常指 PyTorch 2.X 导出过程,而 torch.export.export() 通常指实际的函数调用。

torch.export.export() 的签名为

export(
    mod: torch.nn.Module,
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
    *,
    dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram

torch.export.export() 通过调用 mod(*args, **kwargs) 跟踪张量计算图,并将其包装在一个 ExportedProgram 中,该程序可以序列化或稍后用不同的输入执行。要执行 ExportedProgram,我们可以对其调用 .module() 以返回一个可调用的 torch.nn.Module,就像原始程序一样。我们将在教程的后续部分详细介绍 dynamic_shapes 参数。

import torch
from torch.export import export

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x, y):
        return torch.nn.functional.relu(self.lin(x + y), inplace=True)

mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))
<class 'torch.export.exported_program.ExportedProgram'>
tensor([[0.8632, 0.8407, 0.0407, 0.0000, 0.4132, 0.0000, 0.0000, 0.1538, 0.6111,
         0.0000],
        [0.0000, 0.0000, 0.0273, 0.8057, 0.0000, 1.0162, 0.8042, 0.0000, 0.2660,
         0.0000],
        [0.9481, 0.1396, 1.0225, 0.9563, 0.5832, 0.2546, 0.4095, 0.4591, 0.0000,
         2.0053],
        [1.1300, 0.4873, 0.0000, 0.9663, 1.2275, 1.4015, 0.0000, 0.9444, 0.0000,
         0.0000],
        [0.0000, 0.8724, 1.1648, 0.6867, 0.0000, 0.2833, 0.3202, 0.5848, 0.0000,
         0.0833],
        [1.1311, 0.1324, 0.0000, 1.7842, 0.0000, 0.3474, 0.9916, 0.3571, 0.0000,
         0.0000],
        [1.4348, 1.0570, 0.1771, 0.0000, 0.9510, 0.0000, 0.0000, 0.0000, 0.2618,
         0.0000],
        [0.8853, 0.0000, 0.0000, 0.4486, 0.0000, 0.0000, 0.5841, 0.7604, 0.0000,
         0.0000]], grad_fn=<ReluBackward0>)

让我们回顾一下 ExportedProgram 中一些有趣的属性。

graph 属性是一个从我们导出的函数跟踪而来的 FX 图,即所有 PyTorch 操作的计算图。FX 图采用“ATen IR”格式,这意味着它只包含“ATen 级别”的操作。

graph_signature 属性更详细地描述了导出图中的输入和输出节点,说明哪些是参数、缓冲区、用户输入或用户输出。

range_constraints 属性将在稍后介绍。

print(exported_mod)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_lin_weight: "f32[10, 100]", p_lin_bias: "f32[10]", x: "f32[8, 100]", y: "f32[8, 100]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
            add: "f32[8, 100]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            linear: "f32[8, 10]" = torch.ops.aten.linear.default(add, p_lin_weight, p_lin_bias);  add = p_lin_weight = p_lin_bias = None
            relu_: "f32[8, 10]" = torch.ops.aten.relu_.default(linear);  linear = None
            return (relu_,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lin_weight'), target='lin.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lin_bias'), target='lin.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='relu_'), target=None)])
Range constraints: {}

有关更多详细信息,请参阅 torch.export 文档

图中断

尽管 torch.exporttorch.compile 共享组件,但 torch.export 的主要限制(特别是与 torch.compile 相比)在于它不支持图中断。这是因为处理图中断涉及使用默认的 Python 评估来解释不受支持的操作,这与导出用例不兼容。因此,为了使您的模型代码与 torch.export 兼容,您需要修改代码以移除图中断。

在以下情况下需要图中断:

  • 数据相关的控制流

class Bad1(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return torch.sin(x)
        return torch.cos(x)

import traceback as tb
try:
    export(Bad1(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3][3, 1]cpu"):
        l_x_ = L_x_

         # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:116 in forward, code: if x.sum() > 0:
        sum_1: "f32[][]cpu" = l_x_.sum();  l_x_ = None
        gt: "b8[][]cpu" = sum_1 > 0;  sum_1 = gt = None

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 122, in <module>
    export(Bad1(), (torch.randn(3, 3),))
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 360, in export
    return _export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2112, in _export
    ep = _export_for_training(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1975, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1677, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
    raise e.with_traceback(None) from None
torch._dynamo.exc.Unsupported: Data-dependent branching
  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
  Hint: Use `torch.cond` to express dynamic control flow.

  Developer debug context: attempted to jump with TensorVariable()


from user code:
   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward
    if x.sum() > 0:

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
  • 使用 .data 访问张量数据

class Bad2(torch.nn.Module):
    def forward(self, x):
        x.data[0, 0] = 3
        return x

try:
    export(Bad2(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 调用不受支持的函数 (例如许多内置函数)

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

try:
    export(Bad3(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 148, in <module>
    export(Bad3(), (torch.randn(3, 3),))
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 360, in export
    return _export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2112, in _export
    ep = _export_for_training(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1975, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1677, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
    raise e.with_traceback(None) from None
torch._dynamo.exc.Unsupported: call_id not supported for sourceless TensorVariable

from user code:
   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 145, in forward
    return x + id(x)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

非严格导出

为了跟踪程序,torch.export 默认使用 TorchDynamo(一个字节码分析引擎)来象征性地分析 Python 代码并根据结果构建图。这种分析使得 torch.export 能够提供更强的安全保证,但并非所有 Python 代码都受支持,从而导致了这些图中断。

为了解决这个问题,在 PyTorch 2.3 中,我们引入了一种新的导出模式,称为非严格模式(non-strict mode)。在这种模式下,我们使用 Python 解释器跟踪程序,其执行方式与 eager 模式完全相同,从而允许我们跳过不受支持的 Python 特性。这可以通过添加 strict=False 标志来实现。

回顾一些导致图中断的先前示例

  • 调用不受支持的函数(例如许多内置函数)会被跟踪

通过,但在这种情况下,id(x) 在图中被特化为一个常量整数。这是因为 id(x) 不是张量操作,所以该操作不会记录在图中。

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:179 in forward, code: x = x + 1
            add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, 1);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:180 in forward, code: return x + id(x)
            add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, 139813962166368);  add = None
            return (add_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {}

tensor([[1.3981e+14, 1.3981e+14, 1.3981e+14],
        [1.3981e+14, 1.3981e+14, 1.3981e+14],
        [1.3981e+14, 1.3981e+14, 1.3981e+14]])

然而,仍然有一些特性需要对原始模块进行重写

控制流操作

torch.export 实际上支持数据相关的控制流。但这需要使用控制流操作来表达。例如,我们可以使用 cond 操作来修复上面的控制流示例,就像这样

class Bad1Fixed(torch.nn.Module):
    def forward(self, x):
        def true_fn(x):
            return torch.sin(x)
        def false_fn(x):
            return torch.cos(x)
        return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed)
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:205 in forward, code: return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
            sum_1: "f32[]" = torch.ops.aten.sum.default(x)
            gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None

             # File: /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_higher_order_ops/cond.py:137 in cond, code: return cond_op(pred, true_fn, false_fn, operands)
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x]);  gt = true_graph_0 = false_graph_0 = x = None
            getitem: "f32[3, 3]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 3]"):
                 # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:202 in true_fn, code: return torch.sin(x)
                sin: "f32[3, 3]" = torch.ops.aten.sin.default(x);  x = None
                return (sin,)

        class false_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 3]"):
                 # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:204 in false_fn, code: return torch.cos(x)
                cos: "f32[3, 3]" = torch.ops.aten.cos.default(x);  x = None
                return (cos,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

tensor([[0.8415, 0.8415, 0.8415],
        [0.8415, 0.8415, 0.8415],
        [0.8415, 0.8415, 0.8415]])
tensor([[0.5403, 0.5403, 0.5403],
        [0.5403, 0.5403, 0.5403],
        [0.5403, 0.5403, 0.5403]])

需要注意 cond 的一些限制

  • 谓词(即 x.sum() > 0)必须是布尔值或单元素张量。

  • 操作数(即 [x])必须是张量。

  • 分支函数(即 true_fnfalse_fn)的签名必须与操作数匹配,并且它们都必须返回具有相同元数据(例如,dtypeshape 等)的单个张量。

  • 分支函数不能修改输入或全局变量。

  • 分支函数不能访问闭包变量,除非函数在方法的范围内定义,可以访问 self

有关 cond 的更多详细信息,请查看 cond 文档

我们还可以使用 map,它将函数应用于第一个张量参数的第一个维度。

from torch._higher_order_ops.map import map as torch_map

class MapModule(torch.nn.Module):
    def forward(self, xs, y, z):
        def body(x, y, z):
            return x + y + z

        return torch_map(body, xs, y, z)

inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
exported_map_example = export(MapModule(), inps)
print(exported_map_example)
print(exported_map_example.module()(*inps))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, xs: "f32[6, 4]", y: "i64[]", z: "i64[]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:236 in forward, code: return torch_map(body, xs, y, z)
            body_graph_0 = self.body_graph_0
            map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y, z]);  body_graph_0 = xs = y = z = None
            getitem: "f32[6, 4]" = map_impl[0];  map_impl = None
            return (getitem,)

        class body_graph_0(torch.nn.Module):
            def forward(self, xs: "f32[4]", y: "i64[]", z: "i64[]"):
                 # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:234 in body, code: return x + y + z
                add: "f32[4]" = torch.ops.aten.add.Tensor(xs, y);  xs = y = None
                add_1: "f32[4]" = torch.ops.aten.add.Tensor(add, z);  add = z = None
                return (add_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='xs'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

tensor([[10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.]])

其他控制流操作包括 while_loopassociative_scanscan。有关每个操作符的更多文档,请参阅此页面

约束/动态形状

本节介绍导出程序的动态行为和表示。动态行为取决于要导出的特定模型,因此在本教程的大部分内容中,我们将重点介绍这个特定的玩具模型(并标注了结果张量的形状)

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [6, 5]
        x: torch.Tensor,  # [4]
        y: torch.Tensor,  # [8, 4]
        z: torch.Tensor,  # [32]
    ):
        x0 = x + y  # [8, 4]
        x1 = self.l(w)  # [6, 3]
        x2 = x0.flatten()  # [32]
        x3 = x2 + z  # [32]
        return x1, x3

默认情况下,torch.export 生成静态程序。这带来的一个后果是,在运行时,即使输入形状在 eager 模式下有效,程序也无法处理不同形状的输入。

w = torch.randn(6, 5)
x = torch.randn(4)
y = torch.randn(8, 4)
z = torch.randn(32)
model = DynamicModel()
ep = export(model, (w, x, y, z))
model(w, x, torch.randn(3, 4), torch.randn(12))
try:
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 286, in <module>
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 830, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 406, in __call__
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 393, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl
    return inner()
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_unlift.py", line 55, in _check_input_constraints_pre_hook
    _check_input_constraints_for_graph(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_export/utils.py", line 398, in _check_input_constraints_for_graph
    raise RuntimeError(
RuntimeError: Expected input at *args[2].shape[0] to be equal to 8, but got 3

基本概念:符号和 Guards

为了启用动态性,export() 提供了一个 dynamic_shapes 参数。使用动态形状最简单的方法是使用 Dim.AUTO 并查看返回的程序。动态行为是在输入维度级别指定的;对于每个输入,我们可以指定一个值元组

from torch.export.dynamic_shapes import Dim

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)

在我们查看生成的程序之前,让我们了解指定 dynamic_shapes 意味着什么,以及它如何与导出交互。对于每个指定 Dim 对象的输入维度,会分配一个符号,其范围为 [2, inf](为什么不是 [0, inf][1, inf]?我们将在 0/1 特化部分稍后解释)。

导出然后运行模型跟踪,查看模型执行的每个操作。每个单独的操作都可以发出所谓的“guards”;基本上是程序有效所必需的布尔条件。当 guards 涉及为输入维度分配的符号时,程序包含对有效输入形状的限制;即程序的动态行为。符号形状子系统负责接收所有发出的 guards 并生成一个符合所有这些 guards 的最终程序表示。在我们看到 ExportedProgram 中的这个“最终表示”之前,让我们看看我们正在跟踪的玩具模型发出的 guards。

在这里,每个前向输入张量都用跟踪开始时分配的符号进行了标注

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [s0, s1]
        x: torch.Tensor,  # [s2]
        y: torch.Tensor,  # [s3, s4]
        z: torch.Tensor,  # [s5]
    ):
        x0 = x + y  # guard: s2 == s4
        x1 = self.l(w)  # guard: s1 == 5
        x2 = x0.flatten()  # no guard added here
        x3 = x2 + z  # guard: s3 * s4 == s5
        return x1, x3

让我们理解每个操作和发出的 guards

  • x0 = x + y:这是一个带广播的逐元素相加,因为 x 是一个 1 维张量,而 y 是一个 2 维张量。x 沿着 y 的最后一个维度广播,发出 guard s2 == s4

  • x1 = self.l(w):调用 nn.Linear() 执行与模型参数的矩阵乘法。在导出中,参数、缓冲区和常量被认为是程序状态,被视为静态,因此这是一个动态输入(w: [s0, s1])与静态形状张量之间的矩阵乘法。这发出 guard s1 == 5

  • x2 = x0.flatten():这个调用实际上没有发出任何 guards!(至少没有与输入形状相关的)

  • x3 = x2 + zx2 在展平后形状为 [s3*s4],并且这个逐元素相加发出 s3 * s4 == s5

将所有这些 guards 写下并总结,这几乎就像一个数学证明,这也是符号形状子系统试图做的事情!总而言之,我们可以得出结论,程序必须具有以下输入形状才能有效:

  • w: [s0, 5]

  • x: [s2]

  • y: [s3, s2]

  • z: [s2*s3]

当我们最终打印导出程序以查看结果时,这些形状就是我们在相应输入上看到的标注

print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_l_weight: "f32[3, 5]", p_l_bias: "f32[3]", w: "f32[s0, 5]", x: "f32[s2]", y: "f32[s3, s2]", z: "f32[s2*s3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward, code: x0 = x + y  # [8, 4]
            add: "f32[s3, s2]" = torch.ops.aten.add.Tensor(x, y);  x = y = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward, code: x1 = self.l(w)  # [6, 3]
            linear: "f32[s0, 3]" = torch.ops.aten.linear.default(w, p_l_weight, p_l_bias);  w = p_l_weight = p_l_bias = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:270 in forward, code: x2 = x0.flatten()  # [32]
            flatten: "f32[s2*s3]" = torch.ops.aten.flatten.using_ints(add);  add = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward, code: x3 = x2 + z  # [32]
            add_1: "f32[s2*s3]" = torch.ops.aten.add.Tensor(flatten, z);  flatten = z = None
            return (linear, add_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_l_weight'), target='l.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_l_bias'), target='l.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='w'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {s0: VR[2, int_oo], s2: VR[2, int_oo], s3: VR[2, int_oo], s2*s3: VR[4, int_oo]}

另一个需要注意的特性是上面的 range_constraints 字段,它包含每个符号的有效范围。目前这并不是很有趣,因为这个导出调用没有发出与符号边界相关的 guards,并且每个基本符号都有一个通用边界,但这将在稍后出现。

到目前为止,因为我们一直在导出这个玩具模型,所以这种体验并不能代表调试动态形状 guards 和问题的典型难度。在大多数情况下, emitted guards 的内容并不明显,也难以确定哪些操作和用户代码部分负责。对于这个玩具模型,我们可以精确定位具体的行,而且 guards 也相当直观。

在更复杂的情况下,一个有用的第一步始终是启用详细日志记录。这可以通过环境变量 TORCH_LOGS="+dynamic" 完成,或通过交互式命令 torch._logging.set_logs(dynamic=10) 完成。

torch._logging.set_logs(dynamic=10)
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
I0423 16:38:38.478000 660 torch/fx/experimental/symbolic_shapes.py:3334] [12/0] create_env
I0423 16:38:38.480000 660 torch/fx/experimental/symbolic_shapes.py:4606] [12/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.480000 660 torch/fx/experimental/symbolic_shapes.py:4606] [12/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0423 16:38:38.481000 660 torch/fx/experimental/symbolic_shapes.py:7018] [12/0] runtime_assert True == True [statically known]
I0423 16:38:38.483000 660 torch/fx/experimental/symbolic_shapes.py:4606] [12/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.485000 660 torch/fx/experimental/symbolic_shapes.py:4606] [12/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.486000 660 torch/fx/experimental/symbolic_shapes.py:4606] [12/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.488000 660 torch/fx/experimental/symbolic_shapes.py:4606] [12/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0423 16:38:38.490000 660 torch/fx/experimental/symbolic_shapes.py:6787] [12/0] eval size_oblivious(Eq(s2, 1)) == False [statically known]
V0423 16:38:38.491000 660 torch/fx/experimental/symbolic_shapes.py:7018] [12/0] runtime_assert True == True [statically known]
V0423 16:38:38.492000 660 torch/fx/experimental/symbolic_shapes.py:6787] [12/0] eval size_oblivious(Eq(s4, 1)) == False [statically known]
I0423 16:38:38.492000 660 torch/fx/experimental/symbolic_shapes.py:6630] [12/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # [8, 4]  # workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
I0423 16:38:38.493000 660 torch/fx/experimental/symbolic_shapes.py:6234] [12/0] set_replacement s4 = s2 (solve) VR[2, int_oo]
V0423 16:38:38.495000 660 torch/fx/experimental/symbolic_shapes.py:6787] [12/0] eval size_oblivious(Ne(s2, 1)) == True [statically known]
V0423 16:38:38.496000 660 torch/fx/experimental/symbolic_shapes.py:6787] [12/0] eval size_oblivious(Ne(s3, 1)) == True [statically known]
I0423 16:38:38.501000 660 torch/fx/experimental/symbolic_shapes.py:6630] [12/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2236 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0423 16:38:38.502000 660 torch/fx/experimental/symbolic_shapes.py:6071] [12/0] _update_var_to_range s1 = VR[5, 5] (update)
I0423 16:38:38.503000 660 torch/fx/experimental/symbolic_shapes.py:6234] [12/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
V0423 16:38:38.505000 660 torch/fx/experimental/symbolic_shapes.py:6787] [12/0] eval size_oblivious(Eq(s0, 1)) == False [statically known]
V0423 16:38:38.510000 660 torch/fx/experimental/symbolic_shapes.py:6787] [12/0] eval size_oblivious(Eq(s2*s3, 1)) == False [statically known]
V0423 16:38:38.511000 660 torch/fx/experimental/symbolic_shapes.py:6787] [12/0] eval size_oblivious(Eq(s5, 1)) == False [statically known]
I0423 16:38:38.511000 660 torch/fx/experimental/symbolic_shapes.py:6630] [12/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0423 16:38:38.512000 660 torch/fx/experimental/symbolic_shapes.py:6071] [12/0] _update_var_to_range s5 = VR[4, int_oo] (update)
I0423 16:38:38.513000 660 torch/fx/experimental/symbolic_shapes.py:6234] [12/0] set_replacement s5 = s2*s3 (solve) VR[4, int_oo]
V0423 16:38:38.515000 660 torch/fx/experimental/symbolic_shapes.py:6787] [12/0] eval size_oblivious(Ne(s2*s3, 1)) == True [statically known]
I0423 16:38:38.520000 660 torch/fx/experimental/symbolic_shapes.py:4734] [12/0] produce_guards
V0423 16:38:38.521000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['w'].size()[0] s0 None
V0423 16:38:38.521000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['w'].size()[1] 5 None
V0423 16:38:38.522000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['w'].stride()[0] 5 None
V0423 16:38:38.522000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['w'].stride()[1] 1 None
V0423 16:38:38.522000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['w'].storage_offset() 0 None
V0423 16:38:38.522000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['x'].size()[0] s2 None
V0423 16:38:38.523000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['x'].stride()[0] 1 None
V0423 16:38:38.523000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['x'].storage_offset() 0 None
V0423 16:38:38.523000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['y'].size()[0] s3 None
V0423 16:38:38.524000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['y'].size()[1] s2 None
V0423 16:38:38.524000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['y'].stride()[0] s2 None
V0423 16:38:38.524000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['y'].stride()[1] 1 None
V0423 16:38:38.525000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['y'].storage_offset() 0 None
V0423 16:38:38.525000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['z'].size()[0] s2*s3 None
V0423 16:38:38.525000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['z'].stride()[0] 1 None
V0423 16:38:38.526000 660 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['z'].storage_offset() 0 None
V0423 16:38:38.558000 660 torch/fx/experimental/symbolic_shapes.py:6787] eval size_oblivious(Ne(s0, 1)) == True [statically known]

即使对于这个简单的玩具模型,这也会输出相当多的信息。这里的日志行已在前面和后面进行了截断,以忽略不必要的信息,但查看日志,我们可以看到与我们上面描述相关的内容;例如,符号的分配

"""
create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
runtime_assert True == True [statically known]
create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
"""
"\ncreate_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\nruntime_assert True == True [statically known]\ncreate_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\n"

带有 create_symbol 的行显示何时分配了一个新符号,并且日志还会识别为其分配的张量变量名称和维度。在其他行中,我们还可以看到发出的 guards

"""
runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
"""
'\nruntime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"\nruntime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"\nruntime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"\n'

[guard added] 消息旁边,我们还可以看到负责的用户代码行 - 幸运的是,这里的模型足够简单。在许多实际案例中,情况并非如此直接:高级 torch 操作可能有复杂的 fake-kernel 实现或操作符分解,这使得 guards 的发出位置和内容变得复杂。在这种情况下,深入挖掘和调查的最佳方法是遵循日志的建议,并使用环境变量 TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..." 重新运行,以进一步追溯相关的 guard。

Dim.AUTO 只是与 dynamic_shapes 交互的可用选项之一;截至撰写本文时,还有另外两个选项:Dim.DYNAMICDim.STATICDim.STATIC 只是将一个维度标记为静态,而 Dim.DYNAMIC 在所有方面都与 Dim.AUTO 相似,除了一个例外:当特化为常量时,它会引发错误;这是为了保持动态性。例如,看看当在动态标记的维度上发出静态 guard 时会发生什么

dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0423 16:38:38.579000 660 torch/fx/experimental/symbolic_shapes.py:3334] [13/0] create_env
I0423 16:38:38.581000 660 torch/fx/experimental/symbolic_shapes.py:4606] [13/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.581000 660 torch/fx/experimental/symbolic_shapes.py:4606] [13/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0423 16:38:38.582000 660 torch/fx/experimental/symbolic_shapes.py:7018] [13/0] runtime_assert True == True [statically known]
I0423 16:38:38.584000 660 torch/fx/experimental/symbolic_shapes.py:4606] [13/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.586000 660 torch/fx/experimental/symbolic_shapes.py:4606] [13/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.587000 660 torch/fx/experimental/symbolic_shapes.py:4606] [13/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.589000 660 torch/fx/experimental/symbolic_shapes.py:4606] [13/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0423 16:38:38.591000 660 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Eq(s2, 1)) == False [statically known]
V0423 16:38:38.592000 660 torch/fx/experimental/symbolic_shapes.py:7018] [13/0] runtime_assert True == True [statically known]
V0423 16:38:38.592000 660 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Eq(s4, 1)) == False [statically known]
I0423 16:38:38.593000 660 torch/fx/experimental/symbolic_shapes.py:6630] [13/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # [8, 4]  # workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
I0423 16:38:38.594000 660 torch/fx/experimental/symbolic_shapes.py:6234] [13/0] set_replacement s4 = s2 (solve) VR[2, int_oo]
V0423 16:38:38.596000 660 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Ne(s2, 1)) == True [statically known]
V0423 16:38:38.596000 660 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Ne(s3, 1)) == True [statically known]
I0423 16:38:38.602000 660 torch/fx/experimental/symbolic_shapes.py:6630] [13/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2236 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0423 16:38:38.603000 660 torch/fx/experimental/symbolic_shapes.py:6071] [13/0] _update_var_to_range s1 = VR[5, 5] (update)
I0423 16:38:38.604000 660 torch/fx/experimental/symbolic_shapes.py:6234] [13/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
V0423 16:38:38.605000 660 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Eq(s0, 1)) == False [statically known]
V0423 16:38:38.610000 660 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Eq(s2*s3, 1)) == False [statically known]
V0423 16:38:38.611000 660 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Eq(s5, 1)) == False [statically known]
I0423 16:38:38.612000 660 torch/fx/experimental/symbolic_shapes.py:6630] [13/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0423 16:38:38.613000 660 torch/fx/experimental/symbolic_shapes.py:6071] [13/0] _update_var_to_range s5 = VR[4, int_oo] (update)
I0423 16:38:38.614000 660 torch/fx/experimental/symbolic_shapes.py:6234] [13/0] set_replacement s5 = s2*s3 (solve) VR[4, int_oo]
V0423 16:38:38.615000 660 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Ne(s2*s3, 1)) == True [statically known]
I0423 16:38:38.621000 660 torch/fx/experimental/symbolic_shapes.py:4734] [13/0] produce_guards
V0423 16:38:38.621000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['w'].size()[0] s0 None
V0423 16:38:38.622000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['w'].size()[1] 5 RelaxedUnspecConstraint(warn_only=False)
V0423 16:38:38.622000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['w'].stride()[0] 5 None
V0423 16:38:38.623000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['w'].stride()[1] 1 None
V0423 16:38:38.623000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['w'].storage_offset() 0 None
V0423 16:38:38.623000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['x'].size()[0] s2 None
V0423 16:38:38.624000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['x'].stride()[0] 1 None
V0423 16:38:38.624000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['x'].storage_offset() 0 None
V0423 16:38:38.624000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['y'].size()[0] s3 None
V0423 16:38:38.625000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['y'].size()[1] s2 None
V0423 16:38:38.625000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['y'].stride()[0] s2 None
V0423 16:38:38.625000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['y'].stride()[1] 1 None
V0423 16:38:38.626000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['y'].storage_offset() 0 None
V0423 16:38:38.626000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['z'].size()[0] s2*s3 None
V0423 16:38:38.626000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['z'].stride()[0] 1 None
V0423 16:38:38.627000 660 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['z'].storage_offset() 0 None
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0] Error while creating guard:
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0] Name: ''
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]     Source: shape_env
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]     Create Function: SHAPE_ENV
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]     Guard Types: None
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]     Code List: None
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]     Object Weakref: None
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]     Guarded Class Weakref: None
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0] Traceback (most recent call last):
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_guards.py", line 357, in create
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]     return self.create_fn(builder, self)
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]     python_code_parts, verbose_code_parts = _get_code_parts(
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]     return output_graph.shape_env.produce_guards_verbose(
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]     raise ConstraintViolationError(
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
E0423 16:38:38.628000 660 torch/_guards.py:359] [13/0]   - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).
E0423 16:38:38.631000 660 torch/_guards.py:361] [13/0] Created at:
E0423 16:38:38.631000 660 torch/_guards.py:361] [13/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 694, in transform
E0423 16:38:38.631000 660 torch/_guards.py:361] [13/0]     tracer = InstructionTranslator(
E0423 16:38:38.631000 660 torch/_guards.py:361] [13/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3329, in __init__
E0423 16:38:38.631000 660 torch/_guards.py:361] [13/0]     output=OutputGraph(
E0423 16:38:38.631000 660 torch/_guards.py:361] [13/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 358, in __init__
E0423 16:38:38.631000 660 torch/_guards.py:361] [13/0]     self.init_ambient_guards()
E0423 16:38:38.631000 660 torch/_guards.py:361] [13/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 512, in init_ambient_guards
E0423 16:38:38.631000 660 torch/_guards.py:361] [13/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1722, in inner
    raise constraint_violation_error
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1677, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
    return self._torchdynamo_orig_callable(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
    return _compile(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 906, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 2481, in __init__
    guard.create(builder)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_guards.py", line 357, in create
    return self.create_fn(builder, self)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV
    python_code_parts, verbose_code_parts = _get_code_parts(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts
    return output_graph.shape_env.produce_guards_verbose(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 418, in <module>
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 360, in export
    return _export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2112, in _export
    ep = _export_for_training(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1975, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 756, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).

静态 guards 并不总是模型的固有特性;它们也可能来自用户指定。事实上,导致形状特化的一个常见陷阱是用户为等价维度指定冲突的标记;一个动态,另一个静态。当 x.shape[0]y.shape[1] 出现这种情况时,会引发相同的错误类型。

dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO)
dynamic_shapes["x"] = (Dim.STATIC,)
dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0423 16:38:38.649000 660 torch/fx/experimental/symbolic_shapes.py:3334] [14/0] create_env
I0423 16:38:38.651000 660 torch/fx/experimental/symbolic_shapes.py:4606] [14/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.652000 660 torch/fx/experimental/symbolic_shapes.py:4606] [14/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0423 16:38:38.653000 660 torch/fx/experimental/symbolic_shapes.py:7018] [14/0] runtime_assert True == True [statically known]
I0423 16:38:38.656000 660 torch/fx/experimental/symbolic_shapes.py:4606] [14/0] create_symbol s2 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.656000 660 torch/fx/experimental/symbolic_shapes.py:4606] [14/0] create_symbol s3 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.659000 660 torch/fx/experimental/symbolic_shapes.py:4606] [14/0] create_symbol s4 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0423 16:38:38.662000 660 torch/fx/experimental/symbolic_shapes.py:6787] [14/0] eval size_oblivious(Eq(s3, 1)) == False [statically known]
I0423 16:38:38.665000 660 torch/fx/experimental/symbolic_shapes.py:6630] [14/0] runtime_assert Eq(s3, 4) [guard added] x0 = x + y  # [8, 4]  # workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s3, 4)"
V0423 16:38:38.666000 660 torch/fx/experimental/symbolic_shapes.py:6071] [14/0] _update_var_to_range s3 = VR[4, 4] (update)
I0423 16:38:38.667000 660 torch/fx/experimental/symbolic_shapes.py:6234] [14/0] set_replacement s3 = 4 (range_refined_to_singleton) VR[4, 4]
V0423 16:38:38.669000 660 torch/fx/experimental/symbolic_shapes.py:6787] [14/0] eval size_oblivious(Ne(s2, 1)) == True [statically known]
I0423 16:38:38.674000 660 torch/fx/experimental/symbolic_shapes.py:6630] [14/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2236 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0423 16:38:38.675000 660 torch/fx/experimental/symbolic_shapes.py:6071] [14/0] _update_var_to_range s1 = VR[5, 5] (update)
I0423 16:38:38.676000 660 torch/fx/experimental/symbolic_shapes.py:6234] [14/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
V0423 16:38:38.677000 660 torch/fx/experimental/symbolic_shapes.py:6787] [14/0] eval size_oblivious(Eq(s0, 1)) == False [statically known]
V0423 16:38:38.678000 660 torch/fx/experimental/symbolic_shapes.py:7018] [14/0] runtime_assert True == True [statically known]
V0423 16:38:38.684000 660 torch/fx/experimental/symbolic_shapes.py:6787] [14/0] eval size_oblivious(Eq(s4, 1)) == False [statically known]
I0423 16:38:38.688000 660 torch/fx/experimental/symbolic_shapes.py:6630] [14/0] runtime_assert Eq(4*s2, s4) [guard added] x3 = x2 + z  # [32]  # workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(4*s2, s4)"
V0423 16:38:38.691000 660 torch/fx/experimental/symbolic_shapes.py:6071] [14/0] _update_var_to_range s4 = VR[8, int_oo] (update)
I0423 16:38:38.691000 660 torch/fx/experimental/symbolic_shapes.py:6234] [14/0] set_replacement s4 = 4*s2 (solve) VR[8, int_oo]
I0423 16:38:38.698000 660 torch/fx/experimental/symbolic_shapes.py:4734] [14/0] produce_guards
V0423 16:38:38.698000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['w'].size()[0] s0 None
V0423 16:38:38.699000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['w'].size()[1] 5 None
V0423 16:38:38.699000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['w'].stride()[0] 5 None
V0423 16:38:38.699000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['w'].stride()[1] 1 None
V0423 16:38:38.700000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['w'].storage_offset() 0 None
V0423 16:38:38.700000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['x'].size()[0] 4 None
V0423 16:38:38.700000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['x'].stride()[0] 1 None
V0423 16:38:38.700000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['x'].storage_offset() 0 None
V0423 16:38:38.701000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['y'].size()[0] s2 None
V0423 16:38:38.701000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['y'].size()[1] 4 RelaxedUnspecConstraint(warn_only=False)
V0423 16:38:38.701000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['y'].stride()[0] 4 None
V0423 16:38:38.702000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['y'].stride()[1] 1 None
V0423 16:38:38.702000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['y'].storage_offset() 0 None
V0423 16:38:38.702000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['z'].size()[0] 4*s2 None
V0423 16:38:38.702000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['z'].stride()[0] 1 None
V0423 16:38:38.703000 660 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['z'].storage_offset() 0 None
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0] Error while creating guard:
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0] Name: ''
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]     Source: shape_env
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]     Create Function: SHAPE_ENV
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]     Guard Types: None
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]     Code List: None
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]     Object Weakref: None
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]     Guarded Class Weakref: None
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0] Traceback (most recent call last):
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_guards.py", line 357, in create
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]     return self.create_fn(builder, self)
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]     python_code_parts, verbose_code_parts = _get_code_parts(
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]     return output_graph.shape_env.produce_guards_verbose(
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]     raise ConstraintViolationError(
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
E0423 16:38:38.704000 660 torch/_guards.py:359] [14/0]   - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).
E0423 16:38:38.706000 660 torch/_guards.py:361] [14/0] Created at:
E0423 16:38:38.706000 660 torch/_guards.py:361] [14/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 694, in transform
E0423 16:38:38.706000 660 torch/_guards.py:361] [14/0]     tracer = InstructionTranslator(
E0423 16:38:38.706000 660 torch/_guards.py:361] [14/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3329, in __init__
E0423 16:38:38.706000 660 torch/_guards.py:361] [14/0]     output=OutputGraph(
E0423 16:38:38.706000 660 torch/_guards.py:361] [14/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 358, in __init__
E0423 16:38:38.706000 660 torch/_guards.py:361] [14/0]     self.init_ambient_guards()
E0423 16:38:38.706000 660 torch/_guards.py:361] [14/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 512, in init_ambient_guards
E0423 16:38:38.706000 660 torch/_guards.py:361] [14/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1722, in inner
    raise constraint_violation_error
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1677, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
    return self._torchdynamo_orig_callable(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
    return _compile(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 906, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 2481, in __init__
    guard.create(builder)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_guards.py", line 357, in create
    return self.create_fn(builder, self)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV
    python_code_parts, verbose_code_parts = _get_code_parts(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts
    return output_graph.shape_env.produce_guards_verbose(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 431, in <module>
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 360, in export
    return _export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2112, in _export
    ep = _export_for_training(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1975, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 756, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).

您可能会问,为什么导出(export)会“特化”(specializes),即为什么我们通过采用静态(static)路径来解决这种静态/动态(static/dynamic)冲突。答案在于上面描述的符号形状(symbolic shapes)系统,包括符号(symbols)和守卫(guards)。当 x.shape[0] 被标记为静态时,我们不会分配一个符号,而是将其形状视为一个具体的整数 4 进行编译。而 y.shape[1] 则会分配一个符号,因此我们最终会发出守卫 s3 == 4,从而导致特化。

导出的一个特性是,在跟踪(tracing)期间,像断言(asserts)、torch._check()if/else 条件等语句也会发出守卫。看看当我们用这些语句增强现有模型时会发生什么:

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(self, w, x, y, z):
        assert w.shape[0] <= 512
        torch._check(x.shape[0] >= 4)
        if w.shape[0] == x.shape[0] + 2:
            x0 = x + y
            x1 = self.l(w)
            x2 = x0.flatten()
            x3 = x2 + z
            return x1, x3
        else:
            return w

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
try:
    ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0423 16:38:38.721000 660 torch/fx/experimental/symbolic_shapes.py:3334] [15/0] create_env
I0423 16:38:38.723000 660 torch/fx/experimental/symbolic_shapes.py:4606] [15/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.723000 660 torch/fx/experimental/symbolic_shapes.py:4606] [15/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0423 16:38:38.724000 660 torch/fx/experimental/symbolic_shapes.py:7018] [15/0] runtime_assert True == True [statically known]
I0423 16:38:38.726000 660 torch/fx/experimental/symbolic_shapes.py:4606] [15/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.728000 660 torch/fx/experimental/symbolic_shapes.py:4606] [15/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.729000 660 torch/fx/experimental/symbolic_shapes.py:4606] [15/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.731000 660 torch/fx/experimental/symbolic_shapes.py:4606] [15/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:38.737000 660 torch/fx/experimental/symbolic_shapes.py:6630] [15/0] runtime_assert s0 <= 512 [guard added] assert w.shape[0] <= 512  # workspace/intermediate_source/torch_export_tutorial.py:450 in forward (_dynamo/symbolic_convert.py:669 in inner), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s0 <= 512"
V0423 16:38:38.738000 660 torch/fx/experimental/symbolic_shapes.py:6071] [15/0] _update_var_to_range s0 = VR[2, 512] (update)
I0423 16:38:38.742000 660 torch/fx/experimental/symbolic_shapes.py:6630] [15/0] runtime_assert s2 >= 4 [guard added] torch._check(x.shape[0] >= 4)  # workspace/intermediate_source/torch_export_tutorial.py:451 in forward (_dynamo/utils.py:3284 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s2 >= 4"
V0423 16:38:38.743000 660 torch/fx/experimental/symbolic_shapes.py:6071] [15/0] _update_var_to_range s2 = VR[4, int_oo] (update)
I0423 16:38:38.748000 660 torch/fx/experimental/symbolic_shapes.py:6630] [15/0] eval Eq(s0, s2 + 2) [guard added] if w.shape[0] == x.shape[0] + 2:  # workspace/intermediate_source/torch_export_tutorial.py:452 in forward (_dynamo/variables/tensor.py:1245 in evaluate_expr), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, s2 + 2)"
V0423 16:38:38.749000 660 torch/fx/experimental/symbolic_shapes.py:6071] [15/0] _update_var_to_range s0 = VR[6, 512] (update)
V0423 16:38:38.752000 660 torch/fx/experimental/symbolic_shapes.py:6071] [15/0] _update_var_to_range s2 = VR[4, 510] (update)
I0423 16:38:38.753000 660 torch/fx/experimental/symbolic_shapes.py:6234] [15/0] set_replacement s0 = s2 + 2 (solve) VR[6, 512]
V0423 16:38:38.755000 660 torch/fx/experimental/symbolic_shapes.py:6787] [15/0] eval size_oblivious(Eq(s2, 1)) == False [statically known]
V0423 16:38:38.755000 660 torch/fx/experimental/symbolic_shapes.py:7018] [15/0] runtime_assert True == True [statically known]
V0423 16:38:38.756000 660 torch/fx/experimental/symbolic_shapes.py:6787] [15/0] eval size_oblivious(Eq(s4, 1)) == False [statically known]
I0423 16:38:38.758000 660 torch/fx/experimental/symbolic_shapes.py:6630] [15/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # workspace/intermediate_source/torch_export_tutorial.py:453 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
V0423 16:38:38.759000 660 torch/fx/experimental/symbolic_shapes.py:6071] [15/0] _update_var_to_range s4 = VR[4, 510] (update)
I0423 16:38:38.760000 660 torch/fx/experimental/symbolic_shapes.py:6234] [15/0] set_replacement s4 = s2 (solve) VR[4, 510]
V0423 16:38:38.762000 660 torch/fx/experimental/symbolic_shapes.py:6787] [15/0] eval size_oblivious(Ne(s2, 1)) == True [statically known]
V0423 16:38:38.762000 660 torch/fx/experimental/symbolic_shapes.py:6787] [15/0] eval size_oblivious(Ne(s3, 1)) == True [statically known]
I0423 16:38:38.768000 660 torch/fx/experimental/symbolic_shapes.py:6630] [15/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # workspace/intermediate_source/torch_export_tutorial.py:454 in forward (_meta_registrations.py:2236 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0423 16:38:38.769000 660 torch/fx/experimental/symbolic_shapes.py:6071] [15/0] _update_var_to_range s1 = VR[5, 5] (update)
I0423 16:38:38.770000 660 torch/fx/experimental/symbolic_shapes.py:6234] [15/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
V0423 16:38:38.778000 660 torch/fx/experimental/symbolic_shapes.py:6787] [15/0] eval size_oblivious(Eq(s2*s3, 1)) == False [statically known]
V0423 16:38:38.779000 660 torch/fx/experimental/symbolic_shapes.py:6787] [15/0] eval size_oblivious(Eq(s5, 1)) == False [statically known]
I0423 16:38:38.787000 660 torch/fx/experimental/symbolic_shapes.py:6630] [15/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # workspace/intermediate_source/torch_export_tutorial.py:456 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0423 16:38:38.788000 660 torch/fx/experimental/symbolic_shapes.py:6071] [15/0] _update_var_to_range s5 = VR[8, int_oo] (update)
I0423 16:38:38.789000 660 torch/fx/experimental/symbolic_shapes.py:6234] [15/0] set_replacement s5 = s2*s3 (solve) VR[8, int_oo]
V0423 16:38:38.791000 660 torch/fx/experimental/symbolic_shapes.py:6787] [15/0] eval size_oblivious(Ne(s2*s3, 1)) == True [statically known]
V0423 16:38:38.794000 660 torch/fx/experimental/symbolic_shapes.py:7018] [15/0] runtime_assert s2 >= 4 == True [statically known]
I0423 16:38:38.800000 660 torch/fx/experimental/symbolic_shapes.py:4734] [15/0] produce_guards
V0423 16:38:38.800000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['w'].size()[0] s2 + 2 None
V0423 16:38:38.800000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['w'].size()[1] 5 None
V0423 16:38:38.801000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['w'].stride()[0] 5 None
V0423 16:38:38.801000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['w'].stride()[1] 1 None
V0423 16:38:38.801000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['w'].storage_offset() 0 None
V0423 16:38:38.801000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['x'].size()[0] s2 None
V0423 16:38:38.802000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['x'].stride()[0] 1 None
V0423 16:38:38.802000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['x'].storage_offset() 0 None
V0423 16:38:38.802000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['y'].size()[0] s3 None
V0423 16:38:38.802000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['y'].size()[1] s2 None
V0423 16:38:38.803000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['y'].stride()[0] s2 None
V0423 16:38:38.803000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['y'].stride()[1] 1 None
V0423 16:38:38.803000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['y'].storage_offset() 0 None
V0423 16:38:38.804000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['z'].size()[0] s2*s3 None
V0423 16:38:38.804000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['z'].stride()[0] 1 None
V0423 16:38:38.804000 660 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['z'].storage_offset() 0 None

每个语句都会发出额外的守卫,导出的程序会显示这些变化;s0s2 + 2 取代,s2 现在包含上下界,这体现在 range_constraints 中。

对于 if/else 条件,您可能会问为什么取了 True 分支,以及为什么跟踪发出的守卫不是 w.shape[0] != x.shape[0] + 2。答案是,导出是由跟踪提供的示例输入(sample inputs)引导的,并对所选的分支进行特化。如果提供了不同的示例输入形状导致 if 条件失败,导出将跟踪并发出对应于 else 分支的守卫。此外,您可能还会问为什么我们只跟踪了 if 分支,以及是否可以在程序中保持控制流(control-flow)并保留两个分支。为此,请参阅上面 Control Flow Ops 部分中关于重写模型代码的内容。

0/1 特化

既然我们在讨论守卫和特化,现在是谈论我们之前提到的 0/1 特化问题的好时机。总而言之,导出将对值为 0 或 1 的示例输入维度进行特化,因为这些形状具有在跟踪时(trace-time)的属性,这些属性不适用于其他形状。例如,大小为 1 的张量可以广播(broadcast),而其他大小可能会失败;大小为 0 的张量……。这仅意味着当您希望程序硬编码这些维度时,应该指定 0/1 的示例输入;当需要动态行为时,则应指定非 0/1 的示例输入。看看导出此线性层在运行时会发生什么:

ep = export(
    torch.nn.Linear(4, 3),
    (torch.randn(1, 4),),
    dynamic_shapes={
        "input": (Dim.AUTO, Dim.STATIC),
    },
)
try:
    ep.module()(torch.randn(2, 4))
except Exception:
    tb.print_exc()
I0423 16:38:38.866000 660 torch/fx/experimental/symbolic_shapes.py:3334] [3/1] create_env
I0423 16:38:38.880000 660 torch/fx/experimental/symbolic_shapes.py:4734] [3/1] produce_guards
V0423 16:38:38.880000 660 torch/fx/experimental/symbolic_shapes.py:4954] [3/1] track_symint L['args'][0].size()[0] 1 None
V0423 16:38:38.880000 660 torch/fx/experimental/symbolic_shapes.py:4954] [3/1] track_symint L['args'][0].size()[1] 4 None
V0423 16:38:38.880000 660 torch/fx/experimental/symbolic_shapes.py:4954] [3/1] track_symint L['args'][0].stride()[0] 4 None
V0423 16:38:38.881000 660 torch/fx/experimental/symbolic_shapes.py:4954] [3/1] track_symint L['args'][0].stride()[1] 1 None
V0423 16:38:38.881000 660 torch/fx/experimental/symbolic_shapes.py:4954] [3/1] track_symint L['args'][0].storage_offset() 0 None
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 500, in <module>
    ep.module()(torch.randn(2, 4))
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 830, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 406, in __call__
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/graph_module.py", line 393, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl
    return inner()
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_unlift.py", line 55, in _check_input_constraints_pre_hook
    _check_input_constraints_for_graph(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_export/utils.py", line 398, in _check_input_constraints_for_graph
    raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 1, but got 2

命名维度 (Named Dims)

到目前为止,我们只讨论了三种指定动态形状的方式:Dim.AUTODim.DYNAMICDim.STATIC。这些方式的吸引力在于用户体验摩擦小;模型跟踪期间发出的所有守卫都得到遵守,并且导出在底层会自动处理动态行为,例如最小/最大范围、关系以及静态/动态维度。动态形状子系统本质上充当一个“发现”过程,总结这些守卫并呈现导出认为的程序的整体动态行为。这种设计的缺点在于用户对这些模型的动态行为有更强的期望或信念时——也许强烈希望保持动态性,不惜一切代价避免在特定维度上特化,或者我们只是想通过修改原始模型代码来捕获动态行为的变化,或者可能涉及底层的分解(decompositions)或元核(meta-kernels)。这些变化不会被检测到,并且 export() 调用很可能会成功,除非有检查生成 ExportedProgram 表示的测试。

对于这种情况,我们建议采用“传统”的指定动态形状的方式,这对于导出的长期用户可能比较熟悉:命名 Dims

dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
    "x": (dx, None),
    "y": (2 * dx, dh),
}

这种动态形状的风格允许用户指定为输入维度分配哪些符号,这些符号的最小/最大边界,并限制生成的 ExportedProgram 的动态行为;如果模型跟踪发出的守卫与给定的关系或静态/动态规范冲突,则会引发 ConstraintViolation 错误。例如,在上面的规范中,断言了以下内容:

  • x.shape[0] 的范围应为 [4, 256],并与 y.shape[0] 通过 y.shape[0] == 2 * x.shape[0] 关联。

  • x.shape[1] 是静态的。

  • y.shape[1] 的范围为 [2, 512],与任何其他维度无关。

在这种设计中,我们允许使用一元线性表达式来指定维度之间的关系:对于任何维度,都可以指定 A * dim + B。这使得用户可以为动态维度指定更复杂的约束,例如整数可分性。

dx = Dim("dx", min=4, max=512)
dynamic_shapes = {
    "x": (4 * dx, None)  # x.shape[0] has range [16, 2048], and is divisible by 4.
}

约束违例,建议的修复

这种指定风格(在引入 Dim.AUTO 之前)的一个常见问题是,规范通常与模型跟踪产生的结果不匹配。这会导致 ConstraintViolation 错误以及导出建议的修复——例如,看看使用此模型和规范时的情况,该模型本质上要求 xy 的维度 0 相等,并要求维度 1 是静态的。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        w = x + y
        return w + torch.ones(4)

dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
try:
    ep = export(
        Foo(),
        (torch.randn(6, 4), torch.randn(6, 4)),
        dynamic_shapes={
            "x": (dx, d1),
            "y": (dy, d1),
        },
    )
except Exception:
    tb.print_exc()
I0423 16:38:39.026000 660 torch/fx/experimental/symbolic_shapes.py:3334] [16/0] create_env
I0423 16:38:39.028000 660 torch/fx/experimental/symbolic_shapes.py:4606] [16/0] create_symbol s0 = 6 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:39.029000 660 torch/fx/experimental/symbolic_shapes.py:4606] [16/0] create_symbol s1 = 4 for L['x'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0423 16:38:39.030000 660 torch/fx/experimental/symbolic_shapes.py:7018] [16/0] runtime_assert True == True [statically known]
I0423 16:38:39.032000 660 torch/fx/experimental/symbolic_shapes.py:4606] [16/0] create_symbol s2 = 6 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0423 16:38:39.033000 660 torch/fx/experimental/symbolic_shapes.py:4606] [16/0] create_symbol s3 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0423 16:38:39.037000 660 torch/fx/experimental/symbolic_shapes.py:6787] [16/0] eval size_oblivious(Eq(s1, 1)) == False [statically known]
V0423 16:38:39.038000 660 torch/fx/experimental/symbolic_shapes.py:7018] [16/0] runtime_assert True == True [statically known]
V0423 16:38:39.038000 660 torch/fx/experimental/symbolic_shapes.py:6787] [16/0] eval size_oblivious(Eq(s0, 1)) == False [statically known]
V0423 16:38:39.039000 660 torch/fx/experimental/symbolic_shapes.py:6787] [16/0] eval size_oblivious(Eq(s3, 1)) == False [statically known]
I0423 16:38:39.041000 660 torch/fx/experimental/symbolic_shapes.py:6630] [16/0] runtime_assert Eq(s1, s3) [guard added] w = x + y  # workspace/intermediate_source/torch_export_tutorial.py:552 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, s3)"
I0423 16:38:39.042000 660 torch/fx/experimental/symbolic_shapes.py:6234] [16/0] set_replacement s3 = s1 (solve) VR[2, int_oo]
V0423 16:38:39.043000 660 torch/fx/experimental/symbolic_shapes.py:6787] [16/0] eval size_oblivious(Eq(s2, 1)) == False [statically known]
I0423 16:38:39.045000 660 torch/fx/experimental/symbolic_shapes.py:6630] [16/0] runtime_assert Eq(s0, s2) [guard added] w = x + y  # workspace/intermediate_source/torch_export_tutorial.py:552 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, s2)"
I0423 16:38:39.046000 660 torch/fx/experimental/symbolic_shapes.py:6234] [16/0] set_replacement s2 = s0 (solve) VR[2, int_oo]
V0423 16:38:39.047000 660 torch/fx/experimental/symbolic_shapes.py:6787] [16/0] eval size_oblivious(Ne(s1, 1)) == True [statically known]
V0423 16:38:39.048000 660 torch/fx/experimental/symbolic_shapes.py:6787] [16/0] eval size_oblivious(Ne(s0, 1)) == True [statically known]
I0423 16:38:39.055000 660 torch/fx/experimental/symbolic_shapes.py:6630] [16/0] runtime_assert Eq(s1, 4) [guard added] return w + torch.ones(4)  # workspace/intermediate_source/torch_export_tutorial.py:553 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 4)"
V0423 16:38:39.055000 660 torch/fx/experimental/symbolic_shapes.py:6071] [16/0] _update_var_to_range s1 = VR[4, 4] (update)
I0423 16:38:39.056000 660 torch/fx/experimental/symbolic_shapes.py:6234] [16/0] set_replacement s1 = 4 (range_refined_to_singleton) VR[4, 4]
V0423 16:38:39.060000 660 torch/fx/experimental/symbolic_shapes.py:6071] [16/0] _update_var_to_range s3 = VR[4, 4] (update)
I0423 16:38:39.060000 660 torch/fx/experimental/symbolic_shapes.py:6234] [16/0] set_replacement s3 = 4 (find) VR[4, 4]
I0423 16:38:39.062000 660 torch/fx/experimental/symbolic_shapes.py:4734] [16/0] produce_guards
V0423 16:38:39.063000 660 torch/fx/experimental/symbolic_shapes.py:4954] [16/0] track_symint L['x'].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0423 16:38:39.063000 660 torch/fx/experimental/symbolic_shapes.py:4954] [16/0] track_symint L['x'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0423 16:38:39.064000 660 torch/fx/experimental/symbolic_shapes.py:4954] [16/0] track_symint L['x'].stride()[0] 4 None
V0423 16:38:39.064000 660 torch/fx/experimental/symbolic_shapes.py:4954] [16/0] track_symint L['x'].stride()[1] 1 None
V0423 16:38:39.064000 660 torch/fx/experimental/symbolic_shapes.py:4954] [16/0] track_symint L['x'].storage_offset() 0 None
V0423 16:38:39.065000 660 torch/fx/experimental/symbolic_shapes.py:4954] [16/0] track_symint L['y'].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0423 16:38:39.065000 660 torch/fx/experimental/symbolic_shapes.py:4954] [16/0] track_symint L['y'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0423 16:38:39.065000 660 torch/fx/experimental/symbolic_shapes.py:4954] [16/0] track_symint L['y'].stride()[0] 4 None
V0423 16:38:39.066000 660 torch/fx/experimental/symbolic_shapes.py:4954] [16/0] track_symint L['y'].stride()[1] 1 None
V0423 16:38:39.066000 660 torch/fx/experimental/symbolic_shapes.py:4954] [16/0] track_symint L['y'].storage_offset() 0 None
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0] Error while creating guard:
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0] Name: ''
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]     Source: shape_env
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]     Create Function: SHAPE_ENV
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]     Guard Types: None
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]     Code List: None
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]     Object Weakref: None
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]     Guarded Class Weakref: None
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0] Traceback (most recent call last):
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_guards.py", line 357, in create
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]     return self.create_fn(builder, self)
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]     python_code_parts, verbose_code_parts = _get_code_parts(
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]     return output_graph.shape_env.produce_guards_verbose(
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]     raise ConstraintViolationError(
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]   - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]   - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
E0423 16:38:39.067000 660 torch/_guards.py:359] [16/0]   - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
E0423 16:38:39.069000 660 torch/_guards.py:361] [16/0] Created at:
E0423 16:38:39.069000 660 torch/_guards.py:361] [16/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 694, in transform
E0423 16:38:39.069000 660 torch/_guards.py:361] [16/0]     tracer = InstructionTranslator(
E0423 16:38:39.069000 660 torch/_guards.py:361] [16/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3329, in __init__
E0423 16:38:39.069000 660 torch/_guards.py:361] [16/0]     output=OutputGraph(
E0423 16:38:39.069000 660 torch/_guards.py:361] [16/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 358, in __init__
E0423 16:38:39.069000 660 torch/_guards.py:361] [16/0]     self.init_ambient_guards()
E0423 16:38:39.069000 660 torch/_guards.py:361] [16/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 512, in init_ambient_guards
E0423 16:38:39.069000 660 torch/_guards.py:361] [16/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1722, in inner
    raise constraint_violation_error
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1677, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
    return self._torchdynamo_orig_callable(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
    return _compile(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 906, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 2481, in __init__
    guard.create(builder)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_guards.py", line 357, in create
    return self.create_fn(builder, self)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV
    python_code_parts, verbose_code_parts = _get_code_parts(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts
    return output_graph.shape_env.produce_guards_verbose(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.

Suggested fixes:
  d1 = 4
  dy = dx

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 557, in <module>
    ep = export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 360, in export
    return _export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2112, in _export
    ep = _export_for_training(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1975, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 756, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.

Suggested fixes:
  d1 = 4
  dy = dx

对建议修复的期望是,用户可以交互地复制粘贴更改到他们的动态形状规范中,然后成功导出。

最后,还有一些关于规范选项的有用信息:

  • None 是静态行为的一个不错的选择:- dynamic_shapes=None (默认) 导出整个模型为静态。- 在输入级别指定 None 会导出所有张量维度为静态,对于非张量输入也是必需的。- 在维度级别指定 None 会特化该维度,但这已被弃用,建议使用 Dim.STATIC

  • 指定逐维度的整数值也会产生静态行为,并且还会检查提供的示例输入是否与规范匹配。

这些选项在下面的输入和动态形状规范中进行了组合:

inputs = (
    torch.randn(4, 4),
    torch.randn(3, 3),
    16,
    False,
)
dynamic_shapes = {
    "tensor_0": (Dim.AUTO, None),
    "tensor_1": None,
    "int_val": None,
    "bool_val": None,
}

数据依赖错误

在尝试导出模型时,您可能会遇到类似“无法对数据依赖表达式进行守卫”(“Could not guard on data-dependent expression”)或“无法从数据依赖表达式中提取特化整数”(“Could not extract specialized integer from data-dependent expression”)的错误。这些错误之所以存在,是因为 torch.export() 使用 FakeTensors 编译程序,FakeTensors 象征性地表示其对应的真实张量。尽管它们具有等效的象征性属性(例如大小、步长、数据类型),但它们的不同之处在于 FakeTensors 不包含任何数据值。虽然这避免了不必要的内存使用和昂贵的计算,但也意味着导出可能无法开箱即用地编译用户代码中依赖数据值的部分。简而言之,如果编译器需要一个具体的、数据依赖的值才能继续,它就会报错,抱怨该值不可用。

数据依赖值出现在许多地方,常见的来源是诸如 item()tolist()torch.unbind() 等从张量中提取标量值的调用。这些值如何在导出的程序中表示?在 约束/动态形状 部分,我们讨论了分配符号来表示动态输入维度。这里也是同样的情况:我们为程序中出现的每个数据依赖值分配符号。重要的区别在于这些是“无支持(unbacked)”符号,与为输入维度分配的“有支持(backed)”符号形成对比。“有支持/无支持” 的命名法指的是是否存在符号的“提示(hint)”:一个支持符号的具体值,可以告知编译器如何继续。

在输入形状符号(有支持符号)的情况下,这些提示就是提供的示例输入形状,这解释了为什么控制流分支是由示例输入属性决定的。对于数据依赖值,符号是在跟踪期间从 FakeTensor 的“数据”中获取的,因此编译器不知道这些符号将取哪个实际值(提示)。

让我们看看这些在导出的程序中如何体现:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.tolist()
        return b + [a]

inps = (
    torch.tensor(1),
    torch.tensor([2, 3]),
)
ep = export(Foo(), inps)
print(ep)
I0423 16:38:39.085000 660 torch/fx/experimental/symbolic_shapes.py:3334] [17/0] create_env
I0423 16:38:39.090000 660 torch/fx/experimental/symbolic_shapes.py:4276] [17/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # workspace/intermediate_source/torch_export_tutorial.py:618 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.090000 660 torch/fx/experimental/symbolic_shapes.py:1130] [17/0] compute_unbacked_bindings [u0]
I0423 16:38:39.092000 660 torch/fx/experimental/symbolic_shapes.py:4276] [17/0] create_unbacked_symint u1 [-int_oo, int_oo] b = y.tolist()  # workspace/intermediate_source/torch_export_tutorial.py:619 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.093000 660 torch/fx/experimental/symbolic_shapes.py:1130] [17/0] compute_unbacked_bindings [u1]
I0423 16:38:39.095000 660 torch/fx/experimental/symbolic_shapes.py:4276] [17/0] create_unbacked_symint u2 [-int_oo, int_oo] b = y.tolist()  # workspace/intermediate_source/torch_export_tutorial.py:619 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.095000 660 torch/fx/experimental/symbolic_shapes.py:1130] [17/0] compute_unbacked_bindings [u2]
I0423 16:38:39.099000 660 torch/fx/experimental/symbolic_shapes.py:4734] [17/0] produce_guards
V0423 16:38:39.099000 660 torch/fx/experimental/symbolic_shapes.py:4954] [17/0] track_symint L['x'].storage_offset() 0 None
V0423 16:38:39.099000 660 torch/fx/experimental/symbolic_shapes.py:4954] [17/0] track_symint L['y'].size()[0] 2 None
V0423 16:38:39.100000 660 torch/fx/experimental/symbolic_shapes.py:4954] [17/0] track_symint L['y'].stride()[0] 1 None
V0423 16:38:39.100000 660 torch/fx/experimental/symbolic_shapes.py:4954] [17/0] track_symint L['y'].storage_offset() 0 None
I0423 16:38:39.105000 660 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u3 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.107000 660 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u4 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.113000 660 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u5 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.113000 660 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u5]
I0423 16:38:39.114000 660 torch/fx/experimental/symbolic_shapes.py:6234] set_replacement u5 = u0 (rename_unbacked_to) VR[-int_oo, int_oo]
I0423 16:38:39.116000 660 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u6 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.116000 660 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u6]
I0423 16:38:39.116000 660 torch/fx/experimental/symbolic_shapes.py:6234] set_replacement u6 = u1 (rename_unbacked_to) VR[-int_oo, int_oo]
I0423 16:38:39.118000 660 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u7 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.118000 660 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u7]
I0423 16:38:39.119000 660 torch/fx/experimental/symbolic_shapes.py:6234] set_replacement u7 = u2 (rename_unbacked_to) VR[-int_oo, int_oo]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "i64[2]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward, code: b = y.tolist()
            select: "i64[]" = torch.ops.aten.select.int(y, 0, 0)
            item_1: "Sym(u1)" = torch.ops.aten.item.default(select);  select = None
            select_1: "i64[]" = torch.ops.aten.select.int(y, 0, 1);  y = None
            item_2: "Sym(u2)" = torch.ops.aten.item.default(select_1);  select_1 = None
            return (item_1, item_2, item)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item'), target=None)])
Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[-int_oo, int_oo], u2: VR[-int_oo, int_oo], u3: VR[-int_oo, int_oo], u4: VR[-int_oo, int_oo], u5: VR[-int_oo, int_oo], u6: VR[-int_oo, int_oo], u7: VR[-int_oo, int_oo]}

结果是分配并返回了 3 个无支持符号(注意它们以“u”为前缀,而不是输入形状/有支持符号通常使用的“s”):1 个用于 item() 调用,以及 1 个用于 tolist() 调用中 y 的每个元素。注意范围约束字段,这些符号的范围是 [-int_oo, int_oo],而不是为输入形状符号分配的默认范围 [0, int_oo],因为我们没有关于这些值的信息——它们不代表大小,因此不一定具有正值。

守卫,torch._check()

但上面的情况很容易导出,因为这些符号的具体值并未用于任何编译器决策;所有相关的是返回值是无支持符号。本节强调的数据依赖错误是以下情况,其中遇到了数据依赖守卫

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

在这里,我们实际上需要“提示”,即 a 的具体值,以便编译器决定是跟踪 return y + 2 还是 return y * 5 作为输出。因为我们使用 FakeTensors 进行跟踪,所以我们不知道 a // 2 >= 5 实际评估结果是什么,导出就会报错“无法对数据依赖表达式 u0 // 2 >= 5 (unhinted) 进行守卫”。

那么如何导出这个玩具模型呢?与 torch.compile() 不同,导出需要完整的图编译(full graph compilation),我们不能在这里进行图断开(graph break)。以下是一些基本选项:

  1. 手动特化(Manual specialization):我们可以通过选择要跟踪的分支来进行干预,要么通过移除控制流代码只包含特化分支,要么使用 torch.compiler.is_compiling() 在编译时(compile-time)守卫哪些代码被跟踪。

  2. torch.cond():我们可以重写控制流代码以使用 torch.cond(),这样我们就不会特化到某个分支。

虽然这些选项是有效的,但它们有其弊端。选项 1 有时需要对模型代码进行剧烈、侵入性的重写以实现特化,而 torch.cond() 并非处理数据依赖错误的全面系统。正如我们将看到的,还存在不涉及控制流的数据依赖错误。

普遍推荐的方法是首先使用 torch._check() 调用。虽然这些调用看起来仅仅是断言语句,但实际上它们是向编译器告知符号属性的系统。虽然 torch._check() 调用在运行时确实充当断言,但在编译时跟踪时,被检查的表达式会被发送到符号形状子系统进行推理,并且从表达式为真得出的任何符号属性都会被存储为符号属性(前提是系统足够智能能够推断这些属性)。因此,即使无支持符号没有提示,如果我们可以通过 torch._check() 调用传达这些符号通常为真的属性,我们就可以潜在地绕过数据依赖守卫,而无需重写有问题的模型代码。

例如,在上面的模型中,插入 torch._check(a >= 10) 会告诉编译器总是可以返回 y + 2,而插入 torch._check(a == 4) 则会告诉它返回 y * 5。看看当我们重新导出此模型时会发生什么。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 10)
        torch._check(a <= 60)
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

inps = (
    torch.tensor(32),
    torch.randn(4),
)
ep = export(Foo(), inps)
print(ep)
I0423 16:38:39.128000 660 torch/fx/experimental/symbolic_shapes.py:3334] [18/0] create_env
I0423 16:38:39.132000 660 torch/fx/experimental/symbolic_shapes.py:4276] [18/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # workspace/intermediate_source/torch_export_tutorial.py:672 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.133000 660 torch/fx/experimental/symbolic_shapes.py:1130] [18/0] compute_unbacked_bindings [u0]
I0423 16:38:39.135000 660 torch/fx/experimental/symbolic_shapes.py:6630] [18/0] runtime_assert u0 >= 10 [guard added] torch._check(a >= 10)  # workspace/intermediate_source/torch_export_tutorial.py:673 in forward (_dynamo/utils.py:3284 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 10"
V0423 16:38:39.136000 660 torch/fx/experimental/symbolic_shapes.py:6071] [18/0] _update_var_to_range u0 = VR[10, int_oo] (update)
I0423 16:38:39.140000 660 torch/fx/experimental/symbolic_shapes.py:6630] [18/0] runtime_assert u0 <= 60 [guard added] torch._check(a <= 60)  # workspace/intermediate_source/torch_export_tutorial.py:674 in forward (_dynamo/utils.py:3284 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 <= 60"
V0423 16:38:39.141000 660 torch/fx/experimental/symbolic_shapes.py:6071] [18/0] _update_var_to_range u0 = VR[10, 60] (update)
V0423 16:38:39.146000 660 torch/fx/experimental/symbolic_shapes.py:6787] [18/0] eval False == True [statically known]
V0423 16:38:39.149000 660 torch/fx/experimental/symbolic_shapes.py:7018] [18/0] runtime_assert u0 >= 10 == True [statically known]
V0423 16:38:39.150000 660 torch/fx/experimental/symbolic_shapes.py:7018] [18/0] runtime_assert u0 <= 60 == True [statically known]
I0423 16:38:39.153000 660 torch/fx/experimental/symbolic_shapes.py:4734] [18/0] produce_guards
V0423 16:38:39.153000 660 torch/fx/experimental/symbolic_shapes.py:4954] [18/0] track_symint L['x'].storage_offset() 0 None
V0423 16:38:39.153000 660 torch/fx/experimental/symbolic_shapes.py:4954] [18/0] track_symint L['y'].size()[0] 4 None
V0423 16:38:39.154000 660 torch/fx/experimental/symbolic_shapes.py:4954] [18/0] track_symint L['y'].stride()[0] 1 None
V0423 16:38:39.154000 660 torch/fx/experimental/symbolic_shapes.py:4954] [18/0] track_symint L['y'].storage_offset() 0 None
I0423 16:38:39.167000 660 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.167000 660 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u1]
V0423 16:38:39.168000 660 torch/fx/experimental/symbolic_shapes.py:6071] _update_var_to_range u1 = VR[10, 60] (update)
I0423 16:38:39.168000 660 torch/fx/experimental/symbolic_shapes.py:6234] set_replacement u1 = u0 (rename_unbacked_to) VR[10, 60]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[4]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_1: "Sym(u0 >= 10)" = item >= 10
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 10 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 60)" = item <= 60;  item = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 60 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:676 in forward, code: return y + 2
            add: "f32[4]" = torch.ops.aten.add.Tensor(y, 2);  y = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {u0: VR[10, 60], u1: VR[10, 60]}

导出成功,注意范围约束字段显示 u0 的范围为 [10, 60]

那么 torch._check() 调用实际传递了哪些信息呢?这随着符号形状子系统变得更智能而有所不同,但从根本上讲,以下通常是成立的:

  1. 与非数据依赖表达式的相等性:torch._check() 调用可以传递诸如 u0 == s0 + 4u0 == 5 等等相等性关系。

  2. 范围细化(Range refinement):提供符号的下限或上限的调用,例如上面所示。

  3. 围绕更复杂表达式的一些基本推理:插入 torch._check(a < 4) 通常会告诉编译器 a >= 4 是假的。对复杂表达式(如 torch._check(a ** 2 - 3 * a <= 10))的检查通常可以帮助您通过相同的守卫。

如前所述,torch._check() 调用不仅适用于数据依赖的控制流。例如,这里有一个模型,其中插入 torch._check() 有效,而手动特化和 torch.cond() 无效:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps)
except Exception:
    tb.print_exc()
I0423 16:38:39.181000 660 torch/fx/experimental/symbolic_shapes.py:3334] [19/0] create_env
I0423 16:38:39.185000 660 torch/fx/experimental/symbolic_shapes.py:4276] [19/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # workspace/intermediate_source/torch_export_tutorial.py:701 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.186000 660 torch/fx/experimental/symbolic_shapes.py:1130] [19/0] compute_unbacked_bindings [u0]
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0] Data dependent variable 'u0' allocated at:
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/bin/sphinx-build", line 8, in <module>
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     sys.exit(main())
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return make_main(argv)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return make_mode.run_make_mode(argv[1:])
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return make.run_generic_build(args[0])
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return build_main(args + opts)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     self._init_builder()
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     self.events.emit('builder-inited')
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     results.append(listener.handler(self.app, *args))
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     ) = generate_dir_rst(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     intro, title, cost = generate_file_rst(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     output_blocks, time_elapsed = execute_script(script_blocks,
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     output_blocks.append(execute_code_block(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     is_last_expr, mem_max = _exec_and_get_memory(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     mem_max, _ = gallery_conf['call_memory'](
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return 0., func()
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     exec(self.code, self.fake_main.__dict__)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 709, in <module>
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     export(Foo(), inps)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 360, in export
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return _export(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     ep = fn(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return fn(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2112, in _export
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     ep = _export_for_training(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     ep = fn(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return fn(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1975, in _export_for_training
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     export_artifact = export_func(  # type: ignore[operator]
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     gm_torch_level = _export_to_torch_ir(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     gm_torch_level, _ = torch._dynamo.export(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1677, in inner
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     result_traced = opt_f(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return self._call_impl(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return forward_call(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return fn(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return self._call_impl(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return forward_call(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return self._torchdynamo_orig_callable(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 598, in __call__
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return _compile(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     guarded_code = compile_inner(code, one_graph, hooks, transform)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return function(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return _compile_inner(code, one_graph, hooks, transform)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 797, in _compile_inner
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     out_code = transform_code_object(code, transform)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1422, in transform_code_object
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     transformations(instructions, code_options)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 257, in _fn
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return fn(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 715, in transform
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     tracer.run()
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3500, in run
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     super().run()
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     while self.step():
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     self.dispatch_table[inst.opcode](self, inst)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return inner_fn(self, inst)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2168, in CALL_FUNCTION
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     self.call_function(fn, args, {})
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 903, in call_function
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return self.obj.call_method(tx, self.name, args, kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 632, in call_method
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return wrap_fx_proxy(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2302, in wrap_fx_proxy
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2368, in wrap_fx_proxy_cls
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return _wrap_fx_proxy(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2464, in _wrap_fx_proxy
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 3127, in get_fake_value
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     ret_val = wrap_fake_exception(
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2641, in wrap_fake_exception
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return fn()
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 3128, in <lambda>
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     lambda: run_node(tx.output, node, args, kwargs, nnmodule)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 3295, in run_node
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return getattr(args[0], node.target)(*args[1:], **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/utils/_stats.py", line 27, in wrapper
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return fn(*args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1282, in __torch_dispatch__
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return self.dispatch(func, types, args, kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1823, in dispatch
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return self._cached_dispatch_impl(func, types, args, kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1393, in _cached_dispatch_impl
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     output = self._dispatch_impl(func, types, args, kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2397, in _dispatch_impl
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     op_impl_out = op_impl(self, func, *args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_impls.py", line 422, in local_scalar_dense
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     r = fake_mode.shape_env.create_unbacked_symint()
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]     return retlog(fn(*args, **kwargs))
V0423 16:38:39.189000 660 torch/fx/experimental/symbolic_shapes.py:5984] [19/0]
W0423 16:38:39.197000 660 torch/fx/experimental/symbolic_shapes.py:6679] [19/0] failed during evaluate_expr(-u0 > 60, hint=None, size_oblivious=True, forcing_spec=False
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0] failed while running evaluate_expr(*(-u0 > 60, None, False, True), **{})
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0] Traceback (most recent call last):
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]     return retlog(fn(*args, **kwargs))
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]     return self._evaluate_expr(
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]     raise self._make_data_dependent_error(
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0] Caused by: return y[a]  # workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:5278 in meta_select)
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0] For more information, run with TORCH_LOGS="dynamic"
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0] User Stack (most recent call last):
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]   (snipped, see stack below for prefix)
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]     return y[a]
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0]
E0423 16:38:39.197000 660 torch/fx/experimental/recording.py:299] [19/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0] failed while attempting to run meta for aten.select.int
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0] Traceback (most recent call last):
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2427, in _dispatch_impl
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     r = func(*args, **kwargs)
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_ops.py", line 756, in __call__
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     return self._op(*args, **kwargs)
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_meta_registrations.py", line 5278, in meta_select
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 408, in guard_size_oblivious
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     return expr.node.guard_size_oblivious("", 0)
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 588, in guard_size_oblivious
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     r = self.evaluate(size_oblivious=True)
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     return self.shape_env.evaluate_sym_node(self, size_oblivious)
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6655, in evaluate_sym_node
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     return self.evaluate_expr(
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     return retlog(fn(*args, **kwargs))
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     return self._evaluate_expr(
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     raise self._make_data_dependent_error(
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0] Caused by: return y[a]  # workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:5278 in meta_select)
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0] For more information, run with TORCH_LOGS="dynamic"
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0] User Stack (most recent call last):
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   (snipped, see stack below for prefix)
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]     return y[a]
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0]
E0423 16:38:39.199000 660 torch/_subclasses/fake_tensor.py:2431] [19/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 709, in <module>
    export(Foo(), inps)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 360, in export
    return _export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2112, in _export
    ep = _export_for_training(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1975, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1677, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
    raise e.with_traceback(None) from None
torch._dynamo.exc.UserError: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)

Caused by: return y[a]  # workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:5278 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
    return y[a]

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.ac.cn/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
    return y[a]

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

这里是一个场景,仅仅为了防止操作失败,需要插入 torch._check()。导出调用将因“无法对数据依赖表达式 -u0 > 60 进行守卫”而失败,这意味着编译器不知道这是否是一个有效的索引操作——即 x 的值是否超出 y 的边界。在这里,手动特化太具有限制性,而 torch.cond() 没有用武之地。相反,告知编译器关于 u0 的范围就足够了:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 0)
        torch._check(a < y.shape[0])
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps)
print(ep)
I0423 16:38:39.213000 660 torch/fx/experimental/symbolic_shapes.py:3334] [20/0] create_env
I0423 16:38:39.218000 660 torch/fx/experimental/symbolic_shapes.py:4276] [20/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # workspace/intermediate_source/torch_export_tutorial.py:721 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.219000 660 torch/fx/experimental/symbolic_shapes.py:1130] [20/0] compute_unbacked_bindings [u0]
I0423 16:38:39.221000 660 torch/fx/experimental/symbolic_shapes.py:6630] [20/0] runtime_assert u0 >= 0 [guard added] torch._check(a >= 0)  # workspace/intermediate_source/torch_export_tutorial.py:722 in forward (_dynamo/utils.py:3284 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V0423 16:38:39.222000 660 torch/fx/experimental/symbolic_shapes.py:6071] [20/0] _update_var_to_range u0 = VR[0, int_oo] (update)
I0423 16:38:39.226000 660 torch/fx/experimental/symbolic_shapes.py:6630] [20/0] runtime_assert u0 < 60 [guard added] torch._check(a < y.shape[0])  # workspace/intermediate_source/torch_export_tutorial.py:723 in forward (_dynamo/utils.py:3284 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 < 60"
V0423 16:38:39.227000 660 torch/fx/experimental/symbolic_shapes.py:6071] [20/0] _update_var_to_range u0 = VR[0, 59] (update)
V0423 16:38:39.230000 660 torch/fx/experimental/symbolic_shapes.py:6787] [20/0] eval size_oblivious(-u0 > 60) == False [statically known]
V0423 16:38:39.231000 660 torch/fx/experimental/symbolic_shapes.py:6787] [20/0] eval size_oblivious(u0 >= 60) == False [statically known]
V0423 16:38:39.231000 660 torch/fx/experimental/symbolic_shapes.py:6787] [20/0] eval False == True [statically known]
V0423 16:38:39.234000 660 torch/fx/experimental/symbolic_shapes.py:7018] [20/0] runtime_assert u0 >= 0 == True [statically known]
V0423 16:38:39.236000 660 torch/fx/experimental/symbolic_shapes.py:7018] [20/0] runtime_assert u0 <= 59 == True [statically known]
V0423 16:38:39.237000 660 torch/fx/experimental/symbolic_shapes.py:7018] [20/0] runtime_assert u0 < 60 == True [statically known]
I0423 16:38:39.239000 660 torch/fx/experimental/symbolic_shapes.py:4734] [20/0] produce_guards
V0423 16:38:39.240000 660 torch/fx/experimental/symbolic_shapes.py:4954] [20/0] track_symint L['x'].storage_offset() 0 None
V0423 16:38:39.240000 660 torch/fx/experimental/symbolic_shapes.py:4954] [20/0] track_symint L['y'].size()[0] 60 None
V0423 16:38:39.240000 660 torch/fx/experimental/symbolic_shapes.py:4954] [20/0] track_symint L['y'].stride()[0] 1 None
V0423 16:38:39.241000 660 torch/fx/experimental/symbolic_shapes.py:4954] [20/0] track_symint L['y'].storage_offset() 0 None
I0423 16:38:39.254000 660 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.254000 660 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u1]
V0423 16:38:39.255000 660 torch/fx/experimental/symbolic_shapes.py:6071] _update_var_to_range u1 = VR[0, 59] (update)
I0423 16:38:39.255000 660 torch/fx/experimental/symbolic_shapes.py:6234] set_replacement u1 = u0 (rename_unbacked_to) VR[0, 59]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[60]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_1: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 59)" = item <= 59
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 59 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

             #
            lt_1: "Sym(u0 < 60)" = item < 60
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u0 < 60 on node 'lt_1'");  lt_1 = _assert_scalar_default_2 = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:724 in forward, code: return y[a]
            select: "f32[]" = torch.ops.aten.select.int(y, 0, item);  y = item = None
            return (select,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='select'), target=None)])
Range constraints: {u0: VR[0, 59], u1: VR[0, 59]}

特化值

另一类数据依赖错误发生在程序在跟踪期间尝试提取具体的、数据依赖的整数/浮点值时。这看起来像是“无法从数据依赖表达式中提取特化整数”,类似于前一类错误——如果这些错误发生在尝试评估具体整数/浮点值时,那么评估具体布尔值时就会出现数据依赖守卫错误。

此错误通常发生在对数据依赖表达式进行显式或隐式 int() 转换时。例如,此列表推导式中的 range() 调用隐式地对列表的大小进行了 int() 转换:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = torch.cat([y for y in range(a)], dim=0)
        return b + int(a)

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps, strict=False)
except Exception:
    tb.print_exc()
I0423 16:38:39.271000 660 torch/fx/experimental/symbolic_shapes.py:3334] create_env
I0423 16:38:39.276000 660 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.276000 660 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u0]
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984] Data dependent variable 'u0' allocated at:
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/bin/sphinx-build", line 8, in <module>
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     sys.exit(main())
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return make_main(argv)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return make_mode.run_make_mode(argv[1:])
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return make.run_generic_build(args[0])
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return build_main(args + opts)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     self._init_builder()
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     self.events.emit('builder-inited')
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     results.append(listener.handler(self.app, *args))
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     ) = generate_dir_rst(
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     intro, title, cost = generate_file_rst(
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     output_blocks, time_elapsed = execute_script(script_blocks,
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     output_blocks.append(execute_code_block(
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     is_last_expr, mem_max = _exec_and_get_memory(
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     mem_max, _ = gallery_conf['call_memory'](
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return 0., func()
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     exec(self.code, self.fake_main.__dict__)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     export(Foo(), inps, strict=False)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 360, in export
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return _export(
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     ep = fn(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return fn(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2112, in _export
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     ep = _export_for_training(
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     ep = fn(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return fn(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1975, in _export_for_training
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     export_artifact = export_func(  # type: ignore[operator]
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1910, in _non_strict_export
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     aten_export_artifact = _to_aten_func(  # type: ignore[operator]
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1696, in _export_to_aten_ir_make_fx
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     gm, graph_signature = transform(_make_fx_helper)(
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1840, in _aot_export_non_strict
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1616, in _make_fx_helper
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     gm = make_fx(
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2240, in wrapped
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return make_fx_tracer.trace(f, *args)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2178, in trace
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return self._trace_inner(f, *args)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2149, in _trace_inner
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     t = dispatch_trace(
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_compile.py", line 51, in inner
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return disable_fn(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return fn(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1174, in dispatch_trace
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1738, in trace
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     res = super().trace(root, concrete_args)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return fn(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 838, in trace
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     (self.create_arg(fn(*args)),),
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1229, in wrapped
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     out = f(*tensors)  # type:ignore[call-arg]
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "<string>", line 1, in <lambda>
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1520, in wrapped_fn
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return tuple(flat_fn(*args))
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     tree_out = fn(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 903, in functional_call
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     out = mod(*args[params_len:], **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 813, in module_call_wrapper
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return self.call_module(mod, forward, args, kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1808, in call_module
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return Tracer.call_module(self, m, forward, args, kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 531, in call_module
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     ret_val = forward(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 806, in forward
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return _orig_module_call(mod, *args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return self._call_impl(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return forward_call(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1824, in forward
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     tree_out = mod(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 813, in module_call_wrapper
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return self.call_module(mod, forward, args, kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1808, in call_module
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return Tracer.call_module(self, m, forward, args, kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 531, in call_module
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     ret_val = forward(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 806, in forward
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return _orig_module_call(mod, *args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return self._call_impl(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return forward_call(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 747, in forward
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     a = x.item()
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1277, in __torch_function__
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return func(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1324, in __torch_function__
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return func(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_export/non_strict_utils.py", line 683, in __torch_function__
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return func(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_ops.py", line 875, in handler
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return torch._library.utils.handle_dispatch_mode(
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_library/utils.py", line 296, in handle_dispatch_mode
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/utils/_stats.py", line 27, in wrapper
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return fn(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1379, in __torch_dispatch__
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return proxy_call(self, func, self.pre_dispatch, args, kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 914, in proxy_call
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     out = func(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_ops.py", line 756, in __call__
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return self._op(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/utils/_stats.py", line 27, in wrapper
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return fn(*args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1282, in __torch_dispatch__
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return self.dispatch(func, types, args, kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1823, in dispatch
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return self._cached_dispatch_impl(func, types, args, kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1393, in _cached_dispatch_impl
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     output = self._dispatch_impl(func, types, args, kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2397, in _dispatch_impl
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     op_impl_out = op_impl(self, func, *args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_subclasses/fake_impls.py", line 422, in local_scalar_dense
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     r = fake_mode.shape_env.create_unbacked_symint()
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]     return retlog(fn(*args, **kwargs))
V0423 16:38:39.277000 660 torch/fx/experimental/symbolic_shapes.py:5984]
W0423 16:38:39.285000 660 torch/fx/experimental/symbolic_shapes.py:6679] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299] failed while running evaluate_expr(*(u0, None, False, False), **{})
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299] Traceback (most recent call last):
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299]     return retlog(fn(*args, **kwargs))
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299]     return self._evaluate_expr(
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299]   File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299]     raise self._make_data_dependent_error(
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299]
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299] Caused by: (workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299] For more information, run with TORCH_LOGS="dynamic"
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299]
E0423 16:38:39.286000 660 torch/fx/experimental/recording.py:299] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1




def forward(self, arg0_1: "i64[]", arg1_1: "f32[60]"):
     # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:747 in forward, code: a = x.item()
    item: "Sym(u0)" = torch.ops.aten.item.default(arg0_1);  arg0_1 = item = None

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
    export(Foo(), inps, strict=False)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 360, in export
    return _export(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2112, in _export
    ep = _export_for_training(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1092, in wrapper
    raise e
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1065, in wrapper
    ep = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1975, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1910, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1696, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1840, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1616, in _make_fx_helper
    gm = make_fx(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2240, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2178, in trace
    return self._trace_inner(f, *args)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2149, in _trace_inner
    t = dispatch_trace(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1174, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1738, in trace
    res = super().trace(root, concrete_args)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 838, in trace
    (self.create_arg(fn(*args)),),
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1229, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1520, in wrapped_fn
    return tuple(flat_fn(*args))
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 903, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 813, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1808, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 531, in call_module
    ret_val = forward(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 806, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1824, in forward
    tree_out = mod(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 813, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1808, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 531, in call_module
    ret_val = forward(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 806, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 748, in forward
    b = torch.cat([y for y in range(a)], dim=0)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/__init__.py", line 431, in __index__
    return self.node.int_()
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 466, in int_
    return self.guard_int("", 0)  # NB: uses Python backtrace
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 516, in guard_int
    r = self.evaluate()
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6655, in evaluate_sym_node
    return self.evaluate_expr(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr
    return self._evaluate_expr(
  File "/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Caused by: (workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

对于这些错误,您有一些基本选项:

  1. 避免不必要的 int() 转换调用,在此示例中是返回语句中的 int(a)

  2. 使用 torch._check() 调用;不幸的是,在这种情况下您可能只能通过特化来解决(使用 torch._check(a == 60))。

  3. 在更高层级重写有问题代码。例如,列表推导式在语义上是一个 repeat() 操作,它不涉及 int() 转换。以下重写避免了数据依赖错误:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.unsqueeze(0).repeat(a, 1)
        return b + a

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps, strict=False)
print(ep)
I0423 16:38:39.296000 660 torch/fx/experimental/symbolic_shapes.py:3334] create_env
I0423 16:38:39.301000 660 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense)
I0423 16:38:39.302000 660 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u0]
I0423 16:38:39.305000 660 torch/fx/experimental/symbolic_shapes.py:6630] runtime_assert u0 >= 0 [guard added] (_refs/__init__.py:4796 in new_empty), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V0423 16:38:39.306000 660 torch/fx/experimental/symbolic_shapes.py:6071] _update_var_to_range u0 = VR[0, int_oo] (update)
V0423 16:38:39.308000 660 torch/fx/experimental/symbolic_shapes.py:6787] eval size_oblivious(Eq(u0, 0)) == False [statically known]
V0423 16:38:39.311000 660 torch/fx/experimental/symbolic_shapes.py:6787] eval size_oblivious(Eq(u0, 1)) == False [statically known]
V0423 16:38:39.312000 660 torch/fx/experimental/symbolic_shapes.py:7018] runtime_assert True == True [statically known]
I0423 16:38:39.315000 660 torch/fx/experimental/symbolic_shapes.py:4734] produce_guards
V0423 16:38:39.315000 660 torch/fx/experimental/symbolic_shapes.py:4954] track_symint L['args'][0][0].storage_offset() 0 None
V0423 16:38:39.316000 660 torch/fx/experimental/symbolic_shapes.py:4954] track_symint L['args'][0][1].size()[0] 60 None
V0423 16:38:39.316000 660 torch/fx/experimental/symbolic_shapes.py:4954] track_symint L['args'][0][1].stride()[0] 1 None
V0423 16:38:39.316000 660 torch/fx/experimental/symbolic_shapes.py:4954] track_symint L['args'][0][1].storage_offset() 0 None
V0423 16:38:39.318000 660 torch/fx/experimental/symbolic_shapes.py:7018] runtime_assert u0 >= 0 == True [statically known]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[60]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

             #
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item);  sym_constrain_range_for_size_default = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
            ge: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'");  ge = _assert_scalar_default = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:770 in forward, code: b = y.unsqueeze(0).repeat(a, 1)
            unsqueeze: "f32[1, 60]" = torch.ops.aten.unsqueeze.default(y, 0);  y = None
            repeat: "f32[u0, 60]" = torch.ops.aten.repeat.default(unsqueeze, [item, 1]);  unsqueeze = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:771 in forward, code: return b + a
            add: "f32[u0, 60]" = torch.ops.aten.add.Tensor(repeat, item);  repeat = item = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {u0: VR[0, int_oo]}

数据依赖错误可能更为复杂,您的工具包中还有更多选项来处理它们:torch._check_is_size()guard_size_oblivious() 或真实张量跟踪(real-tensor tracing)等。有关更深入的指南,请参阅 导出编程模型处理 GuardOnDataDependentSymNode 错误

自定义操作 (Custom Ops)

torch.export 可以导出包含自定义操作的 PyTorch 程序。有关如何在 C++ 或 Python 中编写自定义操作,请参阅此页面

以下是一个在 Python 中注册自定义操作以供 torch.export 使用的示例。需要注意的重要一点是,自定义操作必须具有 FakeTensor 内核

@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(x: torch.Tensor) -> torch.Tensor:
    print("custom_op called!")
    return torch.relu(x)

@custom_op.register_fake
def custom_op_meta(x):
    # Returns an empty tensor with the same shape as the expected output
    return torch.empty_like(x)

以下是导出包含自定义操作的程序的示例。

class CustomOpExample(torch.nn.Module):
    def forward(self, x):
        x = torch.sin(x)
        x = torch.ops.my_custom_library.custom_op(x)
        x = torch.cos(x)
        return x

exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
print(exported_custom_op_example)
print(exported_custom_op_example.module()(torch.randn(3, 3)))
I0423 16:38:39.333000 660 torch/fx/experimental/symbolic_shapes.py:3334] [21/0] create_env
I0423 16:38:39.343000 660 torch/fx/experimental/symbolic_shapes.py:4734] [21/0] produce_guards
V0423 16:38:39.343000 660 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['x'].size()[0] 3 None
V0423 16:38:39.343000 660 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['x'].size()[1] 3 None
V0423 16:38:39.343000 660 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['x'].stride()[0] 3 None
V0423 16:38:39.344000 660 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['x'].stride()[1] 1 None
V0423 16:38:39.344000 660 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['x'].storage_offset() 0 None
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:812 in forward, code: x = torch.sin(x)
            sin: "f32[3, 3]" = torch.ops.aten.sin.default(x);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:813 in forward, code: x = torch.ops.my_custom_library.custom_op(x)
            custom_op: "f32[3, 3]" = torch.ops.my_custom_library.custom_op.default(sin);  sin = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:814 in forward, code: x = torch.cos(x)
            cos: "f32[3, 3]" = torch.ops.aten.cos.default(custom_op);  custom_op = None
            return (cos,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}

custom_op called!
tensor([[0.5499, 0.6889, 0.7180],
        [0.5413, 1.0000, 1.0000],
        [0.8332, 0.5524, 1.0000]])

注意,在 ExportedProgram 中,自定义操作包含在图中。

IR/分解 (IR/Decompositions)

torch.export 生成的图返回一个仅包含 ATen 操作符 的图,ATen 操作符是 PyTorch 中计算的基本单元。由于有超过 3000 个 ATen 操作符,导出提供了一种基于特定特性缩小图中使用的操作符集的方法,从而创建不同的 IR。

默认情况下,导出生成最通用的 IR,其中包含所有 ATen 操作符,包括函数式(functional)和非函数式(non-functional)操作符。函数式操作符是指不包含任何变动(mutations)或别名(aliasing)输入的操作符。您可以在此处找到所有 ATen 操作符的列表,并且可以通过检查 op._schema.is_mutable 来检查操作符是否是函数式的,例如:

print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
False
True

此通用 IR 可用于在 eager PyTorch Autograd 中进行训练。通过 PyTorch 2.5 中引入的 API torch.export.export_for_training 可以更明确地生成此 IR,但自 PyTorch 2.6 起,调用 torch.export.export 也应生成相同的图。

class DecompExample(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export_for_training(DecompExample(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph)
I0423 16:38:39.368000 660 torch/fx/experimental/symbolic_shapes.py:3334] [22/0] create_env
I0423 16:38:39.396000 660 torch/fx/experimental/symbolic_shapes.py:4734] [22/0] produce_guards
V0423 16:38:39.397000 660 torch/fx/experimental/symbolic_shapes.py:4954] [22/0] track_symint L['x'].size()[0] 1 None
V0423 16:38:39.397000 660 torch/fx/experimental/symbolic_shapes.py:4954] [22/0] track_symint L['x'].size()[1] 1 None
V0423 16:38:39.398000 660 torch/fx/experimental/symbolic_shapes.py:4954] [22/0] track_symint L['x'].size()[2] 3 None
V0423 16:38:39.398000 660 torch/fx/experimental/symbolic_shapes.py:4954] [22/0] track_symint L['x'].size()[3] 3 None
V0423 16:38:39.398000 660 torch/fx/experimental/symbolic_shapes.py:4954] [22/0] track_symint L['x'].stride()[0] 9 None
V0423 16:38:39.398000 660 torch/fx/experimental/symbolic_shapes.py:4954] [22/0] track_symint L['x'].stride()[1] 9 None
V0423 16:38:39.399000 660 torch/fx/experimental/symbolic_shapes.py:4954] [22/0] track_symint L['x'].stride()[2] 3 None
V0423 16:38:39.399000 660 torch/fx/experimental/symbolic_shapes.py:4954] [22/0] track_symint L['x'].stride()[3] 1 None
V0423 16:38:39.399000 660 torch/fx/experimental/symbolic_shapes.py:4954] [22/0] track_symint L['x'].storage_offset() 0 None
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05, True), kwargs = {})
    return (batch_norm,)

然后我们可以通过 API run_decompositions 将此导出的程序降级到仅包含函数式 ATen 操作符的操作符集,该 API 将 ATen 操作符分解为分解表中指定的操作符,并对图进行函数化(functionalizes)。通过指定一个空集,我们仅执行函数化,而不进行任何额外的分解。这将生成一个包含约 2000 个操作符(而不是上面的 3000 个操作符)的 IR,非常适合推理场景。

graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

正如我们所见,之前可变的操作符 torch.ops.aten.add_.default 现在已替换为函数式操作符 torch.ops.aten.add.default

我们还可以将导出的程序进一步降级到仅包含 Core ATen 操作符集,这是一个仅包含约 180 个操作符的集合。此 IR 对于不想重新实现所有 ATen 操作符的后端是最佳选择。

from torch.export import default_decompositions

core_aten_decomp_table = default_decompositions()
core_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table)
print(core_aten_ep.graph)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%convolution, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

我们现在看到 torch.ops.aten.conv2d.default 已被分解为 torch.ops.aten.convolution.default。这是因为 convolution 是一个更“核心”的操作符,因为像 conv1dconv2d 这样的操作可以使用相同的操作符实现。

我们还可以指定自己的分解行为。

my_decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%mul, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

注意,torch.ops.aten.conv2d.default 没有被分解为 torch.ops.aten.convolution.default,而是被分解为 torch.ops.aten.convolution.defaulttorch.ops.aten.mul.Tensor,这与我们的自定义分解规则匹配。

ExportDB

torch.export 仅导出 PyTorch 程序中的单个计算图(single computation graph)。由于此要求,有些 Python 或 PyTorch 特性与 torch.export 不兼容,这将要求用户重写部分模型代码。我们在本教程前面见过示例——例如,使用 cond 重写 if 语句。

ExportDB 是记录 torch.export 支持和不支持的 Python/PyTorch 特性的标准参考。它本质上是一个程序示例列表,每个示例代表一个特定 Python/PyTorch 特性的使用及其与 torch.export 的交互。示例也按类别标记,以便更轻松地搜索。

例如,让我们使用 ExportDB 更好地了解 cond 操作符中的谓词(predicate)是如何工作的。我们可以查看名为 cond_predicate 的示例,该示例带有 torch.cond 标签。示例代码如下:

def cond_predicate(x):
    """
    The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
    - ``torch.Tensor`` with a single element
    - boolean expression
    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """
    pred = x.dim() > 2 and x.shape[2] > 10
    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

更普遍地说,在发生以下情况之一时,可以将 ExportDB 用作参考:

  1. 在尝试 torch.export 之前,您提前知道您的模型使用了一些棘手的 Python/PyTorch 特性,并且您想知道 torch.export 是否涵盖该特性。

  2. 在尝试 torch.export 时,出现失败且不清楚如何解决。

ExportDB 并非详尽无遗,但旨在涵盖典型 PyTorch 代码中发现的所有用例。如果发现应添加到 ExportDB 或应由 torch.export 支持的重要 Python/PyTorch 特性,请随时联系我们。

运行导出的程序

由于 torch.export 仅是一种图捕获机制,直接运行由 torch.export 生成的 artifact 将等同于运行 eager 模块。为了优化导出的程序的执行,我们可以通过 torch.compileAOTInductorTensorRT 等后端将此导出的 artifact 传递给它们。

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.linear(x)
        return x

inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))

# Run it eagerly
res = ep.module()(inp)
print(res)

# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
I0423 16:38:40.048000 660 torch/fx/experimental/symbolic_shapes.py:3334] [23/0] create_env
I0423 16:38:40.062000 660 torch/fx/experimental/symbolic_shapes.py:4734] [23/0] produce_guards
V0423 16:38:40.063000 660 torch/fx/experimental/symbolic_shapes.py:4954] [23/0] track_symint L['x'].size()[0] 2 None
V0423 16:38:40.063000 660 torch/fx/experimental/symbolic_shapes.py:4954] [23/0] track_symint L['x'].size()[1] 3 None
V0423 16:38:40.063000 660 torch/fx/experimental/symbolic_shapes.py:4954] [23/0] track_symint L['x'].stride()[0] 3 None
V0423 16:38:40.064000 660 torch/fx/experimental/symbolic_shapes.py:4954] [23/0] track_symint L['x'].stride()[1] 1 None
V0423 16:38:40.064000 660 torch/fx/experimental/symbolic_shapes.py:4954] [23/0] track_symint L['x'].storage_offset() 0 None
tensor([[ 0.4830, -0.5149,  0.3888],
        [-0.9247,  0.8408, -0.2184]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
I0423 16:38:40.087000 660 torch/fx/experimental/symbolic_shapes.py:3334] [24/0] create_env
I0423 16:38:40.364000 660 torch/fx/experimental/symbolic_shapes.py:4734] [24/0] produce_guards
I0423 16:38:40.374000 660 torch/fx/experimental/symbolic_shapes.py:4734] [24/0] produce_guards
V0423 16:38:40.374000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['x'].size()[0] 2 None
V0423 16:38:40.374000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['x'].size()[1] 3 None
V0423 16:38:40.375000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['x'].stride()[0] 3 None
V0423 16:38:40.375000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['x'].stride()[1] 1 None
V0423 16:38:40.375000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['x'].storage_offset() 0 None
V0423 16:38:40.376000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[0] 3 None
V0423 16:38:40.376000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[1] 3 None
V0423 16:38:40.376000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[0] 3 None
V0423 16:38:40.377000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[1] 1 None
V0423 16:38:40.377000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].storage_offset() 0 None
V0423 16:38:40.378000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].size()[0] 3 None
V0423 16:38:40.378000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].stride()[0] 1 None
V0423 16:38:40.378000 660 torch/fx/experimental/symbolic_shapes.py:4954] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].storage_offset() 0 None
V0423 16:38:40.379000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['x'].size()[0] == 2
V0423 16:38:40.379000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['x'].size()[1] == 3
V0423 16:38:40.379000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['x'].stride()[0] == 3
V0423 16:38:40.380000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['x'].stride()[1] == 1
V0423 16:38:40.380000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['x'].storage_offset() == 0
V0423 16:38:40.380000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[0] == 3
V0423 16:38:40.381000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[1] == 3
V0423 16:38:40.381000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[0] == 3
V0423 16:38:40.382000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[1] == 1
V0423 16:38:40.382000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].storage_offset() == 0
V0423 16:38:40.382000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].size()[0] == 3
V0423 16:38:40.383000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].stride()[0] == 1
V0423 16:38:40.383000 660 torch/fx/experimental/symbolic_shapes.py:5156] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].storage_offset() == 0
tensor([[ 0.4830, -0.5149,  0.3888],
        [-0.9247,  0.8408, -0.2184]], device='cuda:0',
       grad_fn=<CompiledFunctionBackward>)
import torch._inductor

# Note: these APIs are subject to change
# Compile the exported program to a PT2 archive using ``AOTInductor``
with torch.no_grad():
    pt2_path = torch._inductor.aoti_compile_and_package(ep)

# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.ac.cn/docs/main/torch.compiler_aot_inductor.html
aoti_compiled = torch._inductor.aoti_load_package(pt2_path)
res = aoti_compiled(inp)

结论

我们介绍了 torch.export,这是 PyTorch 2.X 从 PyTorch 程序中导出单个计算图的新方式。我们特别展示了为了导出图需要进行的一些代码修改和注意事项(控制流操作、约束等)。

脚本总运行时间: (0 分钟 2.386 秒)

图集由 Sphinx-Gallery 生成


评价本教程

© 版权所有 2024, PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源