快捷方式

基于 TorchDynamo 的 ONNX 导出器

警告

TorchDynamo 的 ONNX 导出器是一项快速发展的测试版技术。

概述

ONNX 导出器利用 TorchDynamo 引擎挂钩到 Python 的帧评估 API,并动态地将字节码重写为 FX 图。生成的 FX 图在最终转换为 ONNX 图之前会进行优化。

这种方法的主要优势在于 FX 图 是使用字节码分析捕获的,它保留了模型的动态特性,而不是使用传统的静态跟踪技术。

导出器旨在模块化和可扩展。它由以下组件组成

  • ONNX 导出器: Exporter 是协调导出过程的主要类。

  • ONNX 导出选项: ExportOptions 包含一组控制导出过程的选项。

  • ONNX 注册表: OnnxRegistry 是 ONNX 算子和函数的注册表。

  • FX 图提取器: FXGraphExtractor 从 PyTorch 模型中提取 FX 图。

  • 模拟模式: ONNXFakeContext 是一个上下文管理器,它为大型模型启用模拟模式。

  • ONNX 程序: ONNXProgram 是导出器的输出,包含导出的 ONNX 图和诊断信息。

  • ONNX 程序序列化器: ONNXProgramSerializer 将导出的模型序列化到文件。

  • ONNX 诊断选项: DiagnosticOptions 包含一组控制导出器发出的诊断信息的选项。

依赖项

ONNX 导出器依赖于额外的 Python 包

它们可以通过 pip 安装

pip install --upgrade onnx onnxscript

一个简单的例子

以下展示了导出器 API 的实际应用,以一个简单的多层感知器 (MLP) 为例

import torch
import torch.nn as nn

class MLPModel(nn.Module):
  def __init__(self):
      super().__init__()
      self.fc0 = nn.Linear(8, 8, bias=True)
      self.fc1 = nn.Linear(8, 4, bias=True)
      self.fc2 = nn.Linear(4, 2, bias=True)
      self.fc3 = nn.Linear(2, 2, bias=True)

  def forward(self, tensor_x: torch.Tensor):
      tensor_x = self.fc0(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      tensor_x = self.fc1(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      tensor_x = self.fc2(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      output = self.fc3(tensor_x)
      return output

model = MLPModel()
tensor_x = torch.rand((97, 8), dtype=torch.float32)
onnx_program = torch.onnx.dynamo_export(model, tensor_x)

如上代码所示,您只需提供 torch.onnx.dynamo_export() 模型实例及其输入。导出器将返回一个 torch.onnx.ONNXProgram 实例,其中包含导出的 ONNX 图以及其他信息。

通过 onnx_program.model_proto 可用的内存模型是符合 ONNX IR 规范onnx.ModelProto 对象。然后可以使用 torch.onnx.ONNXProgram.save() API 将 ONNX 模型序列化为 Protobuf 文件

onnx_program.save("mlp.onnx")

使用 GUI 检查 ONNX 模型

您可以使用 Netron 查看导出的模型。

MLP model as viewed using Netron

请注意,每个层都用一个矩形框表示,右上角有一个 *f* 图标。

ONNX function highlighted on MLP model

展开它,将显示函数体。

ONNX function body

函数体是 ONNX 运算符或其他函数的序列。

使用 SARIF 诊断问题

ONNX 诊断通过采用 静态分析结果交换格式(又名 SARIF) 超越了常规日志,以帮助用户使用 GUI(如 Visual Studio Code 的 SARIF 查看器)调试和改进其模型。

主要优点是

  • 诊断以机器可解析的 静态分析结果交换格式 (SARIF) 发出。

  • 一种更清晰、结构化的方式来添加新的诊断规则并跟踪它们。

  • 作为未来更多使用诊断的改进的基础。

API 参考

torch.onnx.dynamo_export(model, /, *model_args, export_options=None, **model_kwargs)

将 torch.nn.Module 导出到 ONNX 图。

参数
返回值

导出的 ONNX 模型的内存表示。

返回类型

ONNXProgram

示例 1 - 最简单的导出

class MyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(2, 2)
    def forward(self, x, bias=None):
        out = self.linear(x)
        out = out + bias
        return out
model = MyModel()
kwargs = {"bias": 3.}
args = (torch.randn(2, 2, 2),)
onnx_program = torch.onnx.dynamo_export(
    model,
    *args,
    **kwargs).save("my_simple_model.onnx")

示例 2 - 使用动态形状导出

# The previous model can be exported with dynamic shapes
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_program = torch.onnx.dynamo_export(
    model,
    *args,
    **kwargs,
    export_options=export_options)
onnx_program.save("my_dynamic_model.onnx")

通过打印输入动态维度,我们可以看到输入形状不再是 (2,2,2)

>>> print(onnx_program.model_proto.graph.input[0])
name: "arg0"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "arg0_dim_0"
      }
      dim {
        dim_param: "arg0_dim_1"
      }
      dim {
        dim_param: "arg0_dim_2"
      }
    }
  }
}
class torch.onnx.ExportOptions(*, dynamic_shapes=None, op_level_debug=None, fake_context=None, onnx_registry=None, diagnostic_options=None)

影响 TorchDynamo ONNX 导出器的选项。

变量
  • dynamic_shapes (Optional[bool]) – 输入/输出张量的形状信息提示。当为 None 时,导出器会确定最兼容的设置。当为 True 时,所有输入形状都被视为动态。当为 False 时,所有输入形状都被视为静态。

  • op_level_debug (Optional[bool]) – 是否以操作级别调试信息导出模型

  • diagnostic_options (DiagnosticOptions) – 导出器的诊断选项。

  • fake_context (Optional[ONNXFakeContext]) – 用于符号跟踪的假上下文。

  • onnx_registry (Optional[OnnxRegistry]) – 用于将 ATen 运算符注册到 ONNX 函数的 ONNX 注册表。

torch.onnx.enable_fake_mode()

在上下文持续时间内启用假模式。

在内部,它实例化一个 torch._subclasses.fake_tensor.FakeTensorMode 上下文管理器,该管理器将用户输入和模型参数转换为 torch._subclasses.fake_tensor.FakeTensor

一个 torch._subclasses.fake_tensor.FakeTensor 是一个 torch.Tensor,它能够运行 PyTorch 代码,而无需实际通过在 meta 设备上分配的张量进行计算。由于设备上没有实际分配数据,因此此 API 允许导出大型模型,而无需执行模型所需的实际内存占用。

强烈建议在导出内存不足以容纳的模型时启用假模式。

返回值

一个 ONNXFakeContext 对象,必须通过 ExportOptions.fake_context 参数传递给 dynamo_export()

示例

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> class MyModel(torch.nn.Module):  # Dummy model
...     def __init__(self) -> None:
...         super().__init__()
...         self.linear = torch.nn.Linear(2, 2)
...     def forward(self, x):
...         out = self.linear(x)
...         return out
>>> with torch.onnx.enable_fake_mode() as fake_context:
...     my_nn_module = MyModel()
...     arg1 = torch.randn(2, 2, 2)  # positional input 1
>>> export_options = torch.onnx.ExportOptions(fake_context=fake_context)
>>> onnx_program = torch.onnx.dynamo_export(
...     my_nn_module,
...     arg1,
...     export_options=export_options
... )
>>> # Saving model WITHOUT initializers
>>> onnx_program.save("my_model_without_initializers.onnx")
>>> # Saving model WITH initializers
>>> onnx_program.save("my_model_with_initializers.onnx", model_state=MyModel().state_dict())

警告

此 API 处于实验阶段,并且向后兼容。

class torch.onnx.ONNXProgram(model_proto, input_adapter, output_adapter, diagnostic_context, *, fake_context=None, export_exception=None, model_signature=None, model_torch=None)

已导出到 ONNX 的 PyTorch 模型的内存中表示。

参数
  • model_proto (onnx.ModelProto) – 导出的 ONNX 模型,作为 onnx.ModelProto

  • input_adapter (io_adapter.InputAdapter) – 用于将 PyTorch 输入转换为 ONNX 输入的输入适配器。

  • output_adapter (io_adapter.OutputAdapter) – 用于将 PyTorch 输出转换为 ONNX 输出的输出适配器。

  • diagnostic_context (diagnostics.DiagnosticContext) – 用于 SARIF 诊断系统的上下文对象,负责记录错误和元数据。

  • fake_context (Optional[ONNXFakeContext]) – 用于符号跟踪的假上下文。

  • export_exception (Optional[Exception]) – 导出过程中发生的异常(如果有)。

  • model_signature (Optional[torch.export.ExportGraphSignature]) – 导出 ONNX 图的模型签名。

adapt_torch_inputs_to_onnx(*model_args, model_with_state_dict=None, **model_kwargs)[source]

将 PyTorch 模型输入转换为导出的 ONNX 模型输入格式。

由于设计差异,PyTorch 模型和导出的 ONNX 模型之间的输入/输出格式通常不相同。例如,PyTorch 模型允许 None,但 ONNX 不支持。PyTorch 模型允许嵌套的张量结构,但 ONNX 仅支持扁平化的张量,等等。

实际的适配步骤与每个单独的导出相关联。它取决于 PyTorch 模型、用于导出的特定模型参数和模型关键字参数集以及导出选项。

此方法重放导出期间记录的适配步骤。

参数
  • model_args – PyTorch 模型输入。

  • model_with_state_dict (可选[联合[模块, 可调用, 导出程序]]) – 要获取额外状态的 PyTorch 模型。如果未指定,则使用导出期间使用的模型。当 enable_fake_mode() 用于提取 ONNX 图表所需的真实初始值时,需要此参数。

  • model_kwargs – PyTorch 模型关键字输入。

返回值

从 PyTorch 模型输入转换而来的张量序列。

返回类型

序列[联合[张量, int, float, bool]]

示例

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> from typing import Dict, Tuple
>>> def func_nested_input(
...     x_dict: Dict[str, torch.Tensor],
...     y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
... ):
...     if "a" in x_dict:
...         x = x_dict["a"]
...     elif "b" in x_dict:
...         x = x_dict["b"]
...     else:
...         x = torch.randn(3)
...
...     y1, (y2, y3) = y_tuple
...
...     return x + y1 + y2 + y3
>>> x_dict = {"a": torch.tensor(1.)}
>>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.)))
>>> onnx_program = torch.onnx.dynamo_export(func_nested_input, x_dict, y_tuple)
>>> print(x_dict, y_tuple)
{'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.)))
>>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple, model_with_state_dict=func_nested_input))
(tensor(1.), tensor(2.), tensor(3.), tensor(4.))

警告

此 API 处于实验阶段,并且向后兼容。

adapt_torch_outputs_to_onnx(model_outputs, model_with_state_dict=None)[源代码]

将 PyTorch 模型输出转换为导出的 ONNX 模型输出格式。

由于设计差异,PyTorch 模型和导出的 ONNX 模型之间的输入/输出格式通常不相同。例如,PyTorch 模型允许 None,但 ONNX 不支持。PyTorch 模型允许嵌套的张量结构,但 ONNX 仅支持扁平化的张量,等等。

实际的适配步骤与每个单独的导出相关联。它取决于 PyTorch 模型、用于导出的特定模型参数和模型关键字参数集以及导出选项。

此方法重放导出期间记录的适配步骤。

参数
  • model_outputs (任何) – PyTorch 模型输出。

  • model_with_state_dict (可选[联合[模块, 可调用, 导出程序]]) – 要获取额外状态的 PyTorch 模型。如果未指定,则使用导出期间使用的模型。当 enable_fake_mode() 用于提取 ONNX 图表所需的真实初始值时,需要此参数。

返回值

导出 ONNX 模型输出格式的 PyTorch 模型输出。

返回类型

序列[联合[张量, int, float, bool]]

示例

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> def func_returning_tuples(x, y, z):
...     x = x + y
...     y = y + z
...     z = x + y
...     return (x, (y, z))
>>> x = torch.tensor(1.)
>>> y = torch.tensor(2.)
>>> z = torch.tensor(3.)
>>> onnx_program = torch.onnx.dynamo_export(func_returning_tuples, x, y, z)
>>> pt_output = func_returning_tuples(x, y, z)
>>> print(pt_output)
(tensor(3.), (tensor(5.), tensor(8.)))
>>> print(onnx_program.adapt_torch_outputs_to_onnx(pt_output, model_with_state_dict=func_returning_tuples))
[tensor(3.), tensor(5.), tensor(8.)]

警告

此 API 处于实验阶段,并且向后兼容。

property diagnostic_context: diagnostics.DiagnosticContext

与导出相关的诊断上下文。

property fake_context: Optional[ONNXFakeContext]

与导出相关的假上下文。

property model_proto: onnx.ModelProto

导出的 ONNX 模型,以 onnx.ModelProto 格式。

property model_signature: Optional[ExportGraphSignature]

导出 ONNX 图的模型签名。

此信息很重要,因为 ONNX 规范通常与 PyTorch 的规范不同,导致 ONNX 图的输入和输出模式与实际 PyTorch 模型实现不同。通过使用模型签名,用户可以了解输入和输出的差异,并正确地在 ONNX Runtime 中执行模型。

注意:模型签名仅在从 torch.export.ExportedProgram 对象导出 ONNX 图时可用。

注意:对模型进行的任何更改模型签名的转换都必须通过 InputAdaptStep 和/或 OutputAdaptStep 对此模型签名进行更新。

示例

以下模型产生不同的输入和输出集。前 4 个输入是模型参数(即 conv1.weight、conv2.weight、fc1.weight、fc2.weight),接下来的 2 个输入是注册的缓冲区(即 my_buffer2、my_buffer1),最后 2 个输入是用户输入(即 x 和 b)。第一个输出是缓冲区变异(即 my_buffer2),最后一个输出是实际的模型输出。

>>> class CustomModule(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
...         self.register_buffer("my_buffer1", torch.tensor(3.0))
...         self.register_buffer("my_buffer2", torch.tensor(4.0))
...         self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, bias=False)
...         self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False)
...         self.fc1 = torch.nn.Linear(9216, 128, bias=False)
...         self.fc2 = torch.nn.Linear(128, 10, bias=False)
...     def forward(self, x, b):
...         tensor_x = self.conv1(x)
...         tensor_x = torch.nn.functional.sigmoid(tensor_x)
...         tensor_x = self.conv2(tensor_x)
...         tensor_x = torch.nn.functional.sigmoid(tensor_x)
...         tensor_x = torch.nn.functional.max_pool2d(tensor_x, 2)
...         tensor_x = torch.flatten(tensor_x, 1)
...         tensor_x = self.fc1(tensor_x)
...         tensor_x = torch.nn.functional.sigmoid(tensor_x)
...         tensor_x = self.fc2(tensor_x)
...         output = torch.nn.functional.log_softmax(tensor_x, dim=1)
...         (
...         self.my_buffer2.add_(1.0) + self.my_buffer1
...         )  # Mutate buffer through in-place addition
...         return output
>>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3))
>>> exported_program = torch.export.export(CustomModule(), args=inputs)
>>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs)
>>> print(onnx_program.model_signature)
ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'),
                  target='conv1.weight', persistent=None),
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg1_1'),
                  target='conv2.weight', persistent=None),
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg2_1'),
                  target='fc1.weight', persistent=None),
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg3_1'),
                  target='fc2.weight', persistent=None),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg4_1'),
                  target='my_buffer2', persistent=True),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg5_1'),
                  target='my_buffer1', persistent=True),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg6_1'),
                  target=None, persistent=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg7_1'),
                  target=None, persistent=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add'), target='my_buffer2'),
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='_log_softmax'), target=None)
    ]
)
save(destination, *, model_state=None, serializer=None)[source]

使用指定的 serializer 将内存中的 ONNX 模型保存到 destination

参数
  • destination (Union[str, BufferedIOBase]) – 保存 ONNX 模型的目标位置。它可以是字符串或文件类对象。当与 model_state 一起使用时,它必须是目标位置的完整路径字符串。如果 destination 是字符串,除了将 ONNX 模型保存到文件中之外,模型权重也会存储在与 ONNX 模型相同目录中的单独文件中。例如,对于 destination=”/path/model.onnx”,初始化器将保存在“/path/”文件夹中,以及“onnx.model”。

  • model_state (Optional[Union[Dict[str, Any], str]]) – 包含所有权重的 PyTorch 模型的 state_dict。它可以是包含检查点路径的字符串,也可以是包含实际模型状态的字典。支持的文件格式与 torch.loadsafetensors.safe_open 支持的格式相同。当使用 enable_fake_mode() 但需要 ONNX 图上的真实初始化器时,这是必需的。

  • serializer (Optional[ONNXProgramSerializer]) – 要使用的序列化器。如果未指定,则模型将被序列化为 Protobuf。

save_diagnostics(destination)[source]

将导出诊断信息保存为 SARIF 日志到指定的目标路径。

参数

destination (str) – 保存诊断 SARIF 日志的目标路径。它必须具有 .sarif 扩展名。

引发

ValueError – 如果目标路径不以 .sarif 扩展名结尾。

class torch.onnx.ONNXProgramSerializer(*args, **kwargs)

用于将 ONNX 图序列化为特定格式(例如 Protobuf)的协议。请注意,这是一种高级用法场景。

serialize(onnx_program, destination)[source]

必须实现的序列化协议方法。

参数
  • onnx_program (ONNXProgram) – 表示内存中导出的 ONNX 模型

  • destination (BufferedIOBase) – 二进制 IO 流或预分配的缓冲区,序列化后的模型应写入其中。

示例

一个简单的序列化器,它将导出的 onnx.ModelProto 以 Protobuf 格式写入 destination

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import io
>>> import torch
>>> import torch.onnx
>>> class MyModel(torch.nn.Module):  # Dummy model
...     def __init__(self) -> None:
...         super().__init__()
...         self.linear = torch.nn.Linear(2, 2)
...     def forward(self, x):
...         out = self.linear(x)
...         return out
>>> class ProtobufONNXProgramSerializer:
...     def serialize(
...         self, onnx_program: torch.onnx.ONNXProgram, destination: io.BufferedIOBase
...     ) -> None:
...         destination.write(onnx_program.model_proto.SerializeToString())
>>> model = MyModel()
>>> arg1 = torch.randn(2, 2, 2)  # positional input 1
>>> torch.onnx.dynamo_export(model, arg1).save(
...     destination="exported_model.onnx",
...     serializer=ProtobufONNXProgramSerializer(),
... )
class torch.onnx.ONNXRuntimeOptions(*, session_options=None, execution_providers=None, execution_provider_options=None)

通过 ONNX Runtime 影响 ONNX 模型执行的选项。

变量
  • session_options (可选[序列['onnxruntime.SessionOptions']]) – ONNX Runtime 会话选项。

  • execution_providers (可选[序列[联合[str, 元组[str, 字典[任何, 任何]]]]]) – 在模型执行期间使用的 ONNX Runtime 执行提供程序。

  • execution_provider_options (可选[序列[字典[任何, 任何]]]) – ONNX Runtime 执行提供程序选项。

torch.onnx.InvalidExportOptionsError

当用户为 ExportOptions 指定无效值时引发。

torch.onnx.OnnxExporterError(onnx_program, message)

当 ONNX 导出器错误发生时引发。

当 ONNX 导出过程中出现错误时,会抛出此异常。它封装了在失败之前生成的 ONNXProgram 对象,允许访问部分导出结果和相关元数据。

torch.onnx.OnnxRegistry

ONNX 函数的注册表。

注册表维护着从限定名称到固定操作集版本下的符号函数的映射。它支持注册自定义 onnx-script 函数,并支持调度程序将调用调度到相应的函数。

get_op_functions(namespace, op_name, overload=None)[source]

返回给定操作的 ONNXFunctions 列表:torch.ops.<namespace>.<op_name>.<overload>。

该列表按注册时间排序。自定义操作符应位于列表的后半部分。

参数
  • namespace (str) – 要获取的操作符的命名空间。

  • op_name (str) – 要获取的操作符的名称。

  • overload (Optional[str]) – 要获取的操作符的重载。如果为默认重载,则保留为 None。

返回值

与给定名称相对应的 ONNXFunctions 列表,如果注册表中没有该名称,则为 None。

返回类型

Optional[List[ONNXFunction]]

is_registered_op(namespace, op_name, overload=None)[source]

返回给定操作符是否已注册:torch.ops.<namespace>.<op_name>.<overload>。

参数
  • namespace (str) – 要检查的操作符的命名空间。

  • op_name (str) – 要检查的操作符的名称。

  • overload (Optional[str]) – 要检查的操作符的重载。如果为默认重载,则保留为 None。

返回值

如果给定操作符已注册,则为 True,否则为 False。

返回类型

bool

property opset_version: int

导出器应针对的 ONNX 操作符集版本。默认为最新支持的 ONNX 操作符集版本:18。随着 ONNX 的不断发展,默认版本将随着时间的推移而增加。

register_op(function, namespace, op_name, overload=None, is_complex=False)[source]

注册自定义算子:torch.ops.<namespace>.<op_name>.<overload>。

参数
  • function (Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction]) – 要注册的 onnx-sctip 函数。

  • namespace (str) – 要注册的算子的命名空间。

  • op_name (str) – 要注册的算子的名称。

  • overload (Optional[str]) – 要注册的算子的重载。如果是默认重载,则将其保留为 None。

  • is_complex (bool) – 该函数是否为处理复数值输入的函数。

引发

ValueError – 如果名称不符合 ‘namespace::op’ 的形式。

class torch.onnx.DiagnosticOptions(verbosity_level=20, warnings_as_errors=False)

诊断上下文的选项。

变量
  • verbosity_level (int) – 设置每个诊断记录的信息量,相当于 Python 日志模块中的 ‘level’。

  • warnings_as_errors (bool) – 当为 True 时,警告诊断将被视为错误诊断。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源