• 教程 >
  • torch.export 教程
快捷方式

torch.export 教程

作者:William Wen,Zhengxu Chen,Angela Yi

警告

torch.export 及其相关功能处于原型阶段,可能会发生向后兼容性中断的更改。本教程提供了截至 PyTorch 2.3 的 torch.export 用法快照。

torch.export() 是 PyTorch 2.X 中将 PyTorch 模型导出为标准化模型表示的方法,旨在在不同的(即无 Python 的)环境中运行。官方文档可以在这里找到 此处

在本教程中,您将学习如何使用 torch.export() 从 PyTorch 程序中提取 ExportedProgram(即单图表示)。我们还详细介绍了一些您可能需要进行的考虑/修改,以便使您的模型与 torch.export 兼容。

内容

基本用法

torch.export 通过跟踪目标函数(给定示例输入)来从 PyTorch 程序中提取单图表示。torch.export.export()torch.export 的主要入口点。

在本教程中,torch.exporttorch.export.export() 在实践中是同义词,尽管 torch.export 通常指的是 PyTorch 2.X 导出过程,而 torch.export.export() 通常指的是实际的函数调用。

torch.export.export() 的签名是

export(
    f: Callable,
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
    *,
    dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram

torch.export.export() 从调用 f(*args, **kwargs) 开始跟踪张量计算图,并将其包装在 ExportedProgram 中,该程序可以被序列化或稍后使用不同的输入执行。请注意,虽然输出 ExportedGraph 是可调用的,并且可以像原始输入可调用对象一样调用,但它不是 torch.nn.Module。我们将在本教程的后面详细介绍 dynamic_shapes 参数。

import torch
from torch.export import export

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x, y):
        return torch.nn.functional.relu(self.lin(x + y), inplace=True)

mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))

让我们回顾一下 ExportedProgram 中一些我们感兴趣的属性。

graph 属性是从我们导出的函数跟踪的 FX 图,即所有 PyTorch 操作的计算图。FX 图具有一些重要的属性

  • 操作是“ATen 级”操作。

  • 该图是“函数化的”,这意味着没有操作是突变的。

graph_module 属性是包装 graph 属性的 GraphModule,以便它可以作为 torch.nn.Module 运行。

print(exported_mod)
print(exported_mod.graph_module)

打印的代码显示 FX 图仅包含 ATen 级操作(例如 torch.ops.aten),并且突变已被删除。例如,突变操作 torch.nn.functional.relu(..., inplace=True) 在打印的代码中由 torch.ops.aten.relu.default 表示,它不会发生突变。原始突变 relu 操作的输入的未来使用将被替换为替换的非突变 relu 操作的额外新输出。

ExportedProgram 中其他感兴趣的属性包括

  • graph_signature – 导出图的输入、输出、参数、缓冲区等。

  • range_constraints – 约束,稍后介绍

print(exported_mod.graph_signature)

有关更多详细信息,请参阅 torch.export文档

图中断

尽管 torch.exporttorch.compile 共享组件,但 torch.export 的主要限制(尤其是在与 torch.compile 相比时)在于它不支持图中断。这是因为处理图中断涉及使用默认的 Python 评估来解释不受支持的操作,这与导出用例不兼容。因此,为了使您的模型代码与 torch.export 兼容,您需要修改代码以删除图中断。

在以下情况下需要图中断:

  • 数据相关的控制流

class Bad1(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return torch.sin(x)
        return torch.cos(x)

import traceback as tb
try:
    export(Bad1(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 使用 .data 访问张量数据

class Bad2(torch.nn.Module):
    def forward(self, x):
        x.data[0, 0] = 3
        return x

try:
    export(Bad2(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 调用不受支持的函数(例如许多内置函数)

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

try:
    export(Bad3(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 不受支持的 Python 语言特性(例如抛出异常、匹配语句)

class Bad4(torch.nn.Module):
    def forward(self, x):
        try:
            x = x + 1
            raise RuntimeError("bad")
        except:
            x = x + 2
        return x

try:
    export(Bad4(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()

非严格导出

为了跟踪程序,torch.export 使用 TorchDynamo(一个字节码分析引擎)对 Python 代码进行符号分析,并根据结果构建图形。此分析允许 torch.export 提供更强的安全保证,但并非所有 Python 代码都受支持,从而导致这些图中断。

为了解决这个问题,在 PyTorch 2.3 中,我们引入了一种新的导出模式,称为非严格模式,在这种模式下,我们使用 Python 解释器跟踪程序,该解释器完全按照在急切模式下的执行方式执行程序,从而允许我们跳过不受支持的 Python 特性。这是通过添加 strict=False 标志来完成的。

查看之前导致图中断的一些示例

  • 使用 .data 访问张量数据现在可以正常工作了

class Bad2(torch.nn.Module):
    def forward(self, x):
        x.data[0, 0] = 3
        return x

bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False)
print(bad2_nonstrict.module()(torch.ones(3, 3)))
  • 调用不受支持的函数(例如许多内置函数)会跟踪

通过,但在这种情况下,id(x) 在图中被专门化为一个常数整数。这是因为 id(x) 不是张量操作,因此该操作不会记录在图中。

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
  • 不受支持的 Python 语言特性(例如抛出异常、匹配

语句)现在也可以跟踪通过了。

class Bad4(torch.nn.Module):
    def forward(self, x):
        try:
            x = x + 1
            raise RuntimeError("bad")
        except:
            x = x + 2
        return x

bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False)
print(bad4_nonstrict.module()(torch.ones(3, 3)))

但是,仍然有一些特性需要对原始模块进行重写

控制流操作

torch.export 实际上确实支持数据相关的控制流。但这些需要使用控制流操作来表达。例如,我们可以使用 cond 操作来修复上面的控制流示例,如下所示

from functorch.experimental.control_flow import cond

class Bad1Fixed(torch.nn.Module):
    def forward(self, x):
        def true_fn(x):
            return torch.sin(x)
        def false_fn(x):
            return torch.cos(x)
        return cond(x.sum() > 0, true_fn, false_fn, [x])

exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))

cond 有一些限制需要注意

  • 谓词(即 x.sum() > 0)必须产生布尔值或单元素张量。

  • 操作数(即 [x])必须是张量。

  • 分支函数(即 true_fnfalse_fn)签名必须与操作数匹配,并且它们都必须返回具有相同元数据(例如,dtypeshape 等)的单个张量。

  • 分支函数不能修改输入或全局变量。

  • 分支函数不能访问闭包变量,除非函数是在方法的范围内定义的,则可以访问 self

有关 cond 的更多详细信息,请查看 cond 文档

约束/动态形状

操作对于不同的张量形状可以有不同的专门化/行为,因此默认情况下,torch.export 要求 ExportedProgram 的输入具有与初始 torch.export.export() 调用提供的相应示例输入相同的形状。如果我们尝试使用具有不同形状的张量在下面的示例中运行 ExportedProgram,我们会收到错误消息

class MyModule2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x, y):
        return torch.nn.functional.relu(self.lin(x + y), inplace=True)

mod2 = MyModule2()
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))

try:
    exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100))
except Exception:
    tb.print_exc()

我们可以使用 torch.export.export()dynamic_shapes 参数来放松此约束,该参数允许我们使用 torch.export.Dim文档)指定输入张量的哪些维度是动态的。

对于输入可调用的每个张量参数,我们可以指定一个从维度到 torch.export.Dim 的映射。torch.export.Dim 本质上是一个命名的符号整数,具有可选的最小和最大边界。

然后,torch.export.export()dynamic_shapes 参数的格式是从输入可调用的张量参数名称到上面描述的维度到 dim 映射的映射。如果没有为张量参数的维度提供 torch.export.Dim,则该维度被假定为静态的。

torch.export.Dim 的第一个参数是符号整数的名称,用于调试。然后我们可以指定一个可选的最小和最大边界(包含)。下面,我们展示了一个用法示例。

在下面的示例中,我们的输入 inp1 的第一维不受约束,但第二维的大小必须在 [4, 18] 的区间内。

from torch.export import Dim

inp1 = torch.randn(10, 10, 2)

class DynamicShapesExample1(torch.nn.Module):
    def forward(self, x):
        x = x[:, 2:]
        return torch.relu(x)

inp1_dim0 = Dim("inp1_dim0")
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
dynamic_shapes1 = {
    "x": {0: inp1_dim0, 1: inp1_dim1},
}

exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1)

print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2)))

try:
    exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2))
except Exception:
    tb.print_exc()

try:
    exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2))
except Exception:
    tb.print_exc()

try:
    exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3))
except Exception:
    tb.print_exc()

请注意,如果我们提供的 torch.export 的示例输入不满足 dynamic_shapes 给出的约束,那么我们会收到错误消息。

inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18)
dynamic_shapes1_bad = {
    "x": {0: inp1_dim0, 1: inp1_dim1_bad},
}

try:
    export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad)
except Exception:
    tb.print_exc()

我们可以通过使用相同的 torch.export.Dim 对象来强制不同张量维度之间的相等性,例如,在矩阵乘法中

inp2 = torch.randn(4, 8)
inp3 = torch.randn(8, 2)

class DynamicShapesExample2(torch.nn.Module):
    def forward(self, x, y):
        return x @ y

inp2_dim0 = Dim("inp2_dim0")
inner_dim = Dim("inner_dim")
inp3_dim1 = Dim("inp3_dim1")

dynamic_shapes2 = {
    "x": {0: inp2_dim0, 1: inner_dim},
    "y": {0: inner_dim, 1: inp3_dim1},
}

exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2)

print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4)))

try:
    exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2))
except Exception:
    tb.print_exc()

我们还可以用其他维度来描述一个维度。关于我们如何用其他维度来详细指定一个维度有一些限制,但通常,形式为 A * Dim + B 的表达式应该可以工作。

class DerivedDimExample1(torch.nn.Module):
    def forward(self, x, y):
        return x + y[1:]

foo = DerivedDimExample1()

x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
derived_dynamic_shapes1 = ({0: dimx}, {0: dimy})

derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1)

print(derived_dim_example1.module()(torch.randn(4), torch.randn(5)))

try:
    derived_dim_example1.module()(torch.randn(4), torch.randn(6))
except Exception:
    tb.print_exc()


class DerivedDimExample2(torch.nn.Module):
    def forward(self, z, y):
        return z[1:] + y[1::3]

foo = DerivedDimExample2()

z, y = torch.randn(4), torch.randn(10)
dx = torch.export.Dim("dx", min=3, max=6)
dz = dx + 1
dy = dx * 3 + 1
derived_dynamic_shapes2 = ({0: dz}, {0: dy})

derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2)
print(derived_dim_example2.module()(torch.randn(7), torch.randn(19)))

实际上,我们可以使用 torch.export 来指导我们哪些 dynamic_shapes 约束是必要的。我们可以通过放松所有约束来做到这一点(回想一下,如果我们不为某个维度提供约束,则默认行为是约束到示例输入的精确形状值),并让 torch.export 发生错误。

inp4 = torch.randn(8, 16)
inp5 = torch.randn(16, 32)

class DynamicShapesExample3(torch.nn.Module):
    def forward(self, x, y):
        if x.shape[0] <= 16:
            return x @ y[:, :16]
        return y

dynamic_shapes3 = {
    "x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
    "y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())},
}

try:
    export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3)
except Exception:
    tb.print_exc()

我们可以看到,错误消息为我们提供了对动态形状约束的建议修复方案。让我们遵循这些建议(确切的建议可能略有不同)。

def suggested_fixes():
    inp4_dim1 = Dim('shared_dim')
    # suggested fixes below
    inp4_dim0 = Dim('inp4_dim0', max=16)
    inp5_dim1 = Dim('inp5_dim1', min=17)
    inp5_dim0 = inp4_dim1
    # end of suggested fixes
    return {
        "x": {0: inp4_dim0, 1: inp4_dim1},
        "y": {0: inp5_dim0, 1: inp5_dim1},
    }

dynamic_shapes3_fixed = suggested_fixes()
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64)))

请注意,在上面的示例中,因为我们在 dynamic_shapes_example3 中约束了 x.shape[0] 的值,所以即使存在原始的 if 语句,导出的程序也是健全的。

如果您想了解 torch.export 生成这些约束的原因,您可以使用环境变量 TORCH_LOGS=dynamic,dynamo 重新运行脚本,或使用 torch._logging.set_logs

import logging
torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)
exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)

# reset to previous values
torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)

我们可以使用 range_constraints 字段查看 ExportedProgram 的符号形状范围。

print(exported_dynamic_shapes_example3.range_constraints)

自定义操作

torch.export 可以导出包含自定义操作符的 PyTorch 程序。

目前,为 torch.export 注册自定义操作符的步骤如下:

  • 使用 torch.library 定义自定义操作符(参考),就像任何其他自定义操作符一样。

@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(input: torch.Tensor) -> torch.Tensor:
    print("custom_op called!")
    return torch.relu(x)
  • 定义自定义操作符的 "Meta" 实现,该实现返回一个与预期输出形状相同的空张量。

@custom_op.register_fake
def custom_op_meta(x):
    return torch.empty_like(x)
  • 使用 torch.ops 从您要导出的代码中调用自定义操作符。

class CustomOpExample(torch.nn.Module):
    def forward(self, x):
        x = torch.sin(x)
        x = torch.ops.my_custom_library.custom_op(x)
        x = torch.cos(x)
        return x
  • 像以前一样导出代码。

exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
exported_custom_op_example.graph_module.print_readable()
print(exported_custom_op_example.module()(torch.randn(3, 3)))

请注意,在以上输出中,自定义操作符包含在导出的图中。当我们将导出的图作为函数调用时,会调用原始的自定义操作符,这可以通过 print 调用来证明。

如果您有一个用 C++ 实现的自定义操作符,请参阅 此文档 以使其与 torch.export 兼容。

分解

默认情况下,由 torch.export 生成的图返回一个仅包含函数式 ATen 操作符的图。此函数式 ATen 操作符集(或“opset”)包含大约 2000 个操作符,所有这些操作符都是函数式的,也就是说,它们不会修改或混淆输入。您可以在 此处 找到所有 ATen 操作符的列表,并且可以通过检查 op._schema.is_mutable 来检查操作符是否为函数式,例如:

print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)

默认情况下,您想要在其中运行导出图的环境应支持所有这大约 2000 个操作符。但是,如果您的特定环境只能支持大约 2000 个操作符中的一个子集,则可以在导出的程序上使用以下 API。

def run_decompositions(
    self: ExportedProgram,
    decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]]
) -> ExportedProgram

run_decompositions 接收一个分解表,该表是操作符到函数的映射,用于指定如何减少或将该操作符分解成其他 ATen 操作符的等效序列。

run_decompositions 的默认分解表是 核心 ATen 分解表,它将所有 ATen 操作符分解为 核心 ATen 操作符集,该集合仅包含约 180 个操作符。

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)

    def forward(self, x):
        return self.linear(x)

ep = export(M(), (torch.randn(2, 3),))
print(ep.graph)

core_ir_ep = ep.run_decompositions()
print(core_ir_ep.graph)

请注意,在运行 run_decompositions 后,torch.ops.aten.t.default 操作符(不是核心 ATen Opset 的一部分)已被替换为 torch.ops.aten.permute.default,后者是核心 ATen Opset 的一部分。

大多数 ATen 操作符已经具有分解,这些分解位于 此处。如果您想使用其中一些现有的分解函数,可以将要分解的操作符列表传递给 get_decompositions 函数,该函数将返回一个使用现有分解实现的分解表。

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)

    def forward(self, x):
        return self.linear(x)

ep = export(M(), (torch.randn(2, 3),))
print(ep.graph)

from torch._decomp import get_decompositions
decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int])
core_ir_ep = ep.run_decompositions(decomp_table)
print(core_ir_ep.graph)

如果对于您想要分解的 ATen 操作符不存在现有的分解函数,请随时向 PyTorch 发送一个拉取请求来实现分解!

ExportDB

torch.export 永远只会从 PyTorch 程序导出单个计算图。由于此要求,某些 Python 或 PyTorch 功能与 torch.export 不兼容,这需要用户重写其模型代码的部分内容。我们在本教程的前面已经看到了这方面的例子——例如,使用 cond 重写 if 语句。

ExportDB 是记录 torch.export 支持和不支持的 Python/PyTorch 功能的标准参考。它本质上是一个程序样本列表,每个样本都代表一个特定 Python/PyTorch 功能的使用及其与 torch.export 的交互。示例也按类别进行标记,以便更容易搜索。

例如,让我们使用 ExportDB 来更好地理解 cond 操作符中的谓词如何工作。我们可以查看名为 cond_predicate 的示例,它具有 torch.cond 标记。示例代码如下所示:

def cond_predicate(x):
    """
    The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
    - ``torch.Tensor`` with a single element
    - boolean expression
    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """
    pred = x.dim() > 2 and x.shape[2] > 10
    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

更一般地,当以下情况之一发生时,可以使用 ExportDB 作为参考:

  1. 在尝试 torch.export 之前,您提前知道您的模型使用了一些棘手的 Python/PyTorch 功能,并且您想知道 torch.export 是否涵盖了该功能。

  2. 在尝试 torch.export 时,出现故障,并且不清楚如何解决。

ExportDB 并不详尽,但旨在涵盖典型 PyTorch 代码中发现的所有用例。如果您有任何重要的 Python/PyTorch 功能应添加到 ExportDB 或由 torch.export 支持,请随时与我们联系。

运行导出的程序

由于 torch.export 只是一个图捕获机制,因此急切地调用由 torch.export 生成的工件将等效于运行急切模块。为了优化导出程序的执行,我们可以通过 torch.compile 将此导出工件传递给后端,例如 Inductor、AOTInductorTensorRT

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.linear(x)
        return x

inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))

# Run it eagerly
res = ep.module()(inp)
print(res)

# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
import torch._export
import torch._inductor

# Note: these APIs are subject to change
# Compile the exported program to a .so using ``AOTInductor``
with torch.no_grad():
so_path = torch._inductor.aot_compile(ep.module(), [inp])

# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.ac.cn/docs/main/torch.compiler_aot_inductor.html
res = torch._export.aot_load(so_path, device="cuda")(inp)

结论

我们介绍了 torch.export,这是 PyTorch 2.X 中从 PyTorch 程序导出单个计算图的新方法。特别是,我们演示了为了导出图而需要进行的几个代码修改和注意事项(控制流操作、约束等)。

脚本的总运行时间:(0 分 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源