注意
单击此处下载完整的示例代码
torch.export 教程¶
创建于:2023 年 10 月 02 日 | 最近更新:2025 年 1 月 27 日 | 最近验证:2024 年 11 月 05 日
作者: William Wen、Zhengxu Chen、Angela Yi、Pian Pawakapan
警告
torch.export
及其相关功能处于原型状态,可能会发生向后不兼容的更改。本教程提供了截至 PyTorch 2.5 的 torch.export
用法的快照。
torch.export()
是 PyTorch 2.X 将 PyTorch 模型导出为标准化模型表示形式的方式,旨在在不同的(即无 Python)环境中运行。官方文档可以在此处找到。
在本教程中,您将学习如何使用 torch.export()
从 PyTorch 程序中提取 ExportedProgram
(即单图表示形式)。我们还将详细介绍您可能需要进行的一些考虑/修改,以使您的模型与 torch.export
兼容。
目录
基本用法¶
torch.export
通过跟踪目标函数(给定示例输入)从 PyTorch 程序中提取单图表示形式。torch.export.export()
是 torch.export
的主要入口点。
在本教程中,torch.export
和 torch.export.export()
实际上是同义词,尽管 torch.export
通常指 PyTorch 2.X 导出过程,而 torch.export.export()
通常指实际的函数调用。
torch.export.export()
的签名是
export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram
torch.export.export()
从调用 mod(*args, **kwargs)
中跟踪张量计算图,并将其包装在 ExportedProgram
中,该程序可以序列化或稍后使用不同的输入执行。要执行 ExportedProgram
,我们可以对其调用 .module()
以返回 torch.nn.Module
,它是可调用的,就像原始程序一样。我们将在本教程后面详细介绍 dynamic_shapes
参数。
import torch
from torch.export import export
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x, y):
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))
<class 'torch.export.exported_program.ExportedProgram'>
tensor([[0.8632, 0.8407, 0.0407, 0.0000, 0.4132, 0.0000, 0.0000, 0.1538, 0.6111,
0.0000],
[0.0000, 0.0000, 0.0273, 0.8057, 0.0000, 1.0162, 0.8042, 0.0000, 0.2660,
0.0000],
[0.9481, 0.1396, 1.0225, 0.9563, 0.5832, 0.2546, 0.4095, 0.4591, 0.0000,
2.0053],
[1.1300, 0.4873, 0.0000, 0.9663, 1.2275, 1.4015, 0.0000, 0.9444, 0.0000,
0.0000],
[0.0000, 0.8724, 1.1648, 0.6867, 0.0000, 0.2833, 0.3202, 0.5848, 0.0000,
0.0833],
[1.1311, 0.1324, 0.0000, 1.7842, 0.0000, 0.3474, 0.9916, 0.3571, 0.0000,
0.0000],
[1.4348, 1.0570, 0.1771, 0.0000, 0.9510, 0.0000, 0.0000, 0.0000, 0.2618,
0.0000],
[0.8853, 0.0000, 0.0000, 0.4486, 0.0000, 0.0000, 0.5841, 0.7604, 0.0000,
0.0000]], grad_fn=<ReluBackward0>)
让我们回顾一下 ExportedProgram
的一些感兴趣的属性。
graph
属性是从我们导出的函数跟踪的 FX 图,即所有 PyTorch 操作的计算图。FX 图采用“ATen IR”,这意味着它仅包含“ATen 级别”的操作。
graph_signature
属性更详细地描述了导出图中的输入和输出节点,描述了哪些是参数、缓冲区、用户输入或用户输出。
range_constraints
属性将在后面介绍。
print(exported_mod)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_lin_weight: "f32[10, 100]", p_lin_bias: "f32[10]", x: "f32[8, 100]", y: "f32[8, 100]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
add: "f32[8, 100]" = torch.ops.aten.add.Tensor(x, y); x = y = None
linear: "f32[8, 10]" = torch.ops.aten.linear.default(add, p_lin_weight, p_lin_bias); add = p_lin_weight = p_lin_bias = None
relu_: "f32[8, 10]" = torch.ops.aten.relu_.default(linear); linear = None
return (relu_,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lin_weight'), target='lin.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lin_bias'), target='lin.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='relu_'), target=None)])
Range constraints: {}
有关更多详细信息,请参阅 torch.export
文档。
图中断¶
尽管 torch.export
与 torch.compile
共享组件,但 torch.export
的主要限制(尤其是与 torch.compile
相比)是它不支持图中断。这是因为处理图中断涉及使用默认 Python 评估来解释不受支持的操作,这与导出用例不兼容。因此,为了使您的模型代码与 torch.export
兼容,您需要修改代码以删除图中断。
在以下情况下,图中断是必要的
数据相关的控制流
class Bad1(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return torch.sin(x)
return torch.cos(x)
import traceback as tb
try:
export(Bad1(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 122, in <module>
export(Bad1(), (torch.randn(3, 3),))
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
return _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
result_traced = opt_f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
super().run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 640, in inner
raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.ac.cn/docs/main/generated/exportdb/index.html#cond-operands
from user code:
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward
if x.sum() > 0:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
使用
.data
访问张量数据
class Bad2(torch.nn.Module):
def forward(self, x):
x.data[0, 0] = 3
return x
try:
export(Bad2(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
调用不受支持的函数(例如许多内置函数)
class Bad3(torch.nn.Module):
def forward(self, x):
x = x + 1
return x + id(x)
try:
export(Bad3(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 148, in <module>
export(Bad3(), (torch.randn(3, 3),))
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
return _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
result_traced = opt_f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
super().run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1004, in call_function
return handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 843, in builtin_dispatch
rv = handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 772, in call_self_handler
result = self_handler(tx, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1936, in call_id
return tensor_variable.call_id(tx)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/tensor.py", line 469, in call_id
unimplemented("call_id not supported for sourceless TensorVariable")
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 317, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: call_id not supported for sourceless TensorVariable
from user code:
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 145, in forward
return x + id(x)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
非严格导出¶
为了跟踪程序,torch.export
默认使用 TorchDynamo(字节码分析引擎)来符号化分析 Python 代码并根据结果构建图。此分析使 torch.export
能够提供更强的安全保证,但并非所有 Python 代码都受支持,从而导致这些图中断。
为了解决这个问题,在 PyTorch 2.3 中,我们引入了一种新的导出模式,称为非严格模式,我们在其中使用 Python 解释器跟踪程序,就像在 eager 模式下执行一样,从而允许我们跳过不受支持的 Python 功能。这通过添加 strict=False
标志来完成。
查看一些导致图中断的先前示例
调用不受支持的函数(例如许多内置函数)会跟踪
通过,但在这种情况下,id(x)
会在图中专门化为常量整数。这是因为 id(x)
不是张量运算,因此该运算不会记录在图中。
class Bad3(torch.nn.Module):
def forward(self, x):
x = x + 1
return x + id(x)
bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:179 in forward, code: x = x + 1
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, 1); x = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:180 in forward, code: return x + id(x)
add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, 139872432249072); add = None
return (add_1,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {}
tensor([[1.3987e+14, 1.3987e+14, 1.3987e+14],
[1.3987e+14, 1.3987e+14, 1.3987e+14],
[1.3987e+14, 1.3987e+14, 1.3987e+14]])
但是,仍然有一些功能需要重写为原始模块
控制流操作¶
torch.export
实际上确实支持数据相关的控制流。但是这些需要使用控制流操作来表达。例如,我们可以使用 cond
操作修复上面的控制流示例,如下所示
class Bad1Fixed(torch.nn.Module):
def forward(self, x):
def true_fn(x):
return torch.sin(x)
def false_fn(x):
return torch.cos(x)
return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed)
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:205 in forward, code: return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
# File: /usr/local/lib/python3.10/dist-packages/torch/_higher_order_ops/cond.py:144 in cond, code: return cond_op(pred, true_fn, false_fn, operands)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x]); gt = true_graph_0 = false_graph_0 = x = None
getitem: "f32[3, 3]" = cond[0]; cond = None
return (getitem,)
class true_graph_0(torch.nn.Module):
def forward(self, x: "f32[3, 3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:202 in true_fn, code: return torch.sin(x)
sin: "f32[3, 3]" = torch.ops.aten.sin.default(x); x = None
return (sin,)
class false_graph_0(torch.nn.Module):
def forward(self, x: "f32[3, 3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:204 in false_fn, code: return torch.cos(x)
cos: "f32[3, 3]" = torch.ops.aten.cos.default(x); x = None
return (cos,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}
tensor([[0.8415, 0.8415, 0.8415],
[0.8415, 0.8415, 0.8415],
[0.8415, 0.8415, 0.8415]])
tensor([[0.5403, 0.5403, 0.5403],
[0.5403, 0.5403, 0.5403],
[0.5403, 0.5403, 0.5403]])
应该注意 cond
的一些限制
谓词(即
x.sum() > 0
)必须生成布尔值或单元素张量。操作数(即
[x]
)必须是张量。分支函数(即
true_fn
和false_fn
)签名必须与操作数匹配,并且它们都必须返回具有相同元数据(例如,dtype
、shape
等)的单个张量。分支函数不能改变输入或全局变量。
分支函数不能访问闭包变量,除非函数在方法的范围内定义,否则不能访问
self
。
有关 cond
的更多详细信息,请查看 cond 文档。
我们还可以使用 map
,它在第一个张量参数的第一个维度上应用函数。
from torch._higher_order_ops.map import map as torch_map
class MapModule(torch.nn.Module):
def forward(self, xs, y, z):
def body(x, y, z):
return x + y + z
return torch_map(body, xs, y, z)
inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
exported_map_example = export(MapModule(), inps)
print(exported_map_example)
print(exported_map_example.module()(*inps))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, xs: "f32[6, 4]", y: "i64[]", z: "i64[]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:236 in forward, code: return torch_map(body, xs, y, z)
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y, z]); body_graph_0 = xs = y = z = None
getitem: "f32[6, 4]" = map_impl[0]; map_impl = None
return (getitem,)
class body_graph_0(torch.nn.Module):
def forward(self, xs: "f32[4]", y: "i64[]", z: "i64[]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:234 in body, code: return x + y + z
add: "f32[4]" = torch.ops.aten.add.Tensor(xs, y); xs = y = None
add_1: "f32[4]" = torch.ops.aten.add.Tensor(add, z); add = z = None
return (add_1,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='xs'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}
tensor([[10., 10., 10., 10.],
[10., 10., 10., 10.],
[10., 10., 10., 10.],
[10., 10., 10., 10.],
[10., 10., 10., 10.],
[10., 10., 10., 10.]])
其他控制流操作包括 while_loop
、associative_scan
和 scan
。有关每个运算符的更多文档,请参阅 此页面。
约束/动态形状¶
本节介绍导出程序的动态行为和表示形式。动态行为受要导出的特定模型的主观影响,因此在本教程的大部分内容中,我们将重点关注这个特定的玩具模型(并带有生成的张量形状注释)
class DynamicModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(5, 3)
def forward(
self,
w: torch.Tensor, # [6, 5]
x: torch.Tensor, # [4]
y: torch.Tensor, # [8, 4]
z: torch.Tensor, # [32]
):
x0 = x + y # [8, 4]
x1 = self.l(w) # [6, 3]
x2 = x0.flatten() # [32]
x3 = x2 + z # [32]
return x1, x3
默认情况下,torch.export
生成静态程序。这样做的后果之一是,在运行时,该程序将无法处理具有不同形状的输入,即使它们在 eager 模式下有效。
w = torch.randn(6, 5)
x = torch.randn(4)
y = torch.randn(8, 4)
z = torch.randn(32)
model = DynamicModel()
ep = export(model, (w, x, y, z))
model(w, x, torch.randn(3, 4), torch.randn(12))
try:
ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 286, in <module>
ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 822, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 400, in __call__
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 387, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1772, in inner
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_unlift.py", line 49, in _check_input_constraints_pre_hook
_check_input_constraints_for_graph(
File "/usr/local/lib/python3.10/dist-packages/torch/_export/utils.py", line 360, in _check_input_constraints_for_graph
raise RuntimeError(
RuntimeError: Expected input at *args[2].shape[0] to be equal to 8, but got 3
基本概念:符号和保护¶
为了启用动态性,export()
提供了 dynamic_shapes
参数。使用动态形状的最简单方法是使用 Dim.AUTO
并查看返回的程序。动态行为在输入维度级别指定;对于每个输入,我们可以指定一个值元组
在我们查看生成的程序之前,让我们了解指定 dynamic_shapes
需要什么,以及它如何与导出交互。对于指定了 Dim
对象的每个输入维度,都会分配一个符号,该符号的范围为 [2, inf]
(为什么不是 [0, inf]
或 [1, inf]
?我们将在后面的 0/1 特殊化部分中解释)。
然后,导出运行模型跟踪,查看模型执行的每个操作。每个单独的操作都可以发出所谓的“保护”;基本上是程序有效所需的布尔条件。当保护涉及为输入维度分配的符号时,程序会包含对有效输入形状的限制;即程序的动态行为。符号形状子系统负责接收所有发出的保护,并生成符合所有这些保护的最终程序表示形式。在我们看到 ExportedProgram
中的“最终表示形式”之前,让我们看一下我们正在跟踪的玩具模型发出的保护。
在此处,每个前向输入张量都使用在跟踪开始时分配的符号进行注释
class DynamicModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(5, 3)
def forward(
self,
w: torch.Tensor, # [s0, s1]
x: torch.Tensor, # [s2]
y: torch.Tensor, # [s3, s4]
z: torch.Tensor, # [s5]
):
x0 = x + y # guard: s2 == s4
x1 = self.l(w) # guard: s1 == 5
x2 = x0.flatten() # no guard added here
x3 = x2 + z # guard: s3 * s4 == s5
return x1, x3
让我们了解每个操作和发出的保护
x0 = x + y
:这是一个元素级加法,带有广播,因为x
是 1-d 张量,而y
是 2-d 张量。x
沿y
的最后一个维度广播,发出保护s2 == s4
。x1 = self.l(w)
:调用nn.Linear()
使用模型参数执行矩阵乘法。在导出中,参数、缓冲区和常量被视为程序状态,程序状态被认为是静态的,因此这是动态输入(w: [s0, s1]
)和静态形状张量之间的 matmul。这会发出保护s1 == 5
。x2 = x0.flatten()
:此调用实际上不发出任何保护!(至少与输入形状无关)x3 = x2 + z
:x2
在展平后具有形状[s3*s4]
,此元素级加法发出s3 * s4 == s5
。
写下所有这些保护并进行总结几乎就像数学证明,这正是符号形状子系统尝试做的事情!总之,我们可以得出结论,程序必须具有以下输入形状才是有效的
w: [s0, 5]
x: [s2]
y: [s3, s2]
z: [s2*s3]
当我们最终打印出导出的程序以查看结果时,我们在相应的输入上看到了这些形状注释
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_l_weight: "f32[3, 5]", p_l_bias: "f32[3]", w: "f32[s0, 5]", x: "f32[s2]", y: "f32[s3, s2]", z: "f32[s2*s3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward, code: x0 = x + y # [8, 4]
add: "f32[s3, s2]" = torch.ops.aten.add.Tensor(x, y); x = y = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward, code: x1 = self.l(w) # [6, 3]
linear: "f32[s0, 3]" = torch.ops.aten.linear.default(w, p_l_weight, p_l_bias); w = p_l_weight = p_l_bias = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:270 in forward, code: x2 = x0.flatten() # [32]
flatten: "f32[s2*s3]" = torch.ops.aten.flatten.using_ints(add); add = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward, code: x3 = x2 + z # [32]
add_1: "f32[s2*s3]" = torch.ops.aten.add.Tensor(flatten, z); flatten = z = None
return (linear, add_1)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_l_weight'), target='l.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_l_bias'), target='l.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='w'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {s0: VR[2, int_oo], s2: VR[2, int_oo], s3: VR[2, int_oo], s2*s3: VR[4, int_oo]}
需要注意的另一个功能是上面的 range_constraints 字段,其中包含每个符号的有效范围。目前这并不是很有趣,因为此导出调用未发出任何与符号边界相关的保护,并且每个基本符号都具有通用边界,但这将在后面介绍。
到目前为止,由于我们一直在导出这个玩具模型,因此这种体验并不能代表调试动态形状保护和问题通常有多么困难。在大多数情况下,不清楚正在发出哪些保护,以及哪些操作和用户代码部分负责。对于这个玩具模型,我们精确地指出了确切的行,并且保护非常直观。
在更复杂的情况下,有用的第一步始终是启用详细日志记录。这可以通过环境变量 TORCH_LOGS="+dynamic"
或以交互方式使用 torch._logging.set_logs(dynamic=10)
来完成
I0203 16:55:11.092000 634 torch/fx/experimental/symbolic_shapes.py:3192] [12/0] create_env
I0203 16:55:11.094000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.094000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.095000 634 torch/fx/experimental/symbolic_shapes.py:6614] [12/0] runtime_assert True == True [statically known]
I0203 16:55:11.098000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.100000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.100000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.103000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.105000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s2, 1) == False [statically known]
V0203 16:55:11.106000 634 torch/fx/experimental/symbolic_shapes.py:6614] [12/0] runtime_assert True == True [statically known]
V0203 16:55:11.107000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s4, 1) == False [statically known]
I0203 16:55:11.108000 634 torch/fx/experimental/symbolic_shapes.py:5963] [12/0] set_replacement s4 = s2 (solve) VR[2, int_oo]
I0203 16:55:11.109000 634 torch/fx/experimental/symbolic_shapes.py:6281] [12/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # [8, 4] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
V0203 16:55:11.110000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Ne(s2, 1) == True [statically known]
V0203 16:55:11.111000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Ne(s3, 1) == True [statically known]
V0203 16:55:11.118000 634 torch/fx/experimental/symbolic_shapes.py:5802] [12/0] _update_var_to_range s1 = VR[5, 5] (update)
I0203 16:55:11.118000 634 torch/fx/experimental/symbolic_shapes.py:5963] [12/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0203 16:55:11.119000 634 torch/fx/experimental/symbolic_shapes.py:6281] [12/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0203 16:55:11.120000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s0, 1) == False [statically known]
V0203 16:55:11.126000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s2*s3, 1) == False [statically known]
V0203 16:55:11.127000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s5, 1) == False [statically known]
V0203 16:55:11.128000 634 torch/fx/experimental/symbolic_shapes.py:5802] [12/0] _update_var_to_range s5 = VR[4, int_oo] (update)
I0203 16:55:11.129000 634 torch/fx/experimental/symbolic_shapes.py:5963] [12/0] set_replacement s5 = s2*s3 (solve) VR[4, int_oo]
I0203 16:55:11.130000 634 torch/fx/experimental/symbolic_shapes.py:6281] [12/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0203 16:55:11.131000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Ne(s2*s3, 1) == True [statically known]
I0203 16:55:11.138000 634 torch/fx/experimental/symbolic_shapes.py:4547] [12/0] produce_guards
V0203 16:55:11.138000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].size()[0] s0 None
V0203 16:55:11.139000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].size()[1] 5 None
V0203 16:55:11.139000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].stride()[0] 5 None
V0203 16:55:11.139000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].stride()[1] 1 None
V0203 16:55:11.140000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].storage_offset() 0 None
V0203 16:55:11.140000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['x'].size()[0] s2 None
V0203 16:55:11.141000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['x'].stride()[0] 1 None
V0203 16:55:11.141000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.141000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].size()[0] s3 None
V0203 16:55:11.142000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].size()[1] s2 None
V0203 16:55:11.142000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].stride()[0] s2 None
V0203 16:55:11.142000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].stride()[1] 1 None
V0203 16:55:11.143000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].storage_offset() 0 None
V0203 16:55:11.143000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['z'].size()[0] s2*s3 None
V0203 16:55:11.143000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['z'].stride()[0] 1 None
V0203 16:55:11.144000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['z'].storage_offset() 0 None
V0203 16:55:11.179000 634 torch/fx/experimental/symbolic_shapes.py:6412] eval Ne(s0, 1) == True [statically known]
即使使用这个简单的玩具模型,这也会吐出很多内容。此处的日志行已在开头和结尾处截断,以忽略不必要的信息,但通过查看日志,我们可以看到与我们上面描述的内容相关的行;例如,符号的分配
"""
create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
runtime_assert True == True [statically known]
create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
"""
"\ncreate_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\nruntime_assert True == True [statically known]\ncreate_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\n"
带有 create_symbol 的行显示何时分配了新符号,并且日志还标识了已为其分配的张量变量名称和维度。在其他行中,我们还可以看到发出的保护
"""
runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
"""
'\nruntime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"\nruntime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"\nruntime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"\n'
在 [guard added]
消息旁边,我们还看到了负责的用户代码行 - 幸运的是,这里的模型足够简单。在许多实际情况下,情况并非如此简单:高级 torch 操作可能具有复杂的假内核实现或运算符分解,这会使发出保护的位置和内容变得复杂。在这种情况下,深入挖掘和调查的最佳方法是遵循日志的建议,并使用环境变量 TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..."
重新运行,以进一步归因于感兴趣的保护。
Dim.AUTO
只是与 dynamic_shapes
交互的可用选项之一;在编写本文时,还有 2 个其他选项可用:Dim.DYNAMIC
和 Dim.STATIC
。Dim.STATIC
只是将维度标记为静态,而 Dim.DYNAMIC
在所有方面都类似于 Dim.AUTO
,但有一个例外:当专门化为常量时,它会引发错误;这旨在保持动态性。例如,请参阅在动态标记的维度上发出静态保护时会发生什么
I0203 16:55:11.200000 634 torch/fx/experimental/symbolic_shapes.py:3192] [13/0] create_env
I0203 16:55:11.202000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.203000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.204000 634 torch/fx/experimental/symbolic_shapes.py:6614] [13/0] runtime_assert True == True [statically known]
I0203 16:55:11.206000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.208000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.208000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.211000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.213000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s2, 1) == False [statically known]
V0203 16:55:11.214000 634 torch/fx/experimental/symbolic_shapes.py:6614] [13/0] runtime_assert True == True [statically known]
V0203 16:55:11.214000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s4, 1) == False [statically known]
I0203 16:55:11.216000 634 torch/fx/experimental/symbolic_shapes.py:5963] [13/0] set_replacement s4 = s2 (solve) VR[2, int_oo]
I0203 16:55:11.217000 634 torch/fx/experimental/symbolic_shapes.py:6281] [13/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # [8, 4] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
V0203 16:55:11.218000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Ne(s2, 1) == True [statically known]
V0203 16:55:11.219000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Ne(s3, 1) == True [statically known]
V0203 16:55:11.226000 634 torch/fx/experimental/symbolic_shapes.py:5802] [13/0] _update_var_to_range s1 = VR[5, 5] (update)
I0203 16:55:11.226000 634 torch/fx/experimental/symbolic_shapes.py:5963] [13/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0203 16:55:11.227000 634 torch/fx/experimental/symbolic_shapes.py:6281] [13/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0203 16:55:11.228000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s0, 1) == False [statically known]
V0203 16:55:11.234000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s2*s3, 1) == False [statically known]
V0203 16:55:11.235000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s5, 1) == False [statically known]
V0203 16:55:11.236000 634 torch/fx/experimental/symbolic_shapes.py:5802] [13/0] _update_var_to_range s5 = VR[4, int_oo] (update)
I0203 16:55:11.237000 634 torch/fx/experimental/symbolic_shapes.py:5963] [13/0] set_replacement s5 = s2*s3 (solve) VR[4, int_oo]
I0203 16:55:11.238000 634 torch/fx/experimental/symbolic_shapes.py:6281] [13/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0203 16:55:11.239000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Ne(s2*s3, 1) == True [statically known]
I0203 16:55:11.246000 634 torch/fx/experimental/symbolic_shapes.py:4547] [13/0] produce_guards
V0203 16:55:11.246000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].size()[0] s0 None
V0203 16:55:11.246000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].size()[1] 5 RelaxedUnspecConstraint(warn_only=False)
V0203 16:55:11.247000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].stride()[0] 5 None
V0203 16:55:11.247000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].stride()[1] 1 None
V0203 16:55:11.247000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].storage_offset() 0 None
V0203 16:55:11.248000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['x'].size()[0] s2 None
V0203 16:55:11.248000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['x'].stride()[0] 1 None
V0203 16:55:11.248000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.249000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].size()[0] s3 None
V0203 16:55:11.249000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].size()[1] s2 None
V0203 16:55:11.249000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].stride()[0] s2 None
V0203 16:55:11.250000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].stride()[1] 1 None
V0203 16:55:11.250000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].storage_offset() 0 None
V0203 16:55:11.250000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['z'].size()[0] s2*s3 None
V0203 16:55:11.251000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['z'].stride()[0] 1 None
V0203 16:55:11.251000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['z'].storage_offset() 0 None
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Error while creating guard:
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Name: ''
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Source: shape_env
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Create Function: SHAPE_ENV
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Guard Types: None
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Code List: None
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Object Weakref: None
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Guarded Class Weakref: None
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] Traceback (most recent call last):
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] return self.create_fn(builder, self)
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] raise ConstraintViolationError(
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
E0203 16:55:11.253000 634 torch/_guards.py:295] [13/0] - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0] Created at:
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 642, in transform
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0] tracer = InstructionTranslator(
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2711, in __init__
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0] output=OutputGraph(
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 336, in __init__
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0] self.init_ambient_guards()
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 485, in init_ambient_guards
E0203 16:55:11.255000 634 torch/_guards.py:297] [13/0] self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1614, in inner
raise constraint_violation_error
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
result_traced = opt_f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 852, in _compile_inner
check_fn = CheckFunctionManager(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2303, in __init__
guard.create(builder)
File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
return self.create_fn(builder, self)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 418, in <module>
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
return _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 679, in _export_to_torch_ir
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).
静态保护也并非总是模型固有的;它们也可能来自用户规范。实际上,导致形状特殊化的常见陷阱是用户为等效维度指定冲突的标记;一个动态标记,另一个静态标记。当 x.shape[0]
和 y.shape[1]
的情况如此时,也会引发相同的错误类型
I0203 16:55:11.273000 634 torch/fx/experimental/symbolic_shapes.py:3192] [14/0] create_env
I0203 16:55:11.276000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.276000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.277000 634 torch/fx/experimental/symbolic_shapes.py:6614] [14/0] runtime_assert True == True [statically known]
I0203 16:55:11.280000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s2 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.280000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s3 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.283000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s4 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.286000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Eq(s3, 1) == False [statically known]
V0203 16:55:11.290000 634 torch/fx/experimental/symbolic_shapes.py:5802] [14/0] _update_var_to_range s3 = VR[4, 4] (update)
I0203 16:55:11.291000 634 torch/fx/experimental/symbolic_shapes.py:5963] [14/0] set_replacement s3 = 4 (range_refined_to_singleton) VR[4, 4]
I0203 16:55:11.291000 634 torch/fx/experimental/symbolic_shapes.py:6281] [14/0] runtime_assert Eq(s3, 4) [guard added] x0 = x + y # [8, 4] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s3, 4)"
V0203 16:55:11.293000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Ne(s2, 1) == True [statically known]
V0203 16:55:11.299000 634 torch/fx/experimental/symbolic_shapes.py:5802] [14/0] _update_var_to_range s1 = VR[5, 5] (update)
I0203 16:55:11.300000 634 torch/fx/experimental/symbolic_shapes.py:5963] [14/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0203 16:55:11.300000 634 torch/fx/experimental/symbolic_shapes.py:6281] [14/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0203 16:55:11.302000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Eq(s0, 1) == False [statically known]
V0203 16:55:11.302000 634 torch/fx/experimental/symbolic_shapes.py:6614] [14/0] runtime_assert True == True [statically known]
V0203 16:55:11.309000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Eq(s4, 1) == False [statically known]
V0203 16:55:11.314000 634 torch/fx/experimental/symbolic_shapes.py:5802] [14/0] _update_var_to_range s4 = VR[8, int_oo] (update)
I0203 16:55:11.317000 634 torch/fx/experimental/symbolic_shapes.py:5963] [14/0] set_replacement s4 = 4*s2 (solve) VR[8, int_oo]
I0203 16:55:11.317000 634 torch/fx/experimental/symbolic_shapes.py:6281] [14/0] runtime_assert Eq(4*s2, s4) [guard added] x3 = x2 + z # [32] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(4*s2, s4)"
I0203 16:55:11.324000 634 torch/fx/experimental/symbolic_shapes.py:4547] [14/0] produce_guards
V0203 16:55:11.324000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].size()[0] s0 None
V0203 16:55:11.325000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].size()[1] 5 None
V0203 16:55:11.325000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].stride()[0] 5 None
V0203 16:55:11.325000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].stride()[1] 1 None
V0203 16:55:11.326000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].storage_offset() 0 None
V0203 16:55:11.326000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['x'].size()[0] 4 None
V0203 16:55:11.326000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['x'].stride()[0] 1 None
V0203 16:55:11.327000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.327000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].size()[0] s2 None
V0203 16:55:11.327000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].size()[1] 4 RelaxedUnspecConstraint(warn_only=False)
V0203 16:55:11.327000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].stride()[0] 4 None
V0203 16:55:11.328000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].stride()[1] 1 None
V0203 16:55:11.328000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].storage_offset() 0 None
V0203 16:55:11.328000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['z'].size()[0] 4*s2 None
V0203 16:55:11.329000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['z'].stride()[0] 1 None
V0203 16:55:11.329000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['z'].storage_offset() 0 None
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Error while creating guard:
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Name: ''
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Source: shape_env
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Create Function: SHAPE_ENV
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Guard Types: None
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Code List: None
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Object Weakref: None
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Guarded Class Weakref: None
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] Traceback (most recent call last):
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] return self.create_fn(builder, self)
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] raise ConstraintViolationError(
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
E0203 16:55:11.330000 634 torch/_guards.py:295] [14/0] - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0] Created at:
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 642, in transform
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0] tracer = InstructionTranslator(
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2711, in __init__
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0] output=OutputGraph(
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 336, in __init__
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0] self.init_ambient_guards()
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 485, in init_ambient_guards
E0203 16:55:11.332000 634 torch/_guards.py:297] [14/0] self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1614, in inner
raise constraint_violation_error
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
result_traced = opt_f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 852, in _compile_inner
check_fn = CheckFunctionManager(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2303, in __init__
guard.create(builder)
File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
return self.create_fn(builder, self)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 431, in <module>
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
return _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 679, in _export_to_torch_ir
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).
在这里您可能会问,为什么导出“特殊化”,即为什么我们通过采用静态路由来解决这种静态/动态冲突。答案是由于上面描述的符号形状系统,即符号和保护。当 x.shape[0]
标记为静态时,我们不分配符号,而是编译时将此形状视为具体整数 4。为 y.shape[1]
分配了一个符号,因此我们最终发出保护 s3 == 4
,从而导致特殊化。
导出的一个功能是在跟踪期间,诸如断言、torch._check()
和 if/else
条件之类的语句也会发出保护。请参阅当我们使用此类语句增强现有模型时会发生什么
class DynamicModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(5, 3)
def forward(self, w, x, y, z):
assert w.shape[0] <= 512
torch._check(x.shape[0] >= 4)
if w.shape[0] == x.shape[0] + 2:
x0 = x + y
x1 = self.l(w)
x2 = x0.flatten()
x3 = x2 + z
return x1, x3
else:
return w
dynamic_shapes = {
"w": (Dim.AUTO, Dim.AUTO),
"x": (Dim.AUTO,),
"y": (Dim.AUTO, Dim.AUTO),
"z": (Dim.AUTO,),
}
try:
ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
tb.print_exc()
I0203 16:55:11.350000 634 torch/fx/experimental/symbolic_shapes.py:3192] [15/0] create_env
I0203 16:55:11.352000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.353000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.354000 634 torch/fx/experimental/symbolic_shapes.py:6614] [15/0] runtime_assert True == True [statically known]
I0203 16:55:11.356000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.358000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.358000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.361000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.367000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s0 = VR[2, 512] (update)
I0203 16:55:11.368000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert s0 <= 512 [guard added] assert w.shape[0] <= 512 # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:450 in forward (_dynamo/symbolic_convert.py:522 in inner), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s0 <= 512"
V0203 16:55:11.372000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s2 = VR[4, int_oo] (update)
I0203 16:55:11.373000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert s2 >= 4 [guard added] torch._check(x.shape[0] >= 4) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:451 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s2 >= 4"
V0203 16:55:11.379000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s0 = VR[6, 512] (update)
V0203 16:55:11.382000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s2 = VR[4, 510] (update)
I0203 16:55:11.382000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s0 = s2 + 2 (solve) VR[6, 512]
I0203 16:55:11.383000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] eval Eq(s0, s2 + 2) [guard added] if w.shape[0] == x.shape[0] + 2: # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:452 in forward (_dynamo/variables/tensor.py:1201 in evaluate_expr), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, s2 + 2)"
V0203 16:55:11.384000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s2, 1) == False [statically known]
V0203 16:55:11.385000 634 torch/fx/experimental/symbolic_shapes.py:6614] [15/0] runtime_assert True == True [statically known]
V0203 16:55:11.386000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s4, 1) == False [statically known]
V0203 16:55:11.388000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s4 = VR[4, 510] (update)
I0203 16:55:11.389000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s4 = s2 (solve) VR[4, 510]
I0203 16:55:11.390000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:453 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
V0203 16:55:11.391000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Ne(s2, 1) == True [statically known]
V0203 16:55:11.392000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Ne(s3, 1) == True [statically known]
V0203 16:55:11.399000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s1 = VR[5, 5] (update)
I0203 16:55:11.399000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0203 16:55:11.400000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:454 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0203 16:55:11.409000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s2*s3, 1) == False [statically known]
V0203 16:55:11.410000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s5, 1) == False [statically known]
V0203 16:55:11.419000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s5 = VR[8, int_oo] (update)
I0203 16:55:11.420000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s5 = s2*s3 (solve) VR[8, int_oo]
I0203 16:55:11.421000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:456 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0203 16:55:11.422000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Ne(s2*s3, 1) == True [statically known]
V0203 16:55:11.426000 634 torch/fx/experimental/symbolic_shapes.py:6614] [15/0] runtime_assert s2 >= 4 == True [statically known]
I0203 16:55:11.432000 634 torch/fx/experimental/symbolic_shapes.py:4547] [15/0] produce_guards
V0203 16:55:11.432000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].size()[0] s2 + 2 None
V0203 16:55:11.433000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].size()[1] 5 None
V0203 16:55:11.433000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].stride()[0] 5 None
V0203 16:55:11.433000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].stride()[1] 1 None
V0203 16:55:11.434000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].storage_offset() 0 None
V0203 16:55:11.434000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['x'].size()[0] s2 None
V0203 16:55:11.434000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['x'].stride()[0] 1 None
V0203 16:55:11.434000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.435000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].size()[0] s3 None
V0203 16:55:11.435000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].size()[1] s2 None
V0203 16:55:11.435000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].stride()[0] s2 None
V0203 16:55:11.436000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].stride()[1] 1 None
V0203 16:55:11.436000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].storage_offset() 0 None
V0203 16:55:11.436000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['z'].size()[0] s2*s3 None
V0203 16:55:11.436000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['z'].stride()[0] 1 None
V0203 16:55:11.437000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['z'].storage_offset() 0 None
这些语句中的每一个都会发出额外的保护,并且导出的程序会显示更改;s0
被消除,取而代之的是 s2 + 2
,并且 s2
现在包含下限和上限,反映在 range_constraints
中。
对于 if/else 条件,您可能会问为什么采用了 True 分支,以及为什么没有从跟踪发出的 w.shape[0] != x.shape[0] + 2
保护。答案是导出由跟踪提供的示例输入指导,并在采用的分支上进行特殊化。如果提供了不同的示例输入形状,这些形状未能满足 if
条件,则导出将跟踪并发出与 else
分支对应的保护。此外,您可能会问为什么我们只跟踪了 if
分支,以及是否可以在程序中保持控制流并保持两个分支都处于活动状态。为此,请参阅上面“控制流操作
”部分之后重写模型代码。
0/1 特殊化¶
由于我们正在讨论 guard 和 specialization,现在是讨论我们之前提出的 0/1 specialization 问题的好时机。 根本原因是 export 将会对值为 0 或 1 的 sample input 维度进行 specialization,因为这些形状具有 trace-time 属性,这些属性不能推广到其他形状。 例如,大小为 1 的张量可以广播,而其他大小则会失败;而大小为 0 的张量则...。 这仅仅意味着,当您希望程序硬编码 0/1 时,您应该指定 0/1 sample input,而当您希望获得动态行为时,则应指定非 0/1 sample input。 请参阅当我们导出此线性层时运行时会发生什么
ep = export(
torch.nn.Linear(4, 3),
(torch.randn(1, 4),),
dynamic_shapes={
"input": (Dim.AUTO, Dim.STATIC),
},
)
try:
ep.module()(torch.randn(2, 4))
except Exception:
tb.print_exc()
I0203 16:55:11.502000 634 torch/fx/experimental/symbolic_shapes.py:3192] [3/1] create_env
I0203 16:55:11.516000 634 torch/fx/experimental/symbolic_shapes.py:4547] [3/1] produce_guards
V0203 16:55:11.516000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].size()[0] 1 None
V0203 16:55:11.517000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].size()[1] 4 None
V0203 16:55:11.517000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].stride()[0] 4 None
V0203 16:55:11.517000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].stride()[1] 1 None
V0203 16:55:11.517000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].storage_offset() 0 None
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 500, in <module>
ep.module()(torch.randn(2, 4))
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 822, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 400, in __call__
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 387, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1772, in inner
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_unlift.py", line 49, in _check_input_constraints_pre_hook
_check_input_constraints_for_graph(
File "/usr/local/lib/python3.10/dist-packages/torch/_export/utils.py", line 360, in _check_input_constraints_for_graph
raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 1, but got 2
命名维度¶
到目前为止,我们只讨论了 3 种指定动态形状的方法:Dim.AUTO
、Dim.DYNAMIC
和 Dim.STATIC
。 这些方法的吸引力在于低摩擦的用户体验;模型 tracing 期间发出的所有 guard 都被遵守,并且 export 底层会自动计算出诸如 min/max 范围、关系以及静态/动态维度之类的动态行为。 动态形状子系统本质上充当“发现”过程,总结这些 guard 并呈现 export 认为的程序的整体动态行为。 这种设计的缺点在于,一旦用户对这些模型的动态行为有更强的期望或信念,就会显现出来 - 也许强烈希望动态化,并且要不惜一切代价避免对特定维度进行 specialization,或者我们只是想通过更改原始模型代码或可能的底层分解或 meta-kernel 来捕获动态行为的变化。 这些更改将不会被检测到,并且 export()
调用很可能会成功,除非有测试来检查生成的 ExportedProgram
表示。
对于这种情况,我们的立场是推荐“传统”的指定动态形状的方式,export 的长期用户可能对此很熟悉:命名的 Dims
这种动态形状的风格允许用户指定为输入维度分配哪些符号、这些符号的 min/max 边界,并限制生成的 ExportedProgram
的动态行为; 如果模型 tracing 发出的 guard 与给定的关系或静态/动态规范冲突,则会引发 ConstraintViolation
错误。 例如,在上述规范中,断言了以下内容
x.shape[0]
的范围为[4, 256]
,并且通过y.shape[0] == 2 * x.shape[0]
与y.shape[0]
相关。x.shape[1]
是静态的。y.shape[1]
的范围为[2, 512]
,并且与任何其他维度无关。
在这种设计中,我们允许使用单变量线性表达式指定维度之间的关系:可以为任何维度指定 A * dim + B
。 这允许用户为动态维度指定更复杂的约束,例如整数可除性
dx = Dim("dx", min=4, max=512)
dynamic_shapes = {
"x": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4.
}
约束冲突,建议的修复方法¶
这种规范风格的一个常见问题(在引入 Dim.AUTO
之前)是规范通常与模型 tracing 生成的内容不匹配。 这将导致 ConstraintViolation
错误和 export 建议的修复方法 - 例如,对于此模型和规范,模型本身需要 x
和 y
的维度 0 之间相等,并要求维度 1 是静态的。
class Foo(torch.nn.Module):
def forward(self, x, y):
w = x + y
return w + torch.ones(4)
dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
try:
ep = export(
Foo(),
(torch.randn(6, 4), torch.randn(6, 4)),
dynamic_shapes={
"x": (dx, d1),
"y": (dy, d1),
},
)
except Exception:
tb.print_exc()
I0203 16:55:11.662000 634 torch/fx/experimental/symbolic_shapes.py:3192] [16/0] create_env
I0203 16:55:11.665000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s0 = 6 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.665000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s1 = 4 for L['x'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.666000 634 torch/fx/experimental/symbolic_shapes.py:6614] [16/0] runtime_assert True == True [statically known]
I0203 16:55:11.668000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s2 = 6 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0203 16:55:11.669000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s3 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0203 16:55:11.673000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s1, 1) == False [statically known]
V0203 16:55:11.673000 634 torch/fx/experimental/symbolic_shapes.py:6614] [16/0] runtime_assert True == True [statically known]
V0203 16:55:11.674000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s0, 1) == False [statically known]
V0203 16:55:11.675000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s3, 1) == False [statically known]
I0203 16:55:11.677000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s3 = s1 (solve) VR[2, int_oo]
I0203 16:55:11.678000 634 torch/fx/experimental/symbolic_shapes.py:6281] [16/0] runtime_assert Eq(s1, s3) [guard added] w = x + y # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:552 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, s3)"
V0203 16:55:11.679000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s2, 1) == False [statically known]
I0203 16:55:11.681000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s2 = s0 (solve) VR[2, int_oo]
I0203 16:55:11.681000 634 torch/fx/experimental/symbolic_shapes.py:6281] [16/0] runtime_assert Eq(s0, s2) [guard added] w = x + y # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:552 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, s2)"
V0203 16:55:11.683000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Ne(s1, 1) == True [statically known]
V0203 16:55:11.683000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Ne(s0, 1) == True [statically known]
V0203 16:55:11.690000 634 torch/fx/experimental/symbolic_shapes.py:5802] [16/0] _update_var_to_range s1 = VR[4, 4] (update)
I0203 16:55:11.691000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s1 = 4 (range_refined_to_singleton) VR[4, 4]
I0203 16:55:11.691000 634 torch/fx/experimental/symbolic_shapes.py:6281] [16/0] runtime_assert Eq(s1, 4) [guard added] return w + torch.ones(4) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:553 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 4)"
V0203 16:55:11.694000 634 torch/fx/experimental/symbolic_shapes.py:5802] [16/0] _update_var_to_range s3 = VR[4, 4] (update)
I0203 16:55:11.695000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s3 = 4 (find) VR[4, 4]
I0203 16:55:11.698000 634 torch/fx/experimental/symbolic_shapes.py:4547] [16/0] produce_guards
V0203 16:55:11.698000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0203 16:55:11.698000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0203 16:55:11.699000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].stride()[0] 4 None
V0203 16:55:11.699000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].stride()[1] 1 None
V0203 16:55:11.699000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.699000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0203 16:55:11.700000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0203 16:55:11.700000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].stride()[0] 4 None
V0203 16:55:11.700000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].stride()[1] 1 None
V0203 16:55:11.700000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].storage_offset() 0 None
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Error while creating guard:
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Name: ''
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Source: shape_env
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Create Function: SHAPE_ENV
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Guard Types: None
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Code List: None
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Object Weakref: None
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Guarded Class Weakref: None
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] Traceback (most recent call last):
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] return self.create_fn(builder, self)
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] raise ConstraintViolationError(
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
E0203 16:55:11.702000 634 torch/_guards.py:295] [16/0] - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0] Created at:
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 642, in transform
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0] tracer = InstructionTranslator(
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2711, in __init__
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0] output=OutputGraph(
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 336, in __init__
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0] self.init_ambient_guards()
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 485, in init_ambient_guards
E0203 16:55:11.705000 634 torch/_guards.py:297] [16/0] self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1614, in inner
raise constraint_violation_error
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
result_traced = opt_f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 852, in _compile_inner
check_fn = CheckFunctionManager(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2303, in __init__
guard.create(builder)
File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
return self.create_fn(builder, self)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
- Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
- The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
Suggested fixes:
d1 = 4
dy = dx
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 557, in <module>
ep = export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
return _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 679, in _export_to_torch_ir
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
- Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
- The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
Suggested fixes:
d1 = 4
dy = dx
对建议的修复方法的期望是,用户可以交互式地将更改复制粘贴到其动态形状规范中,并在之后成功导出。
最后,还有一些关于规范选项的须知
None
是静态行为的一个不错的选择:-dynamic_shapes=None
(默认)导出时整个模型都是静态的。 - 在输入级别指定None
导出时所有张量维度都是静态的,并且对于非张量输入也是必需的。 - 在维度级别指定None
会 specialization 该维度,尽管不推荐这样做,而推荐使用Dim.STATIC
。指定每个维度的整数值也会产生静态行为,并且还会额外检查提供的 sample input 是否与规范匹配。
这些选项在下面的输入和动态形状规范中组合在一起
inputs = (
torch.randn(4, 4),
torch.randn(3, 3),
16,
False,
)
dynamic_shapes = {
"tensor_0": (Dim.AUTO, None),
"tensor_1": None,
"int_val": None,
"bool_val": None,
}
数据相关的错误¶
在尝试导出模型时,您可能遇到过诸如 “Could not guard on data-dependent expression” 或 “Could not extract specialized integer from data-dependent expression” 之类的错误。 这些错误的存在是因为 torch.export()
使用 FakeTensor 编译程序,FakeTensor 符号化地表示其真实张量 counterparts。 虽然这些具有等效的符号属性(例如,大小、步幅、dtypes),但它们的区别在于 FakeTensor 不包含任何数据值。 虽然这避免了不必要的内存使用和昂贵的计算,但这确实意味着 export 可能无法开箱即用地编译用户代码中依赖于数据值的部分。 简而言之,如果编译器需要具体的、数据相关的值才能继续,则它会报错,并抱怨该值不可用。
数据相关的值出现在许多地方,常见的来源是诸如 item()
、tolist()
或 torch.unbind()
之类的调用,这些调用从张量中提取标量值。 这些值在导出的程序中如何表示? 在 约束/动态形状 部分,我们讨论了分配符号来表示动态输入维度。 这里也发生了同样的事情:我们为程序中出现的每个数据相关值分配符号。 重要的区别在于,这些是 “unbacked” 符号,与为输入维度分配的 “backed” 符号形成对比。 “backed/unbacked” 命名法指的是符号的 “hint” 的存在/不存在:一个具体的数值 backing 符号,可以告知编译器如何继续。
在输入形状符号的情况下(backed 符号),这些 hint 只是提供的 sample input 形状,这解释了为什么控制流分支由 sample input 属性决定。 对于数据相关的值,符号取自 tracing 期间的 FakeTensor “data”,因此编译器不知道这些符号将采用的实际值(hint)。
让我们看看这些如何在导出的程序中显示出来
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
b = y.tolist()
return b + [a]
inps = (
torch.tensor(1),
torch.tensor([2, 3]),
)
ep = export(Foo(), inps)
print(ep)
I0203 16:55:11.720000 634 torch/fx/experimental/symbolic_shapes.py:3192] [17/0] create_env
I0203 16:55:11.724000 634 torch/fx/experimental/symbolic_shapes.py:4103] [17/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.724000 634 torch/fx/experimental/symbolic_shapes.py:970] [17/0] compute_unbacked_bindings [u0]
I0203 16:55:11.727000 634 torch/fx/experimental/symbolic_shapes.py:4103] [17/0] create_unbacked_symint u1 [-int_oo, int_oo] b = y.tolist() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.727000 634 torch/fx/experimental/symbolic_shapes.py:970] [17/0] compute_unbacked_bindings [u1]
I0203 16:55:11.729000 634 torch/fx/experimental/symbolic_shapes.py:4103] [17/0] create_unbacked_symint u2 [-int_oo, int_oo] b = y.tolist() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.729000 634 torch/fx/experimental/symbolic_shapes.py:970] [17/0] compute_unbacked_bindings [u2]
I0203 16:55:11.733000 634 torch/fx/experimental/symbolic_shapes.py:4547] [17/0] produce_guards
V0203 16:55:11.734000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.734000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['y'].size()[0] 2 None
V0203 16:55:11.734000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['y'].stride()[0] 1 None
V0203 16:55:11.734000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['y'].storage_offset() 0 None
I0203 16:55:11.741000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u3 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.742000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u4 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.750000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u5 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.750000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u5]
I0203 16:55:11.750000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u5 = u0 (rename_unbacked_to) VR[-int_oo, int_oo]
I0203 16:55:11.752000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u6 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.753000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u6]
I0203 16:55:11.753000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u6 = u1 (rename_unbacked_to) VR[-int_oo, int_oo]
I0203 16:55:11.755000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u7 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.755000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u7]
I0203 16:55:11.755000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u7 = u2 (rename_unbacked_to) VR[-int_oo, int_oo]
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[]", y: "i64[2]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward, code: a = x.item()
item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward, code: b = y.tolist()
select: "i64[]" = torch.ops.aten.select.int(y, 0, 0)
item_1: "Sym(u1)" = torch.ops.aten.item.default(select); select = None
select_1: "i64[]" = torch.ops.aten.select.int(y, 0, 1); y = None
item_2: "Sym(u2)" = torch.ops.aten.item.default(select_1); select_1 = None
return (item_1, item_2, item)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item'), target=None)])
Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[-int_oo, int_oo], u2: VR[-int_oo, int_oo], u3: VR[-int_oo, int_oo], u4: VR[-int_oo, int_oo], u5: VR[-int_oo, int_oo], u6: VR[-int_oo, int_oo], u7: VR[-int_oo, int_oo]}
结果是分配并返回了 3 个 unbacked 符号(请注意它们以 “u” 为前缀,而不是通常的输入形状/backed 符号的 “s”):item()
调用 1 个,tolist()
调用中的 y
的每个元素各 1 个。 从范围约束字段中注意到,这些采用 [-int_oo, int_oo]
的范围,而不是分配给输入形状符号的默认 [0, int_oo]
范围,因为我们没有关于这些值是什么的信息 - 它们不代表大小,因此不一定具有正值。
Guard,torch._check()¶
但是上面的情况很容易导出,因为这些符号的具体值没有用于任何编译器决策;所有相关的是返回值是 unbacked 符号。 本节中重点介绍的数据相关错误是以下情况,其中遇到了 数据相关 guard
在这里,我们实际上需要 “hint”,或者 a
的具体值,以便编译器决定是 tracing return y + 2
还是 return y * 5
作为输出。 因为我们使用 FakeTensor 进行 tracing,所以我们不知道 a // 2 >= 5
实际评估结果,并且 export 报错 “Could not guard on data-dependent expression u0 // 2 >= 5 (unhinted)
”。
那么我们如何导出这个玩具模型呢? 与 torch.compile()
不同,export 需要完全图编译,我们不能仅仅在此处 graph break。 以下是一些基本选项
手动 specialization:我们可以通过选择要 tracing 的分支来进行干预,方法是删除控制流代码以仅包含 specialized 分支,或者使用
torch.compiler.is_compiling()
来 guard 在编译时 tracing 的内容。torch.cond()
:我们可以重写控制流代码以使用torch.cond()
,这样我们就不会 specialization 在分支上。
虽然这些选项是有效的,但它们也有其缺陷。 选项 1 有时需要对模型代码进行 drastic、侵入性的重写才能 specialization,并且 torch.cond()
不是用于处理数据相关错误的综合系统。 正如我们将看到的,存在不涉及控制流的数据相关错误。
通常推荐的方法是从 torch._check()
调用开始。 虽然这些调用给人的印象是纯粹的 assert 语句,但它们实际上是告知编译器关于符号属性的系统。 虽然 torch._check()
调用在运行时充当断言,但在编译时 tracing 时,检查的表达式被发送到符号形状子系统进行推理,并且从表达式为真得出的任何符号属性都存储为符号属性(前提是它足够智能以推断这些属性)。 因此,即使 unbacked 符号没有 hint,如果我们能够通过 torch._check()
调用传达通常对这些符号为真的属性,我们也可以潜在地绕过数据相关 guard,而无需重写有问题的模型代码。
例如,在上面的模型中,插入 torch._check(a >= 10)
会告诉编译器总是可以返回 y + 2
,而 torch._check(a == 4)
会告诉它返回 y * 5
。 看看当我们重新导出此模型时会发生什么。
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
torch._check(a >= 10)
torch._check(a <= 60)
if a // 2 >= 5:
return y + 2
else:
return y * 5
inps = (
torch.tensor(32),
torch.randn(4),
)
ep = export(Foo(), inps)
print(ep)
I0203 16:55:11.765000 634 torch/fx/experimental/symbolic_shapes.py:3192] [18/0] create_env
I0203 16:55:11.769000 634 torch/fx/experimental/symbolic_shapes.py:4103] [18/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.769000 634 torch/fx/experimental/symbolic_shapes.py:970] [18/0] compute_unbacked_bindings [u0]
V0203 16:55:11.772000 634 torch/fx/experimental/symbolic_shapes.py:5802] [18/0] _update_var_to_range u0 = VR[10, int_oo] (update)
I0203 16:55:11.773000 634 torch/fx/experimental/symbolic_shapes.py:6281] [18/0] runtime_assert u0 >= 10 [guard added] torch._check(a >= 10) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:673 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 10"
V0203 16:55:11.777000 634 torch/fx/experimental/symbolic_shapes.py:5802] [18/0] _update_var_to_range u0 = VR[10, 60] (update)
I0203 16:55:11.779000 634 torch/fx/experimental/symbolic_shapes.py:6281] [18/0] runtime_assert u0 <= 60 [guard added] torch._check(a <= 60) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:674 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 <= 60"
V0203 16:55:11.783000 634 torch/fx/experimental/symbolic_shapes.py:6412] [18/0] eval ((u0//2)) >= 5 == True [statically known]
V0203 16:55:11.786000 634 torch/fx/experimental/symbolic_shapes.py:6614] [18/0] runtime_assert u0 >= 10 == True [statically known]
V0203 16:55:11.787000 634 torch/fx/experimental/symbolic_shapes.py:6614] [18/0] runtime_assert u0 <= 60 == True [statically known]
I0203 16:55:11.790000 634 torch/fx/experimental/symbolic_shapes.py:4547] [18/0] produce_guards
V0203 16:55:11.790000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.791000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['y'].size()[0] 4 None
V0203 16:55:11.791000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['y'].stride()[0] 1 None
V0203 16:55:11.791000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['y'].storage_offset() 0 None
I0203 16:55:11.806000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.807000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u1]
V0203 16:55:11.807000 634 torch/fx/experimental/symbolic_shapes.py:5802] _update_var_to_range u1 = VR[10, 60] (update)
I0203 16:55:11.807000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u1 = u0 (rename_unbacked_to) VR[10, 60]
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[]", y: "f32[4]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward, code: a = x.item()
item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None
ge_1: "Sym(u0 >= 10)" = item >= 10
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 10 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
le_1: "Sym(u0 <= 60)" = item <= 60; item = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 60 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:676 in forward, code: return y + 2
add: "f32[4]" = torch.ops.aten.add.Tensor(y, 2); y = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {u0: VR[10, 60], u1: VR[10, 60]}
Export 成功,并从范围约束字段中注意到 u0
采用 [10, 60]
的范围。
那么 torch._check()
调用实际上传达了哪些信息? 这会随着符号形状子系统变得更智能而变化,但在根本层面上,这些通常是真的
与非数据相关表达式的相等性:传达相等性的
torch._check()
调用,例如u0 == s0 + 4
或u0 == 5
。范围细化:提供符号的下限或上限的调用,如上所示。
围绕更复杂表达式的一些基本推理:插入
torch._check(a < 4)
通常会告诉编译器a >= 4
为假。 对复杂表达式的检查,例如torch._check(a ** 2 - 3 * a <= 10)
,通常会让您通过相同的 guard。
如前所述,torch._check()
调用在数据相关控制流之外也适用。 例如,这是一个模型,其中 torch._check()
插入占优势,而手动 specialization 和 torch.cond()
则不占优势
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
return y[a]
inps = (
torch.tensor(32),
torch.randn(60),
)
try:
export(Foo(), inps)
except Exception:
tb.print_exc()
I0203 16:55:11.821000 634 torch/fx/experimental/symbolic_shapes.py:3192] [19/0] create_env
I0203 16:55:11.825000 634 torch/fx/experimental/symbolic_shapes.py:4103] [19/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:701 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.826000 634 torch/fx/experimental/symbolic_shapes.py:970] [19/0] compute_unbacked_bindings [u0]
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] Data dependent variable 'u0' allocated at:
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/bin/sphinx-build", line 8, in <module>
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] sys.exit(main())
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return make_main(argv)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return make_mode.run_make_mode(argv[1:])
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return make.run_generic_build(args[0])
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return build_main(args + opts)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] self._init_builder()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] self.events.emit('builder-inited')
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] results.append(listener.handler(self.app, *args))
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] ) = generate_dir_rst(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] intro, title, cost = generate_file_rst(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] output_blocks, time_elapsed = execute_script(script_blocks,
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] output_blocks.append(execute_code_block(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] is_last_expr, mem_max = _exec_and_get_memory(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] mem_max, _ = gallery_conf['call_memory'](
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return 0., func()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] exec(self.code, self.fake_main.__dict__)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 709, in <module>
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] export(Foo(), inps)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return _export(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] ep = fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return _export_for_training(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] ep = fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] export_artifact = export_func( # type: ignore[operator]
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] gm_torch_level = _export_to_torch_ir(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] gm_torch_level, _ = torch._dynamo.export(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] result_traced = opt_f(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return self._call_impl(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return forward_call(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return self._call_impl(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return forward_call(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return self._torchdynamo_orig_callable(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return _compile(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] guarded_code = compile_inner(code, one_graph, hooks, transform)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return _compile_inner(code, one_graph, hooks, transform)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return function(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] out_code = transform_code_object(code, transform)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] transformations(instructions, code_options)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] tracer.run()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] super().run()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] while self.step():
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] self.dispatch_table[inst.opcode](self, inst)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return inner_fn(self, inst)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] self.call_function(fn, args, {})
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py", line 1022, in call_function
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return self.obj.call_method(tx, self.name, args, kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/tensor.py", line 591, in call_method
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return wrap_fx_proxy(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return _wrap_fx_proxy(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] ret_val = wrap_fake_exception(
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return fn()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] lambda: run_node(tx.output, node, args, kwargs, nnmodule)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2588, in run_node
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return getattr(args[0], node.target)(*args[1:], **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return fn(*args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return self.dispatch(func, types, args, kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return self._cached_dispatch_impl(func, types, args, kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] output = self._dispatch_impl(func, types, args, kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] op_impl_out = op_impl(self, func, *args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 403, in local_scalar_dense
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] r = fake_mode.shape_env.create_unbacked_symint()
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] return retlog(fn(*args, **kwargs))
V0203 16:55:11.829000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]
W0203 16:55:11.836000 634 torch/fx/experimental/symbolic_shapes.py:6307] [19/0] failed during evaluate_expr(-u0 > 60, hint=None, size_oblivious=True, forcing_spec=False
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] failed while running evaluate_expr(*(-u0 > 60, None), **{'fx_node': False, 'size_oblivious': True})
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] Traceback (most recent call last):
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] return retlog(fn(*args, **kwargs))
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] return self._evaluate_expr(
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] raise self._make_data_dependent_error(
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60). (Size-like symbols: none)
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] Caused by: return y[a] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] For more information, run with TORCH_LOGS="dynamic"
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] User Stack (most recent call last):
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] (snipped, see stack below for prefix)
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] return y[a]
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0]
E0203 16:55:11.837000 634 torch/fx/experimental/recording.py:299] [19/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] failed while attempting to run meta for aten.select.int
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] Traceback (most recent call last):
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] r = func(*args, **kwargs)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] return self._op(*args, **kwargs)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 4874, in meta_select
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 407, in guard_size_oblivious
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] return expr.node.guard_size_oblivious("", 0)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] r = self.shape_env.evaluate_expr(
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] return retlog(fn(*args, **kwargs))
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] return self._evaluate_expr(
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] raise self._make_data_dependent_error(
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60). (Size-like symbols: none)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] Caused by: return y[a] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For more information, run with TORCH_LOGS="dynamic"
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] User Stack (most recent call last):
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] (snipped, see stack below for prefix)
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] return y[a]
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]
E0203 16:55:11.838000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2586, in run_node
return node.target(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 4874, in meta_select
guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 407, in guard_size_oblivious
return expr.node.guard_size_oblivious("", 0)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
r = self.shape_env.evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
return retlog(fn(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
return self._evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60). (Size-like symbols: none)
Caused by: return y[a] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
return y[a]
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
ret_val = wrap_fake_exception(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
return fn()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2604, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2586, in run_node
return node.target(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 4874, in meta_select
guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 407, in guard_size_oblivious
return expr.node.guard_size_oblivious("", 0)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
r = self.shape_env.evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
return retlog(fn(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
return self._evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
raise self._make_data_dependent_error(
RuntimeError: Failed running call_function <built-in method select of type object at 0x7f3973a1fec0>(*(FakeTensor(..., size=(60,)), 0, u0), **{}):
Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60). (Size-like symbols: none)
Caused by: return y[a] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
return y[a]
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 709, in <module>
export(Foo(), inps)
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
return _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
result_traced = opt_f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
super().run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 314, in impl
self.push(fn_var.call_function(self, self.popn(nargs), {}))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1004, in call_function
return handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 980, in _handle_insert_op_in_graph
return wrap_fx_proxy(tx, proxy)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
return _wrap_fx_proxy(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2526, in get_fake_value
raise UserError( # noqa: B904
torch._dynamo.exc.UserError: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60). (Size-like symbols: none)
Caused by: return y[a] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
return y[a]
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.ac.cn/docs/main/generated/exportdb/index.html#constrain-as-size-example
from user code:
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
return y[a]
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
这是一个场景,仅仅为了防止操作失败,就需要插入 torch._check()
。 export 调用将失败,并显示 “Could not guard on data-dependent expression -u0 > 60
”,这意味着编译器不知道这是否是有效的索引操作 - x
的值是否超出 y
的范围。 在这里,手动 specialization 太过繁琐,而 torch.cond()
没有用武之地。 相反,告知编译器 u0
的范围就足够了
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
torch._check(a >= 0)
torch._check(a < y.shape[0])
return y[a]
inps = (
torch.tensor(32),
torch.randn(60),
)
ep = export(Foo(), inps)
print(ep)
I0203 16:55:11.859000 634 torch/fx/experimental/symbolic_shapes.py:3192] [20/0] create_env
I0203 16:55:11.863000 634 torch/fx/experimental/symbolic_shapes.py:4103] [20/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.864000 634 torch/fx/experimental/symbolic_shapes.py:970] [20/0] compute_unbacked_bindings [u0]
V0203 16:55:11.866000 634 torch/fx/experimental/symbolic_shapes.py:5802] [20/0] _update_var_to_range u0 = VR[0, int_oo] (update)
I0203 16:55:11.866000 634 torch/fx/experimental/symbolic_shapes.py:6281] [20/0] runtime_assert u0 >= 0 [guard added] torch._check(a >= 0) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:722 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V0203 16:55:11.870000 634 torch/fx/experimental/symbolic_shapes.py:5802] [20/0] _update_var_to_range u0 = VR[0, 59] (update)
I0203 16:55:11.871000 634 torch/fx/experimental/symbolic_shapes.py:6281] [20/0] runtime_assert u0 < 60 [guard added] torch._check(a < y.shape[0]) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:723 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 < 60"
V0203 16:55:11.873000 634 torch/fx/experimental/symbolic_shapes.py:6412] [20/0] eval -u0 > 60 == False [statically known]
V0203 16:55:11.873000 634 torch/fx/experimental/symbolic_shapes.py:6412] [20/0] eval u0 >= 60 == False [statically known]
V0203 16:55:11.874000 634 torch/fx/experimental/symbolic_shapes.py:6412] [20/0] eval u0 >= 0 == True [statically known]
V0203 16:55:11.877000 634 torch/fx/experimental/symbolic_shapes.py:6614] [20/0] runtime_assert u0 >= 0 == True [statically known]
V0203 16:55:11.878000 634 torch/fx/experimental/symbolic_shapes.py:6614] [20/0] runtime_assert u0 <= 59 == True [statically known]
V0203 16:55:11.879000 634 torch/fx/experimental/symbolic_shapes.py:6614] [20/0] runtime_assert u0 < 60 == True [statically known]
I0203 16:55:11.882000 634 torch/fx/experimental/symbolic_shapes.py:4547] [20/0] produce_guards
V0203 16:55:11.883000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:11.883000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['y'].size()[0] 60 None
V0203 16:55:11.883000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['y'].stride()[0] 1 None
V0203 16:55:11.883000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['y'].storage_offset() 0 None
I0203 16:55:11.901000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.902000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u1]
V0203 16:55:11.902000 634 torch/fx/experimental/symbolic_shapes.py:5802] _update_var_to_range u1 = VR[0, 59] (update)
I0203 16:55:11.903000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u1 = u0 (rename_unbacked_to) VR[0, 59]
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[]", y: "f32[60]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward, code: a = x.item()
item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None
ge_1: "Sym(u0 >= 0)" = item >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
le_1: "Sym(u0 <= 59)" = item <= 59
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 59 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
#
lt_1: "Sym(u0 < 60)" = item < 60
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u0 < 60 on node 'lt_1'"); lt_1 = _assert_scalar_default_2 = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:724 in forward, code: return y[a]
select: "f32[]" = torch.ops.aten.select.int(y, 0, item); y = item = None
return (select,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='select'), target=None)])
Range constraints: {u0: VR[0, 59], u1: VR[0, 59]}
specialized 值¶
另一类数据相关错误发生在程序尝试在 tracing 时提取具体的、数据相关的整数/浮点值时。 这看起来像 “Could not extract specialized integer from data-dependent expression”,并且类似于之前的错误类别 - 如果这些错误发生在尝试评估具体的整数/浮点值时,则数据相关 guard 错误会随着评估具体的布尔值而出现。
此错误通常发生在对数据相关表达式存在显式或隐式 int()
强制转换时。 例如,此列表推导式具有一个 range() 调用,该调用隐式地对列表的大小进行 int()
强制转换
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
b = torch.cat([y for y in range(a)], dim=0)
return b + int(a)
inps = (
torch.tensor(32),
torch.randn(60),
)
try:
export(Foo(), inps, strict=False)
except Exception:
tb.print_exc()
I0203 16:55:11.920000 634 torch/fx/experimental/symbolic_shapes.py:3192] create_env
I0203 16:55:11.926000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.927000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u0]
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] Data dependent variable 'u0' allocated at:
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/bin/sphinx-build", line 8, in <module>
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] sys.exit(main())
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return make_main(argv)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return make_mode.run_make_mode(argv[1:])
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return make.run_generic_build(args[0])
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return build_main(args + opts)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] self._init_builder()
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] self.events.emit('builder-inited')
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] results.append(listener.handler(self.app, *args))
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] ) = generate_dir_rst(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] intro, title, cost = generate_file_rst(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] output_blocks, time_elapsed = execute_script(script_blocks,
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] output_blocks.append(execute_code_block(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] is_last_expr, mem_max = _exec_and_get_memory(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] mem_max, _ = gallery_conf['call_memory'](
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return 0., func()
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] exec(self.code, self.fake_main.__dict__)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] export(Foo(), inps, strict=False)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return _export(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] ep = fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return _export_for_training(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] ep = fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] export_artifact = export_func( # type: ignore[operator]
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1772, in _non_strict_export
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] aten_export_artifact = _to_aten_func( # type: ignore[operator]
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1564, in _export_to_aten_ir_make_fx
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] gm, graph_signature = transform(_make_fx_helper)(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1702, in _aot_export_non_strict
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1485, in _make_fx_helper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] gm = make_fx(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2196, in wrapped
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return make_fx_tracer.trace(f, *args)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2134, in trace
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return self._trace_inner(f, *args)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2105, in _trace_inner
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] t = dispatch_trace(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return disable_fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1138, in dispatch_trace
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1694, in trace
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] res = super().trace(root, concrete_args)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in trace
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] (self.create_arg(fn(*args)),),
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1193, in wrapped
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] out = f(*tensors) # type:ignore[call-arg]
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "<string>", line 1, in <lambda>
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1469, in wrapped_fn
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return tuple(flat_fn(*args))
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] tree_out = fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 879, in functional_call
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] out = mod(*args[params_len:], **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return self.call_module(mod, forward, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return Tracer.call_module(self, m, forward, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] ret_val = forward(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return _orig_module_call(mod, *args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return self._call_impl(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return forward_call(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1689, in forward
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] tree_out = mod(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return self.call_module(mod, forward, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return Tracer.call_module(self, m, forward, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] ret_val = forward(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return _orig_module_call(mod, *args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return self._call_impl(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return forward_call(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 747, in forward
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] a = x.item()
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1241, in __torch_function__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return func(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1288, in __torch_function__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return func(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 557, in __torch_function__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return func(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 840, in handler
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return torch._library.utils.handle_dispatch_mode(
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_library/utils.py", line 295, in handle_dispatch_mode
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1343, in __torch_dispatch__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return proxy_call(self, func, self.pre_dispatch, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 912, in proxy_call
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] out = func(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return self._op(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return fn(*args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return self.dispatch(func, types, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return self._cached_dispatch_impl(func, types, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] output = self._dispatch_impl(func, types, args, kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] op_impl_out = op_impl(self, func, *args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 403, in local_scalar_dense
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] r = fake_mode.shape_env.create_unbacked_symint()
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727] return retlog(fn(*args, **kwargs))
V0203 16:55:11.928000 634 torch/fx/experimental/symbolic_shapes.py:5727]
W0203 16:55:11.936000 634 torch/fx/experimental/symbolic_shapes.py:6307] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] failed while running evaluate_expr(*(u0, None), **{'fx_node': False})
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] Traceback (most recent call last):
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] return retlog(fn(*args, **kwargs))
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] return self._evaluate_expr(
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] raise self._make_data_dependent_error(
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: none)
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299]
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] For more information, run with TORCH_LOGS="dynamic"
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299]
E0203 16:55:11.936000 634 torch/fx/experimental/recording.py:299] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
export(Foo(), inps, strict=False)
File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
return _export(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
return _export_for_training(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1772, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1564, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1702, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1485, in _make_fx_helper
gm = make_fx(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2196, in wrapped
return make_fx_tracer.trace(f, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2134, in trace
return self._trace_inner(f, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2105, in _trace_inner
t = dispatch_trace(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
return disable_fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1138, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1694, in trace
res = super().trace(root, concrete_args)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in trace
(self.create_arg(fn(*args)),),
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1193, in wrapped
out = f(*tensors) # type:ignore[call-arg]
File "<string>", line 1, in <lambda>
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1469, in wrapped_fn
return tuple(flat_fn(*args))
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 879, in functional_call
out = mod(*args[params_len:], **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
ret_val = forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1689, in forward
tree_out = mod(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
ret_val = forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
return _orig_module_call(mod, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 748, in forward
b = torch.cat([y for y in range(a)], dim=0)
File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 427, in __index__
return self.node.int_()
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 445, in int_
return self.guard_int("", 0) # NB: uses Python backtrace
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 492, in guard_int
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
return retlog(fn(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
return self._evaluate_expr(
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: none)
Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
对于这些错误,您可以使用以下一些基本选项
避免不必要的
int()
强制转换调用,在本例中为 return 语句中的int(a)
。使用
torch._check()
调用;不幸的是,在这种情况下您可能只能 specialization(使用torch._check(a == 60)
)。在更高级别重写有问题的代码。 例如,列表推导式在语义上是一个
repeat()
操作,它不涉及int()
强制转换。 以下重写避免了数据相关错误
class Foo(torch.nn.Module):
def forward(self, x, y):
a = x.item()
b = y.unsqueeze(0).repeat(a, 1)
return b + a
inps = (
torch.tensor(32),
torch.randn(60),
)
ep = export(Foo(), inps, strict=False)
print(ep)
I0203 16:55:11.946000 634 torch/fx/experimental/symbolic_shapes.py:3192] create_env
I0203 16:55:11.951000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0203 16:55:11.952000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u0]
V0203 16:55:11.956000 634 torch/fx/experimental/symbolic_shapes.py:5802] _update_var_to_range u0 = VR[0, int_oo] (update)
I0203 16:55:11.957000 634 torch/fx/experimental/symbolic_shapes.py:6281] runtime_assert u0 >= 0 [guard added] (_refs/__init__.py:4800 in new_empty), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V0203 16:55:11.959000 634 torch/fx/experimental/symbolic_shapes.py:6412] eval Eq(u0, 0) == False [statically known]
V0203 16:55:11.962000 634 torch/fx/experimental/symbolic_shapes.py:6412] eval Eq(u0, 1) == False [statically known]
V0203 16:55:11.962000 634 torch/fx/experimental/symbolic_shapes.py:6614] runtime_assert True == True [statically known]
I0203 16:55:11.966000 634 torch/fx/experimental/symbolic_shapes.py:4547] produce_guards
V0203 16:55:11.966000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][0].storage_offset() 0 None
V0203 16:55:11.967000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][1].size()[0] 60 None
V0203 16:55:11.967000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][1].stride()[0] 1 None
V0203 16:55:11.967000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][1].storage_offset() 0 None
V0203 16:55:11.969000 634 torch/fx/experimental/symbolic_shapes.py:6614] runtime_assert u0 >= 0 == True [statically known]
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[]", y: "f32[60]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None
#
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
ge: "Sym(u0 >= 0)" = item >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:770 in forward, code: b = y.unsqueeze(0).repeat(a, 1)
unsqueeze: "f32[1, 60]" = torch.ops.aten.unsqueeze.default(y, 0); y = None
repeat: "f32[u0, 60]" = torch.ops.aten.repeat.default(unsqueeze, [item, 1]); unsqueeze = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:771 in forward, code: return b + a
add: "f32[u0, 60]" = torch.ops.aten.add.Tensor(repeat, item); repeat = item = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {u0: VR[0, int_oo]}
数据相关错误可能更加复杂,并且您的工具包中有更多选项来处理它们:torch._check_is_size()
、guard_size_oblivious()
或 real-tensor tracing,作为入门。 有关更深入的指南,请参阅 Export Programming Model 或 Dealing with GuardOnDataDependentSymNode errors。
自定义 Ops¶
torch.export
可以导出带有自定义运算符的 PyTorch 程序。 请参阅 此页面,了解如何在 C++ 或 Python 中编写自定义运算符。
以下是在 python 中注册自定义运算符以供 torch.export
使用的示例。 需要注意的重要事项是,自定义 op 必须具有 FakeTensor kernel。
@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(x: torch.Tensor) -> torch.Tensor:
print("custom_op called!")
return torch.relu(x)
@custom_op.register_fake
def custom_op_meta(x):
# Returns an empty tensor with the same shape as the expected output
return torch.empty_like(x)
这是导出带有自定义 op 的程序的示例。
class CustomOpExample(torch.nn.Module):
def forward(self, x):
x = torch.sin(x)
x = torch.ops.my_custom_library.custom_op(x)
x = torch.cos(x)
return x
exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
print(exported_custom_op_example)
print(exported_custom_op_example.module()(torch.randn(3, 3)))
I0203 16:55:11.985000 634 torch/fx/experimental/symbolic_shapes.py:3192] [21/0] create_env
I0203 16:55:11.995000 634 torch/fx/experimental/symbolic_shapes.py:4547] [21/0] produce_guards
V0203 16:55:11.996000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].size()[0] 3 None
V0203 16:55:11.996000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].size()[1] 3 None
V0203 16:55:11.996000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].stride()[0] 3 None
V0203 16:55:11.997000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].stride()[1] 1 None
V0203 16:55:11.997000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].storage_offset() 0 None
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 3]"):
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:812 in forward, code: x = torch.sin(x)
sin: "f32[3, 3]" = torch.ops.aten.sin.default(x); x = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:813 in forward, code: x = torch.ops.my_custom_library.custom_op(x)
custom_op: "f32[3, 3]" = torch.ops.my_custom_library.custom_op.default(sin); sin = None
# File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:814 in forward, code: x = torch.cos(x)
cos: "f32[3, 3]" = torch.ops.aten.cos.default(custom_op); custom_op = None
return (cos,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}
custom_op called!
tensor([[0.5499, 0.6889, 0.7180],
[0.5413, 1.0000, 1.0000],
[0.8332, 0.5524, 1.0000]])
请注意,在 ExportedProgram
中,自定义运算符包含在图中。
IR/分解¶
torch.export
生成的图返回一个仅包含 ATen 运算符 的图,ATen 运算符是 PyTorch 中计算的基本单元。 由于有 3000 多个 ATen 运算符,因此 export 提供了一种根据某些特征缩小图中使用的运算符集的方法,从而创建不同的 IR。
默认情况下,export 生成最通用的 IR,其中包含所有 ATen 运算符,包括功能运算符和非功能运算符。 功能运算符是不包含任何输入突变或别名的运算符。 您可以在 此处 找到所有 ATen 运算符的列表,您可以通过检查 op._schema.is_mutable
来检查运算符是否是功能运算符,例如
print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
False
True
此通用 IR 可用于在 eager PyTorch Autograd 中进行训练。 可以通过 API torch.export.export_for_training
更明确地访问此 IR,该 API 在 PyTorch 2.5 中引入,但调用 torch.export.export
应该会生成与 PyTorch 2.6 相同的图。
class DecompExample(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export_for_training(DecompExample(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph)
I0203 16:55:12.023000 634 torch/fx/experimental/symbolic_shapes.py:3192] [22/0] create_env
I0203 16:55:12.054000 634 torch/fx/experimental/symbolic_shapes.py:4547] [22/0] produce_guards
V0203 16:55:12.054000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[0] 1 None
V0203 16:55:12.054000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[1] 1 None
V0203 16:55:12.055000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[2] 3 None
V0203 16:55:12.055000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[3] 3 None
V0203 16:55:12.055000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[0] 9 None
V0203 16:55:12.056000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[1] 9 None
V0203 16:55:12.056000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[2] 3 None
V0203 16:55:12.056000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[3] 1 None
V0203 16:55:12.056000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].storage_offset() 0 None
graph():
%p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
%p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
%p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
%p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
%b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
%b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
%b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
%x : [num_users=1] = placeholder[target=x]
%conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
%add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
%batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05, True), kwargs = {})
return (batch_norm,)
然后,我们可以通过 API run_decompositions
将此导出的程序降低到仅包含功能 ATen 运算符的运算符集,该 API 将 ATen 运算符分解为分解表中指定的运算符,并使图功能化。 通过指定空集,我们仅执行功能化,而不执行任何额外的分解。 这将生成一个包含约 2000 个运算符的 IR(而不是上面的 3000 个运算符),并且是推理案例的理想选择。
ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph)
graph():
%p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
%p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
%p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
%p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
%b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
%b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
%b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
%x : [num_users=1] = placeholder[target=x]
%conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
%_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
return (getitem_3, getitem_4, add, getitem)
我们可以看到,以前的可变运算符 torch.ops.aten.add_.default
现在已被 torch.ops.aten.add.default
替换,这是一个 l 运算符。
我们还可以进一步将此导出的程序降低到仅包含 核心 ATen 运算符集 的运算符集,核心 ATen 运算符集是仅包含约 180 个运算符的集合。 此 IR 对于不想重新实现所有 ATen 运算符的后端来说是最佳的。
from torch.export import default_decompositions
core_aten_decomp_table = default_decompositions()
core_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table)
print(core_aten_ep.graph)
graph():
%p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
%p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
%p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
%p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
%b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
%b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
%b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
%x : [num_users=1] = placeholder[target=x]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
%_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%convolution, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
return (getitem_3, getitem_4, add, getitem)
我们现在看到 torch.ops.aten.conv2d.default
已被分解为 torch.ops.aten.convolution.default
。 这是因为 convolution
是一个更 “核心” 的运算符,因为诸如 conv1d
和 conv2d
之类的操作可以使用相同的 op 实现。
我们还可以指定我们自己的分解行为
my_decomp_table = torch.export.default_decompositions()
def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)
my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph)
graph():
%p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
%p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
%p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
%p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
%b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
%b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
%b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
%x : [num_users=1] = placeholder[target=x]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, 2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
%_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%mul, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
return (getitem_3, getitem_4, add, getitem)
请注意,torch.ops.aten.conv2d.default
没有被分解为 torch.ops.aten.convolution.default
,而是被分解为 torch.ops.aten.convolution.default
和 torch.ops.aten.mul.Tensor
,这与我们的自定义分解规则相匹配。
ExportDB¶
torch.export
将始终仅从 PyTorch 程序导出一个计算图。 由于此要求,将存在与 torch.export
不兼容的 Python 或 PyTorch 功能,这将要求用户重写其模型代码的某些部分。 我们在本教程的前面已经看到了这方面的示例 - 例如,使用 cond
重写 if 语句。
ExportDB 是记录 torch.export
支持和不支持的 Python/PyTorch 功能的标准参考。 它本质上是一个程序示例列表,每个示例都代表一个特定 Python/PyTorch 功能的用法及其与 torch.export
的交互。 示例还按类别标记,以便可以更轻松地搜索它们。
例如,让我们使用 ExportDB 来更好地了解 cond
运算符中 predicate 的工作原理。 我们可以查看名为 cond_predicate
的示例,该示例具有 torch.cond
标签。 示例代码如下所示
def cond_predicate(x):
"""
The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
- ``torch.Tensor`` with a single element
- boolean expression
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
pred = x.dim() > 2 and x.shape[2] > 10
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
更一般而言,当发生以下情况之一时,ExportDB 可以用作参考
在尝试
torch.export
之前,您提前知道您的模型使用了一些棘手的 Python/PyTorch 功能,并且您想知道torch.export
是否涵盖了该功能。当尝试
torch.export
时,出现故障并且不清楚如何解决。
ExportDB 不是详尽无遗的,但旨在涵盖典型 PyTorch 代码中发现的所有用例。 如果有重要的 Python/PyTorch 功能应添加到 ExportDB 或受 torch.export
支持,请随时联系我们。
运行导出的程序¶
由于 torch.export
仅是一种图捕获机制,因此 eager 地调用 torch.export
生成的 artifact 将等同于运行 eager 模块。 为了优化 Exported Program 的执行,我们可以通过 torch.compile
、AOTInductor 或 TensorRT 将此导出的 artifact 传递给 Inductor 等后端。
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.linear(x)
return x
inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))
# Run it eagerly
res = ep.module()(inp)
print(res)
# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
I0203 16:55:12.677000 634 torch/fx/experimental/symbolic_shapes.py:3192] [23/0] create_env
I0203 16:55:12.692000 634 torch/fx/experimental/symbolic_shapes.py:4547] [23/0] produce_guards
V0203 16:55:12.692000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].size()[0] 2 None
V0203 16:55:12.692000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].size()[1] 3 None
V0203 16:55:12.693000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].stride()[0] 3 None
V0203 16:55:12.693000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].stride()[1] 1 None
V0203 16:55:12.693000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].storage_offset() 0 None
tensor([[ 0.4830, -0.5149, 0.3888],
[-0.9247, 0.8408, -0.2184]], device='cuda:0',
grad_fn=<AddmmBackward0>)
I0203 16:55:12.720000 634 torch/fx/experimental/symbolic_shapes.py:3192] [24/0] create_env
I0203 16:55:13.317000 634 torch/fx/experimental/symbolic_shapes.py:4547] [24/0] produce_guards
I0203 16:55:13.341000 634 torch/fx/experimental/symbolic_shapes.py:4547] [24/0] produce_guards
V0203 16:55:13.342000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].size()[0] 2 None
V0203 16:55:13.342000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].size()[1] 3 None
V0203 16:55:13.342000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].stride()[0] 3 None
V0203 16:55:13.343000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].stride()[1] 1 None
V0203 16:55:13.343000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].storage_offset() 0 None
V0203 16:55:13.343000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[0] 3 None
V0203 16:55:13.343000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[1] 3 None
V0203 16:55:13.344000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[0] 3 None
V0203 16:55:13.344000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[1] 1 None
V0203 16:55:13.344000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].storage_offset() 0 None
V0203 16:55:13.345000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].size()[0] 3 None
V0203 16:55:13.345000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].stride()[0] 1 None
V0203 16:55:13.345000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].storage_offset() 0 None
V0203 16:55:13.346000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].size()[0] == 2
V0203 16:55:13.346000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].size()[1] == 3
V0203 16:55:13.346000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].stride()[0] == 3
V0203 16:55:13.347000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].stride()[1] == 1
V0203 16:55:13.347000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].storage_offset() == 0
V0203 16:55:13.347000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[0] == 3
V0203 16:55:13.347000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[1] == 3
V0203 16:55:13.348000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[0] == 3
V0203 16:55:13.348000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[1] == 1
V0203 16:55:13.348000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].storage_offset() == 0
V0203 16:55:13.349000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].size()[0] == 3
V0203 16:55:13.349000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].stride()[0] == 1
V0203 16:55:13.349000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].storage_offset() == 0
tensor([[ 0.4830, -0.5149, 0.3888],
[-0.9247, 0.8408, -0.2184]], device='cuda:0',
grad_fn=<CompiledFunctionBackward>)
import torch._inductor
# Note: these APIs are subject to change
# Compile the exported program to a PT2 archive using ``AOTInductor``
with torch.no_grad():
pt2_path = torch._inductor.aoti_compile_and_package(ep)
# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.ac.cn/docs/main/torch.compiler_aot_inductor.html
aoti_compiled = torch._inductor.aoti_load_package(pt2_path)
res = aoti_compiled(inp)