快捷方式

基于 TorchDynamo 的 ONNX 导出器

警告

TorchDynamo 的 ONNX 导出器是一项快速发展的测试版技术。

概述

ONNX 导出器利用 TorchDynamo 引擎挂接到 Python 的帧评估 API,并将它的字节码动态地重写为 FX 图。生成的 FX 图将在最终转换为 ONNX 图之前进行完善。

这种方法的主要优点是,FX 图 是使用字节码分析捕获的,该分析保留了模型的动态特性,而不是使用传统的静态跟踪技术。

导出器旨在模块化且可扩展。它由以下组件组成:

  • ONNX 导出器Exporter 主要类,负责协调导出过程。

  • ONNX 导出选项ExportOptions 有一组选项,用于控制导出过程。

  • ONNX 注册表OnnxRegistry 是 ONNX 算子和函数的注册表。

  • FX 图提取器FXGraphExtractor 从 PyTorch 模型中提取 FX 图。

  • 假模式ONNXFakeContext 是一个上下文管理器,用于为大型模型启用假模式。

  • ONNX 程序ONNXProgram 是导出器的输出,它包含导出的 ONNX 图和诊断信息。

  • ONNX 诊断选项DiagnosticOptions 有一组选项,用于控制导出器发出的诊断信息。

依赖关系

ONNX 导出器依赖于额外的 Python 包

可以通过 pip 安装它们。

pip install --upgrade onnx onnxscript

onnxruntime 然后可以用来在各种处理器上执行模型。

一个简单的示例

以下展示了导出器 API 在一个简单的多层感知器 (MLP) 示例中的应用。

import torch
import torch.nn as nn

class MLPModel(nn.Module):
  def __init__(self):
      super().__init__()
      self.fc0 = nn.Linear(8, 8, bias=True)
      self.fc1 = nn.Linear(8, 4, bias=True)
      self.fc2 = nn.Linear(4, 2, bias=True)
      self.fc3 = nn.Linear(2, 2, bias=True)

  def forward(self, tensor_x: torch.Tensor):
      tensor_x = self.fc0(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      tensor_x = self.fc1(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      tensor_x = self.fc2(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      output = self.fc3(tensor_x)
      return output

model = MLPModel()
tensor_x = torch.rand((97, 8), dtype=torch.float32)
onnx_program = torch.onnx.export(model, (tensor_x,), dynamo=True)

如上所示,您只需要为 torch.onnx.export() 提供模型实例及其输入即可。导出器将返回一个 torch.onnx.ONNXProgram 实例,其中包含导出的 ONNX 图以及其他信息。

通过 onnx_program.model_proto 可用的内存中模型是一个符合 ONNX IR 规范onnx.ModelProto 对象。ONNX 模型可以使用 torch.onnx.ONNXProgram.save() API 序列化为一个 Protobuf 文件

onnx_program.save("mlp.onnx")

有两个函数可以基于 TorchDynamo 引擎将模型导出到 ONNX。它们在生成 ExportedProgram 的方式上略有不同。torch.onnx.dynamo_export() 是在 PyTorch 2.1 中引入的,而 torch.onnx.export() 在 PyTorch 2.5 中进行了扩展,以便轻松地在 TorchScript 和 TorchDynamo 之间切换。要调用前一个函数,可以将上一个示例的最后一行替换为以下内容。

onnx_program = torch.onnx.dynamo_export(model, tensor_x)

使用 GUI 检查 ONNX 模型

您可以使用 Netron 查看导出的模型。

MLP model as viewed using Netron

请注意,每个层都用一个矩形框表示,右上角有一个f图标。

ONNX function highlighted on MLP model

展开后,将显示函数体。

ONNX function body

函数体是 ONNX 算子或其他函数的序列。

转换失败时

函数 torch.onnx.export() 应该再次调用,参数为 report=True。会生成一个 Markdown 报告来帮助用户解决问题。

函数 torch.onnx.dynamo_export() 使用 ‘SARIF’ 格式生成报告。ONNX 诊断通过采用 静态分析结果交换格式 (SARIF) 超越了常规日志,帮助用户使用 GUI(如 Visual Studio Code 的 SARIF 查看器)调试和改进模型。

主要优势是

  • 诊断以机器可解析的 静态分析结果交换格式 (SARIF) 发出。

  • 一种新的更清晰、结构化的方式来添加新的诊断规则并跟踪它们。

  • 作为未来更多消耗诊断改进的基础。

API 参考

torch.onnx.dynamo_export(model, /, *model_args, export_options=None, **model_kwargs)[source]

将 torch.nn.Module 导出到 ONNX 图。

参数
返回

导出的 ONNX 模型的内存表示。

返回类型

ONNXProgram | Any

示例 1 - 最简单的导出

class MyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(2, 2)

    def forward(self, x, bias=None):
        out = self.linear(x)
        out = out + bias
        return out


model = MyModel()
kwargs = {"bias": 3.0}
args = (torch.randn(2, 2, 2),)
onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
    "my_simple_model.onnx"
)

示例 2 - 使用动态形状导出

# The previous model can be exported with dynamic shapes
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_program = torch.onnx.dynamo_export(
    model, *args, **kwargs, export_options=export_options
)
onnx_program.save("my_dynamic_model.onnx")
class torch.onnx.ExportOptions(*, dynamic_shapes=None, fake_context=None, onnx_registry=None, diagnostic_options=None)

影响 TorchDynamo ONNX 导出器的选项。

变量
  • dynamic_shapes (bool | None) – 输入/输出张量的形状信息提示。当为 None 时,导出器确定最兼容的设置。当为 True 时,所有输入形状都被认为是动态的。当为 False 时,所有输入形状都被认为是静态的。

  • diagnostic_options (DiagnosticOptions) – 导出器的诊断选项。

  • fake_context (ONNXFakeContext | None) – 用于符号跟踪的伪上下文。

  • onnx_registry (OnnxRegistry | None) – 用于将 ATen 算子注册到 ONNX 函数的 ONNX 注册表。

torch.onnx.enable_fake_mode()

在上下文持续时间内启用伪模式。

在内部,它实例化一个 torch._subclasses.fake_tensor.FakeTensorMode 上下文管理器,该管理器将用户输入和模型参数转换为 torch._subclasses.fake_tensor.FakeTensor

一个 torch._subclasses.fake_tensor.FakeTensor 是一个 torch.Tensor,它能够在不通过分配在 meta 设备上的张量进行实际计算的情况下运行 PyTorch 代码。因为没有实际数据分配在设备上,所以此 API 允许导出大型模型,而无需实际执行它所需的内存占用。

在导出超出内存容量的模型时,强烈建议启用伪模式。

返回

一个 ONNXFakeContext 对象。

示例

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> class MyModel(torch.nn.Module):  # Dummy model
...     def __init__(self) -> None:
...         super().__init__()
...         self.linear = torch.nn.Linear(2, 2)
...     def forward(self, x):
...         out = self.linear(x)
...         return out
>>> with torch.onnx.enable_fake_mode() as fake_context:
...     my_nn_module = MyModel()
...     arg1 = torch.randn(2, 2, 2)  # positional input 1
>>> export_options = torch.onnx.ExportOptions(fake_context=fake_context)
>>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True)
>>> onnx_program.apply_weights(MyModel().state_dict())
>>> # Saving model WITHOUT initializers
>>> onnx_program.save(
...     "my_model_without_initializers.onnx",
...     include_initializers=False,
...     keep_initializers_as_inputs=True,
... )
>>> # Saving model WITH initializers
>>> onnx_program.save("my_model_with_initializers.onnx")

警告

此 API 处于实验阶段,不向后兼容。

class torch.onnx.ONNXProgram(model_proto, input_adapter, output_adapter, diagnostic_context, *, fake_context=None, export_exception=None, model_torch=None)

已导出到 ONNX 的 PyTorch 模型的内存表示。

参数
  • model_proto (onnx.ModelProto) – 导出的 ONNX 模型,作为 onnx.ModelProto

  • input_adapter (io_adapter.InputAdapter) – 用于将 PyTorch 输入转换为 ONNX 输入的输入适配器。

  • output_adapter (io_adapter.OutputAdapter) – 用于将 PyTorch 输出转换为 ONNX 输出的输出适配器。

  • diagnostic_context (diagnostics.DiagnosticContext) – SARIF 诊断系统负责记录错误和元数据的上下文对象。

  • fake_context (ONNXFakeContext | None) – 用于符号跟踪的伪上下文。

  • export_exception (Exception | None) – 导出过程中发生的异常(如果有)。

adapt_torch_inputs_to_onnx(*model_args, model_with_state_dict=None, **model_kwargs)[source]

将 PyTorch 模型输入转换为导出的 ONNX 模型输入格式。

由于设计上的差异,PyTorch 模型和导出的 ONNX 模型之间的输入/输出格式通常并不相同。例如,PyTorch 模型允许 None,但 ONNX 不支持。PyTorch 模型允许张量的嵌套结构,但 ONNX 只支持扁平化的张量,等等。

实际的适应步骤与每个单独的导出相关联。它取决于 PyTorch 模型、用于导出的特定 model_args 和 model_kwargs 集以及导出选项。

此方法重播导出期间记录的适应步骤。

参数
  • model_args – PyTorch 模型输入。

  • model_with_state_dict (torch.nn.Module | Callable | None) – 用于获取额外状态的 PyTorch 模型。如果未指定,则使用导出期间使用的模型。当使用 enable_fake_mode() 从 ONNX 图所需的真实初始化器中提取时,需要此参数。

  • model_kwargs – PyTorch 模型关键字输入。

返回

从 PyTorch 模型输入转换而来的张量序列。

返回类型

Sequence[torch.Tensor | int | float | bool | torch.dtype]

示例

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> from typing import Dict, Tuple
>>> def func_nested_input(
...     x_dict: Dict[str, torch.Tensor],
...     y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
... ):
...     if "a" in x_dict:
...         x = x_dict["a"]
...     elif "b" in x_dict:
...         x = x_dict["b"]
...     else:
...         x = torch.randn(3)
...
...     y1, (y2, y3) = y_tuple
...
...     return x + y1 + y2 + y3
>>> x_dict = {"a": torch.tensor(1.)}
>>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.)))
>>> onnx_program = torch.onnx.dynamo_export(func_nested_input, x_dict, y_tuple)
>>> print(x_dict, y_tuple)
{'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.)))
>>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple, model_with_state_dict=func_nested_input))
(tensor(1.), tensor(2.), tensor(3.), tensor(4.))

警告

此 API 处于实验阶段,不向后兼容。

adapt_torch_outputs_to_onnx(model_outputs, model_with_state_dict=None)[source]

将 PyTorch 模型输出转换为导出的 ONNX 模型输出格式。

由于设计上的差异,PyTorch 模型和导出的 ONNX 模型之间的输入/输出格式通常并不相同。例如,PyTorch 模型允许 None,但 ONNX 不支持。PyTorch 模型允许张量的嵌套结构,但 ONNX 只支持扁平化的张量,等等。

实际的适应步骤与每个单独的导出相关联。它取决于 PyTorch 模型、用于导出的特定 model_args 和 model_kwargs 集以及导出选项。

此方法重播导出期间记录的适应步骤。

参数
  • model_outputs (Any) – PyTorch 模型输出。

  • model_with_state_dict (torch.nn.Module | Callable | None) – 用于获取额外状态的 PyTorch 模型。如果未指定,则使用导出期间使用的模型。当使用 enable_fake_mode() 从 ONNX 图所需的真实初始化器中提取时,需要此参数。

返回

导出的 ONNX 模型输出格式中的 PyTorch 模型输出。

返回类型

Sequence[torch.Tensor | int | float | bool]

示例

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> def func_returning_tuples(x, y, z):
...     x = x + y
...     y = y + z
...     z = x + y
...     return (x, (y, z))
>>> x = torch.tensor(1.)
>>> y = torch.tensor(2.)
>>> z = torch.tensor(3.)
>>> onnx_program = torch.onnx.dynamo_export(func_returning_tuples, x, y, z)
>>> pt_output = func_returning_tuples(x, y, z)
>>> print(pt_output)
(tensor(3.), (tensor(5.), tensor(8.)))
>>> print(onnx_program.adapt_torch_outputs_to_onnx(pt_output, model_with_state_dict=func_returning_tuples))
[tensor(3.), tensor(5.), tensor(8.)]

警告

此 API 处于实验阶段,不向后兼容。

apply_weights(state_dict)[source]

将指定状态字典中的权重应用于 ONNX 模型。 :param state_dict: 包含要应用于 ONNX 模型的权重的状态字典。

property diagnostic_context: diagnostics.DiagnosticContext

与导出相关的诊断上下文。

property fake_context: ONNXFakeContext | None

与导出相关的假上下文。

property model_proto: onnx.ModelProto

作为 onnx.ModelProto 导出的 ONNX 模型。

save(destination, *, include_initializers=True, model_state=None)[source]

使用指定的 serializer 将内存中的 ONNX 模型保存到 destination

参数
  • destination (str | io.BufferedIOBase) – 保存 ONNX 模型的目标位置。它可以是字符串或类文件对象。当与 model_state 一起使用时,它必须是一个指向目标位置的完整路径字符串。如果 destination 是字符串,除了将 ONNX 模型保存到文件外,模型权重还会存储在与 ONNX 模型相同的目录中的单独文件中。例如,对于 destination=”/path/model.onnx”,初始化器将保存在“/path/”文件夹中,以及“onnx.model”。

  • include_initializers (bool) – 是否将初始化器作为外部数据包含在 ONNX 图中。不能与 model_state_dict 结合使用。

  • model_state (dict[str, Any] | str | None) – 包含其所有权重的 PyTorch 模型的状态字典。它可以是指向检查点的路径的字符串,也可以是包含实际模型状态的字典。支持的文件格式与 torch.loadsafetensors.safe_open 支持的格式相同。当使用 enable_fake_mode() 但 ONNX 图需要真实的初始化器时,这是必需的。

save_diagnostics(destination)[source]

将导出诊断作为 SARIF 日志保存到指定的目标路径。

参数

destination (str) – 保存诊断 SARIF 日志的目标位置。它必须具有 .sarif 扩展名。

Raises

ValueError – 如果目标路径不以 .sarif 扩展名结尾。

class torch.onnx.ONNXRuntimeOptions(*, session_options=None, execution_providers=None, execution_provider_options=None)

影响通过 ONNX 运行时执行 ONNX 模型的选项。

变量
  • session_options (Sequence[onnxruntime.SessionOptions] | None) – ONNX 运行时会话选项。

  • execution_providers (Sequence[str | tuple[str, dict[Any, Any]]] | None) – 模型执行期间要使用的 ONNX 运行时执行提供程序。

  • execution_provider_options (Sequence[dict[Any, Any]] | None) – ONNX 运行时执行提供程序选项。

class torch.onnx.OnnxExporterError

ONNX 导出器引发的错误。这是所有导出器错误的基类。

class torch.onnx.OnnxRegistry

ONNX 函数的注册表。

该注册表维护着从限定名到固定操作集版本的符号函数的映射。它支持注册自定义 onnx-script 函数,并支持调度程序将调用调度到相应的函数。

get_op_functions(namespace, op_name, overload=None)[source]

返回给定操作的 ONNXFunction 列表:torch.ops.<namespace>.<op_name>.<overload>。

该列表按注册时间排序。自定义运算符应位于列表的后半部分。

参数
  • namespace (str) – 要获取的运算符的命名空间。

  • op_name (str) – 要获取的运算符的名称。

  • overload (str | None) – 要获取的运算符的重载。如果是默认重载,将其保留为 None。

返回

与给定名称相对应的 ONNXFunction 列表,如果名称不在注册表中,则为 None。

返回类型

list[registration.ONNXFunction] | None

is_registered_op(namespace, op_name, overload=None)[source]

返回给定操作是否已注册:torch.ops.<namespace>.<op_name>.<overload>。

参数
  • namespace (str) – 要检查的运算符的命名空间。

  • op_name (str) – 要检查的运算符的名称。

  • overload (str | None) – 要检查的运算符的重载。如果是默认重载,将其保留为 None。

返回

如果给定操作已注册,则为 True,否则为 False。

返回类型

bool

property opset_version: int

导出器应针对的 ONNX 操作集版本。默认为最新支持的 ONNX 操作集版本:18。随着 ONNX 的不断发展,默认版本会随着时间的推移而增加。

register_op(function, namespace, op_name, overload=None, is_complex=False)[source]

注册自定义运算符:torch.ops.<namespace>.<op_name>.<overload>。

参数
  • function (onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction) – 要注册的onnx-sctip函数。

  • namespace (str) – 要注册的运算符的命名空间。

  • op_name (str) – 要注册的运算符的名称。

  • overload (str | None) – 要注册的运算符的重载。如果它是默认重载,将其保留为None。

  • is_complex (bool) – 该函数是否为处理复数值输入的函数。

Raises

ValueError – 如果名称不符合‘namespace::op’的形式。

class torch.onnx.DiagnosticOptions(verbosity_level=20, warnings_as_errors=False)

诊断上下文的选项。

变量
  • verbosity_level (int) – 设置每个诊断记录的信息量,相当于 Python 日志模块中的‘level’。

  • warnings_as_errors (bool) – 当为 True 时,警告诊断将被视为错误诊断。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源