快捷方式

torch.export

警告

此功能是一个正在积极开发的原型,未来将会有重大更改。

概述

torch.export.export() 采用任意 Python 可调用对象(torch.nn.Module、函数或方法),并生成一个仅表示函数的张量计算的跟踪图,采用即时 (AOT) 方式,该图随后可以使用不同的输出执行或序列化。

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        a = torch.sin(x)
        b = torch.cos(y)
        return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: torch.export.ExportedProgram = export(
    Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
            # code: a = torch.sin(x)
            sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);

            # code: b = torch.cos(y)
            cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);

            # code: return a + b
            add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
            return (add,)

    Graph signature: ExportGraphSignature(
        parameters=[],
        buffers=[],
        user_inputs=['arg0_1', 'arg1_1'],
        user_outputs=['add'],
        inputs_to_parameters={},
        inputs_to_buffers={},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {}

torch.export 生成一个具有以下不变量的干净中间表示 (IR)。有关 IR 的更多规范,请参见 此处

  • 健全性:保证它是原始程序的健全表示,并保持原始程序的相同调用约定。

  • 规范化:图中没有 Python 语义。原始程序中的子模块被内联以形成一个完全扁平的计算图。

  • 图属性:该图是纯函数的,这意味着它不包含具有副作用的操作,例如突变或别名。它不会改变任何中间值、参数或缓冲区。

  • 元数据:该图包含在跟踪期间捕获的元数据,例如来自用户代码的堆栈跟踪。

在底层,torch.export 利用了以下最新技术

  • TorchDynamo (torch._dynamo) 是一种内部 API,它使用称为 Frame Evaluation API 的 CPython 特性来安全地跟踪 PyTorch 图。这提供了大幅改进的图捕获体验,只需更少的重写即可完全跟踪 PyTorch 代码。

  • AOT Autograd 提供了一个功能化的 PyTorch 图,并确保该图分解/降低到 ATen 运算符集。

  • Torch FX (torch.fx) 是该图的基础表示,允许灵活的基于 Python 的转换。

现有框架

torch.compile() 也利用了与 torch.export 相同的 PT2 堆栈,但略有不同

  • JIT 与 AOTtorch.compile() 是一个 JIT 编译器,而它不打算用于在部署之外生成已编译的工件。

  • 部分与完全图捕获:当 torch.compile() 遇到模型中无法跟踪的部分时,它将“图中断”,并回退到在急切 Python 运行时中运行程序。相比之下,torch.export 旨在获得 PyTorch 模型的完整图表示,因此在遇到无法跟踪的内容时会出错。由于 torch.export 会生成与任何 Python 特性或运行时分离的完整图,因此该图可以保存在不同的环境和语言中并运行。

  • 可用性权衡:由于 torch.compile() 能够在遇到无法跟踪的内容时回退到 Python 运行时,因此它更加灵活。torch.export 相反,要求用户提供更多信息或重写其代码以使其可跟踪。

torch.fx.symbolic_trace() 相比,torch.export 使用 TorchDynamo 进行跟踪,它在 Python 字节码级别运行,这使其能够跟踪不受 Python 运算符重载支持限制的任意 Python 构造。此外,torch.export 会细粒度地跟踪张量元数据,以便对张量形状等内容的条件不会导致跟踪失败。一般来说,torch.export 预计可以在更多用户程序上运行,并生成更低级别的图(在 torch.ops.aten 运算符级别)。请注意,用户仍然可以使用 torch.fx.symbolic_trace() 作为 torch.export 之前的预处理步骤。

torch.jit.script() 相比,torch.export 不会捕获 Python 控制流或数据结构,但它支持比 TorchScript 更多的 Python 语言特性(因为它更容易对 Python 字节码进行全面覆盖)。生成的图更简单,并且只有直线控制流(显式控制流操作符除外)。

torch.jit.trace() 相比,torch.export 是健全的:它能够跟踪对大小执行整数计算的代码,并记录显示特定跟踪对其他输入有效的必要的所有附加条件。

导出 PyTorch 模型

示例

主要入口点是通过 torch.export.export(),它采用一个可调用对象(torch.nn.Module、函数或方法)和示例输入,并将计算图捕获到 torch.export.ExportedProgram 中。一个示例

import torch
from torch.export import export

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):

            # code: a = self.conv(x)
            convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
                arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
            );

            # code: a.add_(constant)
            add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);

            # code: return self.maxpool(self.relu(a))
            relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
            max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
                relu, [3, 3], [3, 3]
            );
            getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
            return (getitem,)

    Graph signature: ExportGraphSignature(
        parameters=['L__self___conv.weight', 'L__self___conv.bias'],
        buffers=[],
        user_inputs=['arg2_1', 'arg3_1'],
        user_outputs=['getitem'],
        inputs_to_parameters={
            'arg0_1': 'L__self___conv.weight',
            'arg1_1': 'L__self___conv.bias',
        },
        inputs_to_buffers={},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {}

检查 ExportedProgram,我们可以注意到以下内容

  • torch.fx.Graph 包含原始程序的计算图,以及原始代码的记录,以便于调试。

  • 该图仅包含 此处 找到的 torch.ops.aten 操作符和自定义操作符,并且是完全可用的,没有任何就地操作符,例如 torch.add_

  • 参数(权重和偏置到卷积)被提升为图表的输入,导致图中没有 get_attr 节点,这些节点之前存在于 torch.fx.symbolic_trace() 的结果中。

  • torch.export.ExportGraphSignature 对输入和输出签名进行建模,并指定哪些输入是参数。

  • 图中每个节点产生的张量的结果形状和数据类型会被记录下来。例如, convolution 节点将产生一个数据类型为 torch.float32 且形状为 (1, 16, 256, 256) 的张量。

非严格导出

在 PyTorch 2.3 中,我们引入了一种称为非严格模式的新跟踪模式。它仍在进行强化,因此如果你遇到任何问题,请使用“oncall: export”标签将它们提交到 Github。

非严格模式下,我们使用 Python 解释器跟踪程序。你的代码将完全按照热切模式执行;唯一的区别是所有张量对象都将被 ProxyTensors 替换,它会将所有操作记录到一个图中。

严格模式下(当前为默认模式),我们首先使用字节码分析引擎 TorchDynamo 跟踪程序。TorchDynamo 实际上不会执行你的 Python 代码。相反,它会对其进行符号分析,并根据结果构建一个图。此分析允许 torch.export 提供更强的安全性保证,但并非所有 Python 代码都受支持。

可能想要使用非严格模式的一个案例是,如果你遇到了一个可能无法轻松解决的 TorchDynamo 不支持的功能,并且你知道 python 代码对于计算并不是必需的。例如

import contextlib
import torch

class ContextManager():
    def __init__(self):
        self.count = 0
    def __enter__(self):
        self.count += 1
    def __exit__(self, exc_type, exc_value, traceback):
        self.count -= 1

class M(torch.nn.Module):
    def forward(self, x):
        with ContextManager():
            return x.sin() + x.cos()

export(M(), (torch.ones(3, 3),), strict=False)  # Non-strict traces successfully
export(M(), (torch.ones(3, 3),))  # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager

在此示例中,使用非严格模式(通过 strict=False 标志)的第一个调用成功跟踪,而使用严格模式(默认)的第二个调用失败,其中 TorchDynamo 无法支持上下文管理器。一种选择是重写代码(参见 torch.expot 的限制),但由于上下文管理器不影响模型中的张量计算,因此我们可以使用非严格模式的结果。

表达动态性

默认情况下,torch.export 将跟踪程序,假设所有输入形状都是静态的,并将导出的程序专门用于这些维度。但是,某些维度(例如批处理维度)可能是动态的,并且会因运行而异。此类维度必须使用 torch.export.Dim() API 创建并通过 torch.export.export() 将其传递到 dynamic_shapes 参数中。一个示例

import torch
from torch.export import Dim, export

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):

            # code: out1 = self.branch1(x1)
            permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
            addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
            relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);

            # code: out2 = self.branch2(x2)
            permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
            addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
            relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1);  addmm_1 = None

            # code: return (out1 + self.buffer, out2)
            add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
            return (add, relu_1)

    Graph signature: ExportGraphSignature(
        parameters=[
            'branch1.0.weight',
            'branch1.0.bias',
            'branch2.0.weight',
            'branch2.0.bias',
        ],
        buffers=['L__self___buffer'],
        user_inputs=['arg5_1', 'arg6_1'],
        user_outputs=['add', 'relu_1'],
        inputs_to_parameters={
            'arg0_1': 'branch1.0.weight',
            'arg1_1': 'branch1.0.bias',
            'arg2_1': 'branch2.0.weight',
            'arg3_1': 'branch2.0.bias',
        },
        inputs_to_buffers={'arg4_1': 'L__self___buffer'},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}

需要注意的其他一些事项

  • 通过 torch.export.Dim() API 和 dynamic_shapes 参数,我们指定每个输入的第一个维度为动态。查看输入 arg5_1arg6_1,它们的符号形状为 (s0, 64) 和 (s0, 128),而不是我们作为示例输入传入的 (32, 64) 和 (32, 128) 形状的张量。 s0 是表示此维度可以是值范围的符号。

  • exported_program.range_constraints 描述图中出现的每个符号的范围。在这种情况下,我们看到 s0 的范围是 [2, inf]。由于此处难以解释的技术原因,因此假定它们不是 0 或 1。这不是错误,也不一定意味着导出的程序不适用于维度 0 或 1。请参阅 0/1 特化问题 以深入讨论此主题。

我们还可以指定输入形状之间更具表现力的关系,例如一对形状可能相差一个,一个形状可能是另一个形状的两倍,或者一个形状是偶数。一个示例

class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y[1:]

x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1

exported_program = torch.export.export(
    M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"):
        # code: return x + y[1:]
        slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807);  arg1_1 = None
        add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1);  arg0_1 = slice_1 = None
        return (add,)

Graph signature: ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]
)
Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}

需要注意的事项

  • 通过为第一个输入指定 {0: dimx},我们看到第一个输入的最终形状现在是动态的,为 [s0]。现在通过为第二个输入指定 {0: dimy},我们看到第二个输入的最终形状也是动态的。但是,由于我们表示 dimy = dimx + 1,而不是 arg1_1 的形状包含一个新符号,我们看到它现在使用 arg0_1 中使用的相同符号 s0 来表示。我们可以看到关系 dimy = dimx + 1 通过 s0 + 1 显示出来。

  • 查看范围约束,我们看到 s0 具有最初指定的范围 [3, 6],我们可以看到 s0 + 1 的已解决范围为 [4, 7]。

序列化

为了保存 ExportedProgram,用户可以使用 torch.export.save()torch.export.load() API。惯例是使用 .pt2 文件扩展名保存 ExportedProgram

一个示例

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), torch.randn(5))

torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')

专业化

理解 torch.export 行为的一个关键概念是静态值和动态值之间的差异。

动态值是可以随运行而改变的值。它们的行为就像 Python 函数的普通参数一样,您可以为参数传递不同的值,并期望您的函数执行正确操作。张量数据被视为动态。

静态值是在导出时固定的值,并且不能在导出程序的执行之间改变。当在跟踪过程中遇到该值时,导出器会将其视为常量并将其硬编码到图中。

当执行操作(例如 x + y)并且所有输入均为静态时,操作的输出将直接硬编码到图中,并且操作将不会显示(即它将被常量折叠)。

当某个值已硬编码到图中时,我们说该图已专门针对该值。

以下值是静态值

输入张量形状

默认情况下,torch.export 将跟踪针对输入张量形状进行专门化的程序,除非通过 torch.exportdynamic_shapes 参数将维度指定为动态。这意味着如果存在与形状相关的控制流,torch.export 将针对使用给定示例输入所采用的分支进行专门化。例如

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x):
        if x.shape[0] > 5:
            return x + 1
        else:
            return x - 1

example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 2]):
            add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
            return (add,)

条件 (x.shape[0] > 5) 未出现在 ExportedProgram 中,因为示例输入具有 (10, 2) 的静态形状。由于 torch.export 针对输入的静态形状进行专门化,因此永远不会达到 else 分支 (x - 1)。要根据跟踪图中张量的形状保留动态分支行为,需要使用 torch.export.dynamic_dim() 指定输入张量 (x.shape[0]) 的维度为动态,并且需要重写源代码。

请注意,作为模块状态一部分的张量(例如参数和缓冲区)始终具有静态形状。

Python 原语

torch.export 还会针对 Python 原语进行专门化,例如 intfloatboolstr。但是它们确实具有动态变体,例如 SymIntSymFloatSymBool

例如

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, const: int, times: int):
        for i in range(times):
            x = x + const
        return x

example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
            add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
            add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
            add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
            return (add_2,)

由于整数是专门化的,torch.ops.aten.add.Tensor 操作全部使用硬编码常量 1 计算,而不是 arg1_1。如果用户在运行时为 arg1_1 传递一个不同的值(例如 2),而不是导出时使用的 1,这将导致错误。此外,for 循环中使用的 times 迭代器也通过 3 次重复的 torch.ops.aten.add.Tensor 调用“内联”到图中,并且输入 arg2_1 从未使用过。

Python 容器

Python 容器(ListDictNamedTuple 等)被认为具有静态结构。

torch.export 的限制

图中断

由于 torch.export 是一个从 PyTorch 程序捕获计算图的一次性过程,因此最终可能会遇到程序中无法跟踪的部分,因为几乎不可能支持跟踪所有 PyTorch 和 Python 特性。在 torch.compile 的情况下,不受支持的操作将导致“图中断”,并且不受支持的操作将使用默认 Python 评估运行。相反,torch.export 将要求用户提供附加信息或重写其代码的某些部分以使其可跟踪。由于跟踪基于 TorchDynamo(在 Python 字节码级别进行评估),因此与以前的跟踪框架相比,所需的重写将大大减少。

遇到图中断时,ExportDB 是了解受支持和不受支持的程序类型以及重写程序以使其可跟踪的方法的宝贵资源。

一种解决此图中断问题的选择是使用非严格导出

数据/形状相关的控制流

当形状未专门化时,在数据相关的控制流(if x.shape[0] > 2)中也可能遇到图中断,因为跟踪编译器不可能处理,除非为组合爆炸数量的路径生成代码。在这种情况下,用户需要使用特殊控制流运算符重写其代码。目前,我们支持 torch.cond 来表示类似 if-else 的控制流(更多内容即将推出!)。

操作符的元内核缺失

在追踪时,所有操作符都需要一个 META 实现(或“元内核”)。这用于推理此操作符的输入/输出形状。

要为 C++ 自定义操作符注册元内核,请参阅此文档

用于为用 python 实现的自定义操作注册自定义元内核的官方 API 目前正在开发中。在最终 API 正在完善时,你可以参阅此处的文档。

在不幸的情况下,如果你的模型使用尚未有元内核实现的 ATen 操作符,请提交问题。

API 参考

torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[source]

export()采用任意 Python 可调用对象(nn.Module、函数或方法)以及示例输入,并生成一个追踪图,该图仅表示函数的张量计算,采用提前(AOT)的方式,随后可以使用不同的输入或序列化执行该图。追踪图 (1) 在函数 ATen 操作符集中生成标准化操作符(以及任何用户指定的自定义操作符),(2) 消除了所有 Python 控制流和数据结构(有某些例外),以及 (3) 记录了形状约束集,以表明此标准化和控制流消除对于未来输入是合理的。

合理性保证

在跟踪时,export() 会注意用户程序和底层 PyTorch 算子内核做出的与形状相关的假设。只有在这些假设成立时,输出的 ExportedProgram 才被认为有效。

跟踪会对输入张量的形状(而非值)做出假设。此类假设必须在图捕获时间进行验证,以便 export() 能够成功。具体而言

  • 对输入张量静态形状的假设会自动验证,无需额外工作。

  • 对输入张量动态形状的假设需要使用 Dim() API 显式指定,以构建动态维度,并通过 dynamic_shapes 参数将其与示例输入关联起来。

如果无法验证任何假设,则会引发致命错误。发生这种情况时,错误消息将包括验证假设所需的建议修复。例如,export() 可能建议对动态维度 dim0_x 的定义进行以下修复,比如出现在与输入 x 关联的形状中,该形状之前定义为 Dim("dim0_x")

dim = Dim("dim0_x", max=5)

此示例表示,生成的代码要求输入 x 的维度 0 小于或等于 5 才有效。你可以检查动态维度定义的建议修复,然后将它们逐字复制到你的代码中,而无需更改 export() 调用的 dynamic_shapes 参数。

参数
  • mod (Module) – 我们将跟踪此模块的前向方法。

  • args (元组[任何, ...]) – 示例位置输入。

  • kwargs (可选[字典[字符串, 任何]]) – 可选示例关键字输入。

  • dynamic_shapes (可选[联合[字典[字符串, 任何], 元组[任何], 列表[任何]]]) –

    一个可选参数,其中类型应为:1) 从参数名称 f 到其动态形状规范的字典,2) 指定原始顺序中每个输入的动态形状规范的元组。如果您正在指定关键字参数的动态性,则需要按原始函数签名中定义的顺序传递它们。

    张量参数的动态形状可以指定为 (1) 从动态维度索引到 Dim() 类型的字典,其中不需要在此字典中包含静态维度索引,但当它们存在时,它们应映射到 None;或 (2) Dim() 类型或 None 的元组/列表,其中 Dim() 类型对应于动态维度,而静态维度由 None 表示。是字典或张量元组/列表的参数通过使用包含规范的映射或序列进行递归指定。

  • strict (bool) – 当启用(默认)时,导出函数将通过 TorchDynamo 跟踪程序,这将确保结果图的健全性。否则,导出的程序不会验证烘焙到图中的隐式假设,并可能导致原始模型和导出模型之间的行为差异。当用户需要解决跟踪器中的错误,或只是希望逐步启用模型中的安全性时,这很有用。请注意,这不会影响结果 IR 规范的不同,并且无论此处传递什么值,模型都将以相同的方式序列化。警告:此选项为实验性选项,请自行承担风险使用。

返回

包含已跟踪可调用对象的 ExportedProgram

返回类型

ExportedProgram

可接受的输入/输出类型

可接受的输入类型(对于 argskwargs)和输出包括

  • 基本类型,即 torch.Tensorintfloatboolstr

  • 数据类,但必须先调用 register_dataclass() 进行注册。

  • 包含 dictlisttuplenamedtupleOrderedDict 的(嵌套)数据结构,其中包含以上所有类型。

torch.export.dynamic_shapes.dynamic_dim(t, index, debug_name=None)[source]

警告

(此功能已弃用。请改用 Dim()。)

dynamic_dim() 构造一个 _Constraint 对象,用于描述张量 t 的维度 index 的动态性。 _Constraint 对象应传递给 export()constraints 参数。

参数
  • t (torch.Tensor) – 具有动态维度大小的示例输入张量

  • index (int) – 动态维度的索引

返回

一个描述形状动态性的 _Constraint 对象。它可以传递给 export(),以便 export() 不假定指定张量的静态大小,即将其保持为符号大小的动态大小,而不是根据示例追踪输入的大小进行专门化。

具体来说,dynamic_dim() 可用于表示以下类型的动态性。

  • 维度的尺寸是动态的且无界的

    t0 = torch.rand(2, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size rather than always being static size 2
    constraints = [dynamic_dim(t0, 0)]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 维度的尺寸是动态的,且有下界

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive)
    # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive)
    constraints = [
        dynamic_dim(t0, 0) >= 5,
        dynamic_dim(t1, 1) > 2,
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 维度的尺寸是动态的,且有上界

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive)
    # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive)
    constraints = [
        dynamic_dim(t0, 0) <= 16,
        dynamic_dim(t1, 1) < 8,
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 维度的尺寸是动态的,且始终等于另一个动态维度的尺寸

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # Sizes of second dimension of t0 and first dimension are always equal
    constraints = [
        dynamic_dim(t0, 1) == dynamic_dim(t1, 0),
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 混合匹配以上所有类型,只要它们不表示冲突的要求即可

torch.export.save(ep, f, *, extra_files=None, opset_version=None)[source]

警告

处于积极开发中,保存的文件在 PyTorch 的较新版本中可能无法使用。

ExportedProgram 保存到类似文件的对象中。然后可以使用 Python API torch.export.load 加载它。

参数
  • ep (ExportedProgram) – 要保存的导出程序。

  • f (Union[str, os.PathLike, io.BytesIO) – 一个类文件对象(必须实现写入和刷新)或包含文件名的一个字符串。

  • extra_files (Optional[Dict[str, Any]]) – 从文件名到内容的映射,该内容将存储为 f 的一部分。

  • opset_version (Optional[Dict[str, int]]) – 一个 opset 名称到此 opset 版本的映射

示例

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

ep = torch.export.export(MyModule(), (torch.randn(5),))

# Save to file
torch.export.save(ep, 'exported_program.pt2')

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]

警告

处于积极开发中,保存的文件在 PyTorch 的较新版本中可能无法使用。

加载之前使用 torch.export.save 保存的 ExportedProgram

参数
  • ep (ExportedProgram) – 要保存的导出程序。

  • f (Union[str, os.PathLike, io.BytesIO) – 一个类文件对象(必须实现写入和刷新)或包含文件名的一个字符串。

  • extra_files (Optional[Dict[str, Any]]) – 此映射中给出的额外文件名将被加载,并且其内容将存储在提供的映射中。

  • expected_opset_version (Optional[Dict[str, int]]) – 一个 opset 名称到预期 opset 版本的映射

返回

一个 ExportedProgram 对象

返回类型

ExportedProgram

示例

import torch
import io

# Load ExportedProgram from file
ep = torch.export.load('exported_program.pt2')

# Load ExportedProgram from io.BytesIO object
with open('exported_program.pt2', 'rb') as f:
    buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)

# Load with extra files.
extra_files = {'foo.txt': ''}  # values will be replaced with data
ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
print(extra_files['foo.txt'])
print(ep(torch.randn(5)))
torch.export.register_dataclass(cls, *, serialized_type_name=None)[源代码]

将数据类注册为 torch.export.export() 的有效输入/输出类型。

参数
  • cls (Type[Any]) – 要注册的数据类类型

  • serialized_type_name (Optional[str]) – 数据类的序列化名称。这是

  • this (如果要序列化包含) –

  • dataclass 的 pytree TreeSpec,则需要此项。

示例

@dataclass
class InputDataClass:
    feature: torch.Tensor
    bias: int

class OutputDataClass:
    res: torch.Tensor

torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)

def fn(o: InputDataClass) -> torch.Tensor:
    res = res=o.feature + o.bias
    return OutputDataClass(res=res)

ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
print(ep)
torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[源代码]

Dim() 构造一个类型,类似于带范围的命名符号整数。它可用于描述动态张量维度多个可能的值。请注意,同一张量或不同张量的不同动态维度可以用同一种类型描述。

参数
  • name (str) – 用于调试的人类可读名称。

  • min (Optional[int]) – 给定符号的最小可能值(包含)

  • max (Optional[int]) – 给定符号的最大可能值(包含)

返回

一种可用于张量动态形状规范的类型。

torch.export.dims(*names, min=None, max=None)[source]

用于创建多个 Dim() 类型的实用工具。

torch.export.Constraint

Union[_Constraint, _DerivedConstraint] 的别名

torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=, verifier=, tensor_constants=, constants=)[源代码]

来自 export() 的程序包。它包含一个 torch.fx.Graph,表示张量计算、包含所有提升参数和缓冲区的张量值的状态字典以及各种元数据。

您可以使用与 export() 追踪的原始可调用对象相同的调用约定,调用 ExportedProgram。

若要对图执行转换,请使用 .module 属性访问 torch.fx.GraphModule。然后,您可以使用 FX 转换 重写图。之后,您可以简单地再次使用 export() 构建正确的 ExportedProgram。

module()[源代码]

返回一个自包含的 GraphModule,其中包含所有内联的参数/缓冲区。

返回类型

模块

buffers()[源代码]

返回原始模块缓冲区的迭代器。

警告

此 API 为实验性 API,向后兼容。

返回类型

迭代器[张量]

named_buffers()[源代码]

返回原始模块缓冲区的迭代器,既生成缓冲区的名称,也生成缓冲区本身。

警告

此 API 为实验性 API,向后兼容。

返回类型

迭代器[元组[字符串, 张量]]

parameters()[源代码]

返回原始模块参数的迭代器。

警告

此 API 为实验性 API,向后兼容。

返回类型

迭代器[参数]

named_parameters()[源代码]

返回原始模块参数的迭代器,既生成参数的名称,也生成参数本身。

警告

此 API 为实验性 API,向后兼容。

返回类型

迭代器[元组[字符串, 参数]]

run_decompositions(decomp_table=)[源代码]

对导出的程序运行一组分解,并返回一个新的导出程序。默认情况下,我们将运行 Core ATen 分解,以获取 Core ATen 算子集 中的算子。

目前,我们不分解联合图。

返回类型

ExportedProgram

torch.export.ExportBackwardSignature(gradients_to_parameters: 字典[str, str], gradients_to_user_inputs: 字典[str, str], loss_output: str)[源代码]
torch.export.ExportGraphSignature(input_specs, output_specs)[源代码]

ExportGraphSignature 对导出图的输入/输出签名进行建模,导出图是具有更强不变性保证的 fx.Graph。

导出图是实用的,并且不会通过 getattr 节点访问图中的“状态”,例如参数或缓冲区。相反,export() 保证参数、缓冲区和常量张量作为输入从图中提取出来。同样,对缓冲区的任何更改也不包含在图中,而是将更改后的缓冲区的值建模为导出图的附加输出。

所有输入和输出的排序为

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

例如,如果导出以下模块

class CustomModule(nn.Module):
    def __init__(self):
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer('my_buffer1', torch.tensor(3.0))
        self.register_buffer('my_buffer2', torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0) # In-place addition

        return output

生成的图将为

graph():
    %arg0_1 := placeholder[target=arg0_1]
    %arg1_1 := placeholder[target=arg1_1]
    %arg2_1 := placeholder[target=arg2_1]
    %arg3_1 := placeholder[target=arg3_1]
    %arg4_1 := placeholder[target=arg4_1]
    %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
    %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
    %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
    %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
    %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
    return (add_tensor_2, add_tensor_1)

生成的 ExportGraphSignature 将为

ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
    ]
)
class torch.export.ModuleCallSignature(inputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument]], outputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec)[source]
torch.export.ModuleCallEntry(fqn: str, signature: Union[torch.export.exported_program.ModuleCallSignature, NoneType] = None)[源代码]
torch.export.graph_signature.InputKind(value)[源代码]

枚举。

torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument], target: Union[str, NoneType], persistent: Union[bool, NoneType] = None)[源代码]
torch.export.graph_signature.OutputKind(value)[源代码]

枚举。

torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument], target: Union[str, NoneType])[源代码]
torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[源代码]

ExportGraphSignature 对导出图的输入/输出签名进行建模,导出图是带有更强不变性保证的 fx.Graph。

导出图是函数性的,不会通过 getattr 节点访问图中的“状态”,例如参数或缓冲区。相反,export() 保证参数、缓冲区和常量张量作为输入从图中提取出来。同样,对缓冲区的任何更改也不会包含在图中,而是将已更改的缓冲区的更新值建模为导出图的附加输出。

所有输入和输出的排序为

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

例如,如果导出以下模块

class CustomModule(nn.Module):
    def __init__(self):
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer('my_buffer1', torch.tensor(3.0))
        self.register_buffer('my_buffer2', torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0) # In-place addition

        return output

生成的图将为

graph():
    %arg0_1 := placeholder[target=arg0_1]
    %arg1_1 := placeholder[target=arg1_1]
    %arg2_1 := placeholder[target=arg2_1]
    %arg3_1 := placeholder[target=arg3_1]
    %arg4_1 := placeholder[target=arg4_1]
    %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
    %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
    %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
    %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
    %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
    return (add_tensor_2, add_tensor_1)

生成的 ExportGraphSignature 将为

ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
    ]
)
replace_all_uses(old, new)[源代码]

在签名中用新名称替换所有旧名称的使用。

get_replace_hook()[source]
class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str)[source]
class torch.export.unflatten.FlatArgsAdapter[source]

使用 input_spec 调整输入参数以对齐 target_spec

abstract adapt(target_spec, input_spec, input_args)[source]

注意:此适配器可能会改变给定的 input_args_with_path

返回类型

List[Any]

class torch.export.unflatten.InterpreterModule(graph)[source]

一个使用 torch.fx.Interpreter 执行的模块,而不是 GraphModule 使用的常规代码生成。这提供了更好的堆栈跟踪信息,并使调试执行变得更加容易。

torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]

展开 ExportedProgram,生成与原始 eager 模块具有相同模块层次结构的模块。如果你尝试将 torch.export 与另一个系统一起使用,该系统期望模块层次结构而不是 torch.export 通常生成的平面图,这可能很有用。

注意

展开模块的 args/kwargs 不一定与 eager 模块匹配,因此进行模块交换(例如 self.submod = new_mod)不一定有效。如果你需要交换模块,则需要设置 torch.export.export()preserve_module_call_signature 参数。

参数
  • module (ExportedProgram) – 要展开的 ExportedProgram。

  • flat_args_adapter (可选[FlatArgsAdapter]) – 如果输入 TreeSpec 与导出的模块不匹配,则调整平面 args。

返回

UnflattenedModule 的一个实例,它与导出前的原始 eager 模块具有相同的模块层次结构。

返回类型

UnflattenedModule

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源