注意
点击 此处 下载完整的示例代码
torch.export 教程¶
作者:William Wen,Zhengxu Chen,Angela Yi
警告
torch.export
及其相关功能处于原型阶段,可能会发生向后兼容性中断的更改。本教程提供了截至 PyTorch 2.3 的 torch.export
用法快照。
torch.export()
是 PyTorch 2.X 中将 PyTorch 模型导出为标准化模型表示的方法,旨在在不同的(即无 Python 的)环境中运行。官方文档可以在这里找到 此处。
在本教程中,您将学习如何使用 torch.export()
从 PyTorch 程序中提取 ExportedProgram
(即单图表示)。我们还详细介绍了一些您可能需要进行的考虑/修改,以便使您的模型与 torch.export
兼容。
内容
基本用法¶
torch.export
通过跟踪目标函数(给定示例输入)来从 PyTorch 程序中提取单图表示。torch.export.export()
是 torch.export
的主要入口点。
在本教程中,torch.export
和 torch.export.export()
在实践中是同义词,尽管 torch.export
通常指的是 PyTorch 2.X 导出过程,而 torch.export.export()
通常指的是实际的函数调用。
torch.export.export()
的签名是
export(
f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram
torch.export.export()
从调用 f(*args, **kwargs)
开始跟踪张量计算图,并将其包装在 ExportedProgram
中,该程序可以被序列化或稍后使用不同的输入执行。请注意,虽然输出 ExportedGraph
是可调用的,并且可以像原始输入可调用对象一样调用,但它不是 torch.nn.Module
。我们将在本教程的后面详细介绍 dynamic_shapes
参数。
import torch
from torch.export import export
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x, y):
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))
让我们回顾一下 ExportedProgram
中一些我们感兴趣的属性。
graph
属性是从我们导出的函数跟踪的 FX 图,即所有 PyTorch 操作的计算图。FX 图具有一些重要的属性
操作是“ATen 级”操作。
该图是“函数化的”,这意味着没有操作是突变的。
graph_module
属性是包装 graph
属性的 GraphModule
,以便它可以作为 torch.nn.Module
运行。
print(exported_mod)
print(exported_mod.graph_module)
打印的代码显示 FX 图仅包含 ATen 级操作(例如 torch.ops.aten
),并且突变已被删除。例如,突变操作 torch.nn.functional.relu(..., inplace=True)
在打印的代码中由 torch.ops.aten.relu.default
表示,它不会发生突变。原始突变 relu
操作的输入的未来使用将被替换为替换的非突变 relu
操作的额外新输出。
ExportedProgram
中其他感兴趣的属性包括
graph_signature
– 导出图的输入、输出、参数、缓冲区等。range_constraints
– 约束,稍后介绍
print(exported_mod.graph_signature)
有关更多详细信息,请参阅 torch.export
的 文档。
图中断¶
尽管 torch.export
与 torch.compile
共享组件,但 torch.export
的主要限制(尤其是在与 torch.compile
相比时)在于它不支持图中断。这是因为处理图中断涉及使用默认的 Python 评估来解释不受支持的操作,这与导出用例不兼容。因此,为了使您的模型代码与 torch.export
兼容,您需要修改代码以删除图中断。
在以下情况下需要图中断:
数据相关的控制流
class Bad1(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return torch.sin(x)
return torch.cos(x)
import traceback as tb
try:
export(Bad1(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
使用
.data
访问张量数据
class Bad2(torch.nn.Module):
def forward(self, x):
x.data[0, 0] = 3
return x
try:
export(Bad2(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
调用不受支持的函数(例如许多内置函数)
class Bad3(torch.nn.Module):
def forward(self, x):
x = x + 1
return x + id(x)
try:
export(Bad3(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
不受支持的 Python 语言特性(例如抛出异常、匹配语句)
class Bad4(torch.nn.Module):
def forward(self, x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x
try:
export(Bad4(), (torch.randn(3, 3),))
except Exception:
tb.print_exc()
非严格导出¶
为了跟踪程序,torch.export
使用 TorchDynamo(一个字节码分析引擎)对 Python 代码进行符号分析,并根据结果构建图形。此分析允许 torch.export
提供更强的安全保证,但并非所有 Python 代码都受支持,从而导致这些图中断。
为了解决这个问题,在 PyTorch 2.3 中,我们引入了一种新的导出模式,称为非严格模式,在这种模式下,我们使用 Python 解释器跟踪程序,该解释器完全按照在急切模式下的执行方式执行程序,从而允许我们跳过不受支持的 Python 特性。这是通过添加 strict=False
标志来完成的。
查看之前导致图中断的一些示例
使用
.data
访问张量数据现在可以正常工作了
class Bad2(torch.nn.Module):
def forward(self, x):
x.data[0, 0] = 3
return x
bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False)
print(bad2_nonstrict.module()(torch.ones(3, 3)))
调用不受支持的函数(例如许多内置函数)会跟踪
通过,但在这种情况下,id(x)
在图中被专门化为一个常数整数。这是因为 id(x)
不是张量操作,因此该操作不会记录在图中。
class Bad3(torch.nn.Module):
def forward(self, x):
x = x + 1
return x + id(x)
bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
不受支持的 Python 语言特性(例如抛出异常、匹配
语句)现在也可以跟踪通过了。
class Bad4(torch.nn.Module):
def forward(self, x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x
bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False)
print(bad4_nonstrict.module()(torch.ones(3, 3)))
但是,仍然有一些特性需要对原始模块进行重写
控制流操作¶
torch.export
实际上确实支持数据相关的控制流。但这些需要使用控制流操作来表达。例如,我们可以使用 cond
操作来修复上面的控制流示例,如下所示
from functorch.experimental.control_flow import cond
class Bad1Fixed(torch.nn.Module):
def forward(self, x):
def true_fn(x):
return torch.sin(x)
def false_fn(x):
return torch.cos(x)
return cond(x.sum() > 0, true_fn, false_fn, [x])
exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))
对 cond
有一些限制需要注意
谓词(即
x.sum() > 0
)必须产生布尔值或单元素张量。操作数(即
[x]
)必须是张量。分支函数(即
true_fn
和false_fn
)签名必须与操作数匹配,并且它们都必须返回具有相同元数据(例如,dtype
、shape
等)的单个张量。分支函数不能修改输入或全局变量。
分支函数不能访问闭包变量,除非函数是在方法的范围内定义的,则可以访问
self
。
有关 cond
的更多详细信息,请查看 cond 文档。
约束/动态形状¶
操作对于不同的张量形状可以有不同的专门化/行为,因此默认情况下,torch.export
要求 ExportedProgram
的输入具有与初始 torch.export.export()
调用提供的相应示例输入相同的形状。如果我们尝试使用具有不同形状的张量在下面的示例中运行 ExportedProgram
,我们会收到错误消息
class MyModule2(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x, y):
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
mod2 = MyModule2()
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))
try:
exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100))
except Exception:
tb.print_exc()
我们可以使用 torch.export.export()
的 dynamic_shapes
参数来放松此约束,该参数允许我们使用 torch.export.Dim
(文档)指定输入张量的哪些维度是动态的。
对于输入可调用的每个张量参数,我们可以指定一个从维度到 torch.export.Dim
的映射。torch.export.Dim
本质上是一个命名的符号整数,具有可选的最小和最大边界。
然后,torch.export.export()
的 dynamic_shapes
参数的格式是从输入可调用的张量参数名称到上面描述的维度到 dim 映射的映射。如果没有为张量参数的维度提供 torch.export.Dim
,则该维度被假定为静态的。
torch.export.Dim
的第一个参数是符号整数的名称,用于调试。然后我们可以指定一个可选的最小和最大边界(包含)。下面,我们展示了一个用法示例。
在下面的示例中,我们的输入 inp1
的第一维不受约束,但第二维的大小必须在 [4, 18] 的区间内。
from torch.export import Dim
inp1 = torch.randn(10, 10, 2)
class DynamicShapesExample1(torch.nn.Module):
def forward(self, x):
x = x[:, 2:]
return torch.relu(x)
inp1_dim0 = Dim("inp1_dim0")
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
dynamic_shapes1 = {
"x": {0: inp1_dim0, 1: inp1_dim1},
}
exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1)
print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2)))
try:
exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2))
except Exception:
tb.print_exc()
try:
exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2))
except Exception:
tb.print_exc()
try:
exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3))
except Exception:
tb.print_exc()
请注意,如果我们提供的 torch.export
的示例输入不满足 dynamic_shapes
给出的约束,那么我们会收到错误消息。
inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18)
dynamic_shapes1_bad = {
"x": {0: inp1_dim0, 1: inp1_dim1_bad},
}
try:
export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad)
except Exception:
tb.print_exc()
我们可以通过使用相同的 torch.export.Dim
对象来强制不同张量维度之间的相等性,例如,在矩阵乘法中
inp2 = torch.randn(4, 8)
inp3 = torch.randn(8, 2)
class DynamicShapesExample2(torch.nn.Module):
def forward(self, x, y):
return x @ y
inp2_dim0 = Dim("inp2_dim0")
inner_dim = Dim("inner_dim")
inp3_dim1 = Dim("inp3_dim1")
dynamic_shapes2 = {
"x": {0: inp2_dim0, 1: inner_dim},
"y": {0: inner_dim, 1: inp3_dim1},
}
exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2)
print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4)))
try:
exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2))
except Exception:
tb.print_exc()
我们还可以用其他维度来描述一个维度。关于我们如何用其他维度来详细指定一个维度有一些限制,但通常,形式为 A * Dim + B
的表达式应该可以工作。
class DerivedDimExample1(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
foo = DerivedDimExample1()
x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
derived_dynamic_shapes1 = ({0: dimx}, {0: dimy})
derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1)
print(derived_dim_example1.module()(torch.randn(4), torch.randn(5)))
try:
derived_dim_example1.module()(torch.randn(4), torch.randn(6))
except Exception:
tb.print_exc()
class DerivedDimExample2(torch.nn.Module):
def forward(self, z, y):
return z[1:] + y[1::3]
foo = DerivedDimExample2()
z, y = torch.randn(4), torch.randn(10)
dx = torch.export.Dim("dx", min=3, max=6)
dz = dx + 1
dy = dx * 3 + 1
derived_dynamic_shapes2 = ({0: dz}, {0: dy})
derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2)
print(derived_dim_example2.module()(torch.randn(7), torch.randn(19)))
实际上,我们可以使用 torch.export
来指导我们哪些 dynamic_shapes
约束是必要的。我们可以通过放松所有约束来做到这一点(回想一下,如果我们不为某个维度提供约束,则默认行为是约束到示例输入的精确形状值),并让 torch.export
发生错误。
inp4 = torch.randn(8, 16)
inp5 = torch.randn(16, 32)
class DynamicShapesExample3(torch.nn.Module):
def forward(self, x, y):
if x.shape[0] <= 16:
return x @ y[:, :16]
return y
dynamic_shapes3 = {
"x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
"y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())},
}
try:
export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3)
except Exception:
tb.print_exc()
我们可以看到,错误消息为我们提供了对动态形状约束的建议修复方案。让我们遵循这些建议(确切的建议可能略有不同)。
def suggested_fixes():
inp4_dim1 = Dim('shared_dim')
# suggested fixes below
inp4_dim0 = Dim('inp4_dim0', max=16)
inp5_dim1 = Dim('inp5_dim1', min=17)
inp5_dim0 = inp4_dim1
# end of suggested fixes
return {
"x": {0: inp4_dim0, 1: inp4_dim1},
"y": {0: inp5_dim0, 1: inp5_dim1},
}
dynamic_shapes3_fixed = suggested_fixes()
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64)))
请注意,在上面的示例中,因为我们在 dynamic_shapes_example3
中约束了 x.shape[0]
的值,所以即使存在原始的 if
语句,导出的程序也是健全的。
如果您想了解 torch.export
生成这些约束的原因,您可以使用环境变量 TORCH_LOGS=dynamic,dynamo
重新运行脚本,或使用 torch._logging.set_logs
。
import logging
torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
# reset to previous values
torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)
我们可以使用 range_constraints
字段查看 ExportedProgram
的符号形状范围。
print(exported_dynamic_shapes_example3.range_constraints)
自定义操作¶
torch.export
可以导出包含自定义操作符的 PyTorch 程序。
目前,为 torch.export
注册自定义操作符的步骤如下:
使用
torch.library
定义自定义操作符(参考),就像任何其他自定义操作符一样。
@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(input: torch.Tensor) -> torch.Tensor:
print("custom_op called!")
return torch.relu(x)
定义自定义操作符的
"Meta"
实现,该实现返回一个与预期输出形状相同的空张量。
@custom_op.register_fake
def custom_op_meta(x):
return torch.empty_like(x)
使用
torch.ops
从您要导出的代码中调用自定义操作符。
class CustomOpExample(torch.nn.Module):
def forward(self, x):
x = torch.sin(x)
x = torch.ops.my_custom_library.custom_op(x)
x = torch.cos(x)
return x
像以前一样导出代码。
exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
exported_custom_op_example.graph_module.print_readable()
print(exported_custom_op_example.module()(torch.randn(3, 3)))
请注意,在以上输出中,自定义操作符包含在导出的图中。当我们将导出的图作为函数调用时,会调用原始的自定义操作符,这可以通过 print
调用来证明。
如果您有一个用 C++ 实现的自定义操作符,请参阅 此文档 以使其与 torch.export
兼容。
分解¶
默认情况下,由 torch.export
生成的图返回一个仅包含函数式 ATen 操作符的图。此函数式 ATen 操作符集(或“opset”)包含大约 2000 个操作符,所有这些操作符都是函数式的,也就是说,它们不会修改或混淆输入。您可以在 此处 找到所有 ATen 操作符的列表,并且可以通过检查 op._schema.is_mutable
来检查操作符是否为函数式,例如:
print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
默认情况下,您想要在其中运行导出图的环境应支持所有这大约 2000 个操作符。但是,如果您的特定环境只能支持大约 2000 个操作符中的一个子集,则可以在导出的程序上使用以下 API。
def run_decompositions(
self: ExportedProgram,
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]]
) -> ExportedProgram
run_decompositions
接收一个分解表,该表是操作符到函数的映射,用于指定如何减少或将该操作符分解成其他 ATen 操作符的等效序列。
run_decompositions
的默认分解表是 核心 ATen 分解表,它将所有 ATen 操作符分解为 核心 ATen 操作符集,该集合仅包含约 180 个操作符。
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
def forward(self, x):
return self.linear(x)
ep = export(M(), (torch.randn(2, 3),))
print(ep.graph)
core_ir_ep = ep.run_decompositions()
print(core_ir_ep.graph)
请注意,在运行 run_decompositions
后,torch.ops.aten.t.default
操作符(不是核心 ATen Opset 的一部分)已被替换为 torch.ops.aten.permute.default
,后者是核心 ATen Opset 的一部分。
大多数 ATen 操作符已经具有分解,这些分解位于 此处。如果您想使用其中一些现有的分解函数,可以将要分解的操作符列表传递给 get_decompositions 函数,该函数将返回一个使用现有分解实现的分解表。
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
def forward(self, x):
return self.linear(x)
ep = export(M(), (torch.randn(2, 3),))
print(ep.graph)
from torch._decomp import get_decompositions
decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int])
core_ir_ep = ep.run_decompositions(decomp_table)
print(core_ir_ep.graph)
如果对于您想要分解的 ATen 操作符不存在现有的分解函数,请随时向 PyTorch 发送一个拉取请求来实现分解!
ExportDB¶
torch.export
永远只会从 PyTorch 程序导出单个计算图。由于此要求,某些 Python 或 PyTorch 功能与 torch.export
不兼容,这需要用户重写其模型代码的部分内容。我们在本教程的前面已经看到了这方面的例子——例如,使用 cond
重写 if 语句。
ExportDB 是记录 torch.export
支持和不支持的 Python/PyTorch 功能的标准参考。它本质上是一个程序样本列表,每个样本都代表一个特定 Python/PyTorch 功能的使用及其与 torch.export
的交互。示例也按类别进行标记,以便更容易搜索。
例如,让我们使用 ExportDB 来更好地理解 cond
操作符中的谓词如何工作。我们可以查看名为 cond_predicate
的示例,它具有 torch.cond
标记。示例代码如下所示:
def cond_predicate(x):
"""
The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
- ``torch.Tensor`` with a single element
- boolean expression
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
pred = x.dim() > 2 and x.shape[2] > 10
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
更一般地,当以下情况之一发生时,可以使用 ExportDB 作为参考:
在尝试
torch.export
之前,您提前知道您的模型使用了一些棘手的 Python/PyTorch 功能,并且您想知道torch.export
是否涵盖了该功能。在尝试
torch.export
时,出现故障,并且不清楚如何解决。
ExportDB 并不详尽,但旨在涵盖典型 PyTorch 代码中发现的所有用例。如果您有任何重要的 Python/PyTorch 功能应添加到 ExportDB 或由 torch.export
支持,请随时与我们联系。
运行导出的程序¶
由于 torch.export
只是一个图捕获机制,因此急切地调用由 torch.export
生成的工件将等效于运行急切模块。为了优化导出程序的执行,我们可以通过 torch.compile
将此导出工件传递给后端,例如 Inductor、AOTInductor 或 TensorRT。
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.linear(x)
return x
inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))
# Run it eagerly
res = ep.module()(inp)
print(res)
# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
import torch._export
import torch._inductor
# Note: these APIs are subject to change
# Compile the exported program to a .so using ``AOTInductor``
with torch.no_grad():
so_path = torch._inductor.aot_compile(ep.module(), [inp])
# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.ac.cn/docs/main/torch.compiler_aot_inductor.html
res = torch._export.aot_load(so_path, device="cuda")(inp)