基于 TorchDynamo 的 ONNX 导出器¶
警告
TorchDynamo 的 ONNX 导出器是一项快速发展的测试版技术。
概述¶
ONNX 导出器利用 TorchDynamo 引擎挂接到 Python 的帧评估 API,并将它的字节码动态地重写为 FX 图。生成的 FX 图将在最终转换为 ONNX 图之前进行完善。
这种方法的主要优点是,FX 图 是使用字节码分析捕获的,该分析保留了模型的动态特性,而不是使用传统的静态跟踪技术。
导出器旨在模块化且可扩展。它由以下组件组成:
ONNX 导出器:
Exporter
主要类,负责协调导出过程。ONNX 导出选项:
ExportOptions
有一组选项,用于控制导出过程。ONNX 注册表:
OnnxRegistry
是 ONNX 算子和函数的注册表。FX 图提取器:
FXGraphExtractor
从 PyTorch 模型中提取 FX 图。假模式:
ONNXFakeContext
是一个上下文管理器,用于为大型模型启用假模式。ONNX 程序:
ONNXProgram
是导出器的输出,它包含导出的 ONNX 图和诊断信息。ONNX 诊断选项:
DiagnosticOptions
有一组选项,用于控制导出器发出的诊断信息。
依赖关系¶
ONNX 导出器依赖于额外的 Python 包
可以通过 pip 安装它们。
pip install --upgrade onnx onnxscript
onnxruntime 然后可以用来在各种处理器上执行模型。
一个简单的示例¶
以下展示了导出器 API 在一个简单的多层感知器 (MLP) 示例中的应用。
import torch
import torch.nn as nn
class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(8, 8, bias=True)
self.fc1 = nn.Linear(8, 4, bias=True)
self.fc2 = nn.Linear(4, 2, bias=True)
self.fc3 = nn.Linear(2, 2, bias=True)
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.fc0(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc1(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc2(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
output = self.fc3(tensor_x)
return output
model = MLPModel()
tensor_x = torch.rand((97, 8), dtype=torch.float32)
onnx_program = torch.onnx.export(model, (tensor_x,), dynamo=True)
如上所示,您只需要为 torch.onnx.export()
提供模型实例及其输入即可。导出器将返回一个 torch.onnx.ONNXProgram
实例,其中包含导出的 ONNX 图以及其他信息。
通过 onnx_program.model_proto
可用的内存中模型是一个符合 ONNX IR 规范 的 onnx.ModelProto
对象。ONNX 模型可以使用 torch.onnx.ONNXProgram.save()
API 序列化为一个 Protobuf 文件。
onnx_program.save("mlp.onnx")
有两个函数可以基于 TorchDynamo 引擎将模型导出到 ONNX。它们在生成 ExportedProgram
的方式上略有不同。torch.onnx.dynamo_export()
是在 PyTorch 2.1 中引入的,而 torch.onnx.export()
在 PyTorch 2.5 中进行了扩展,以便轻松地在 TorchScript 和 TorchDynamo 之间切换。要调用前一个函数,可以将上一个示例的最后一行替换为以下内容。
onnx_program = torch.onnx.dynamo_export(model, tensor_x)
使用 GUI 检查 ONNX 模型¶
您可以使用 Netron 查看导出的模型。
请注意,每个层都用一个矩形框表示,右上角有一个f图标。
展开后,将显示函数体。
函数体是 ONNX 算子或其他函数的序列。
转换失败时¶
函数 torch.onnx.export()
应该再次调用,参数为 report=True
。会生成一个 Markdown 报告来帮助用户解决问题。
函数 torch.onnx.dynamo_export()
使用 ‘SARIF’ 格式生成报告。ONNX 诊断通过采用 静态分析结果交换格式 (SARIF) 超越了常规日志,帮助用户使用 GUI(如 Visual Studio Code 的 SARIF 查看器)调试和改进模型。
主要优势是
诊断以机器可解析的 静态分析结果交换格式 (SARIF) 发出。
一种新的更清晰、结构化的方式来添加新的诊断规则并跟踪它们。
作为未来更多消耗诊断改进的基础。
- FXE0007:fx-graph-to-onnx
- FXE0008:fx-node-to-onnx
- FXE0010:fx-pass
- FXE0011:no-symbolic-function-for-call-function
- FXE0012:unsupported-fx-node-analysis
- FXE0013:op-level-debugging
- FXE0014:find-opschema-matched-symbolic-function
- FXE0015:fx-node-insert-type-promotion
- FXE0016:find-operator-overloads-in-onnx-registry
API 参考¶
- torch.onnx.dynamo_export(model, /, *model_args, export_options=None, **model_kwargs)[source]¶
将 torch.nn.Module 导出到 ONNX 图。
- 参数
model (torch.nn.Module | Callable | torch.export.ExportedProgram) – 要导出到 ONNX 的 PyTorch 模型。
model_args –
model
的位置输入。model_kwargs –
model
的关键字输入。export_options (ExportOptions | None) – 影响导出到 ONNX 的选项。
- 返回
导出的 ONNX 模型的内存表示。
- 返回类型
ONNXProgram | Any
示例 1 - 最简单的导出
class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x, bias=None): out = self.linear(x) out = out + bias return out model = MyModel() kwargs = {"bias": 3.0} args = (torch.randn(2, 2, 2),) onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save( "my_simple_model.onnx" )
示例 2 - 使用动态形状导出
# The previous model can be exported with dynamic shapes export_options = torch.onnx.ExportOptions(dynamic_shapes=True) onnx_program = torch.onnx.dynamo_export( model, *args, **kwargs, export_options=export_options ) onnx_program.save("my_dynamic_model.onnx")
- class torch.onnx.ExportOptions(*, dynamic_shapes=None, fake_context=None, onnx_registry=None, diagnostic_options=None)¶
影响 TorchDynamo ONNX 导出器的选项。
- 变量
dynamic_shapes (bool | None) – 输入/输出张量的形状信息提示。当为
None
时,导出器确定最兼容的设置。当为True
时,所有输入形状都被认为是动态的。当为False
时,所有输入形状都被认为是静态的。diagnostic_options (DiagnosticOptions) – 导出器的诊断选项。
fake_context (ONNXFakeContext | None) – 用于符号跟踪的伪上下文。
onnx_registry (OnnxRegistry | None) – 用于将 ATen 算子注册到 ONNX 函数的 ONNX 注册表。
- torch.onnx.enable_fake_mode()¶
在上下文持续时间内启用伪模式。
在内部,它实例化一个
torch._subclasses.fake_tensor.FakeTensorMode
上下文管理器,该管理器将用户输入和模型参数转换为torch._subclasses.fake_tensor.FakeTensor
。一个
torch._subclasses.fake_tensor.FakeTensor
是一个torch.Tensor
,它能够在不通过分配在meta
设备上的张量进行实际计算的情况下运行 PyTorch 代码。因为没有实际数据分配在设备上,所以此 API 允许导出大型模型,而无需实际执行它所需的内存占用。在导出超出内存容量的模型时,强烈建议启用伪模式。
- 返回
一个
ONNXFakeContext
对象。
示例
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) >>> import torch >>> import torch.onnx >>> class MyModel(torch.nn.Module): # Dummy model ... def __init__(self) -> None: ... super().__init__() ... self.linear = torch.nn.Linear(2, 2) ... def forward(self, x): ... out = self.linear(x) ... return out >>> with torch.onnx.enable_fake_mode() as fake_context: ... my_nn_module = MyModel() ... arg1 = torch.randn(2, 2, 2) # positional input 1 >>> export_options = torch.onnx.ExportOptions(fake_context=fake_context) >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True) >>> onnx_program.apply_weights(MyModel().state_dict()) >>> # Saving model WITHOUT initializers >>> onnx_program.save( ... "my_model_without_initializers.onnx", ... include_initializers=False, ... keep_initializers_as_inputs=True, ... ) >>> # Saving model WITH initializers >>> onnx_program.save("my_model_with_initializers.onnx")
警告
此 API 处于实验阶段,不向后兼容。
- class torch.onnx.ONNXProgram(model_proto, input_adapter, output_adapter, diagnostic_context, *, fake_context=None, export_exception=None, model_torch=None)¶
已导出到 ONNX 的 PyTorch 模型的内存表示。
- 参数
model_proto (onnx.ModelProto) – 导出的 ONNX 模型,作为
onnx.ModelProto
。input_adapter (io_adapter.InputAdapter) – 用于将 PyTorch 输入转换为 ONNX 输入的输入适配器。
output_adapter (io_adapter.OutputAdapter) – 用于将 PyTorch 输出转换为 ONNX 输出的输出适配器。
diagnostic_context (diagnostics.DiagnosticContext) – SARIF 诊断系统负责记录错误和元数据的上下文对象。
fake_context (ONNXFakeContext | None) – 用于符号跟踪的伪上下文。
export_exception (Exception | None) – 导出过程中发生的异常(如果有)。
- adapt_torch_inputs_to_onnx(*model_args, model_with_state_dict=None, **model_kwargs)[source]¶
将 PyTorch 模型输入转换为导出的 ONNX 模型输入格式。
由于设计上的差异,PyTorch 模型和导出的 ONNX 模型之间的输入/输出格式通常并不相同。例如,PyTorch 模型允许 None,但 ONNX 不支持。PyTorch 模型允许张量的嵌套结构,但 ONNX 只支持扁平化的张量,等等。
实际的适应步骤与每个单独的导出相关联。它取决于 PyTorch 模型、用于导出的特定 model_args 和 model_kwargs 集以及导出选项。
此方法重播导出期间记录的适应步骤。
- 参数
model_args – PyTorch 模型输入。
model_with_state_dict (torch.nn.Module | Callable | None) – 用于获取额外状态的 PyTorch 模型。如果未指定,则使用导出期间使用的模型。当使用
enable_fake_mode()
从 ONNX 图所需的真实初始化器中提取时,需要此参数。model_kwargs – PyTorch 模型关键字输入。
- 返回
从 PyTorch 模型输入转换而来的张量序列。
- 返回类型
Sequence[torch.Tensor | int | float | bool | torch.dtype]
示例
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) >>> import torch >>> import torch.onnx >>> from typing import Dict, Tuple >>> def func_nested_input( ... x_dict: Dict[str, torch.Tensor], ... y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ... ): ... if "a" in x_dict: ... x = x_dict["a"] ... elif "b" in x_dict: ... x = x_dict["b"] ... else: ... x = torch.randn(3) ... ... y1, (y2, y3) = y_tuple ... ... return x + y1 + y2 + y3 >>> x_dict = {"a": torch.tensor(1.)} >>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.))) >>> onnx_program = torch.onnx.dynamo_export(func_nested_input, x_dict, y_tuple) >>> print(x_dict, y_tuple) {'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.))) >>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple, model_with_state_dict=func_nested_input)) (tensor(1.), tensor(2.), tensor(3.), tensor(4.))
警告
此 API 处于实验阶段,不向后兼容。
- adapt_torch_outputs_to_onnx(model_outputs, model_with_state_dict=None)[source]¶
将 PyTorch 模型输出转换为导出的 ONNX 模型输出格式。
由于设计上的差异,PyTorch 模型和导出的 ONNX 模型之间的输入/输出格式通常并不相同。例如,PyTorch 模型允许 None,但 ONNX 不支持。PyTorch 模型允许张量的嵌套结构,但 ONNX 只支持扁平化的张量,等等。
实际的适应步骤与每个单独的导出相关联。它取决于 PyTorch 模型、用于导出的特定 model_args 和 model_kwargs 集以及导出选项。
此方法重播导出期间记录的适应步骤。
- 参数
model_outputs (Any) – PyTorch 模型输出。
model_with_state_dict (torch.nn.Module | Callable | None) – 用于获取额外状态的 PyTorch 模型。如果未指定,则使用导出期间使用的模型。当使用
enable_fake_mode()
从 ONNX 图所需的真实初始化器中提取时,需要此参数。
- 返回
导出的 ONNX 模型输出格式中的 PyTorch 模型输出。
- 返回类型
Sequence[torch.Tensor | int | float | bool]
示例
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) >>> import torch >>> import torch.onnx >>> def func_returning_tuples(x, y, z): ... x = x + y ... y = y + z ... z = x + y ... return (x, (y, z)) >>> x = torch.tensor(1.) >>> y = torch.tensor(2.) >>> z = torch.tensor(3.) >>> onnx_program = torch.onnx.dynamo_export(func_returning_tuples, x, y, z) >>> pt_output = func_returning_tuples(x, y, z) >>> print(pt_output) (tensor(3.), (tensor(5.), tensor(8.))) >>> print(onnx_program.adapt_torch_outputs_to_onnx(pt_output, model_with_state_dict=func_returning_tuples)) [tensor(3.), tensor(5.), tensor(8.)]
警告
此 API 处于实验阶段,不向后兼容。
- apply_weights(state_dict)[source]¶
将指定状态字典中的权重应用于 ONNX 模型。 :param state_dict: 包含要应用于 ONNX 模型的权重的状态字典。
- property diagnostic_context: diagnostics.DiagnosticContext¶
与导出相关的诊断上下文。
- property model_proto: onnx.ModelProto¶
作为
onnx.ModelProto
导出的 ONNX 模型。
- save(destination, *, include_initializers=True, model_state=None)[source]¶
使用指定的
serializer
将内存中的 ONNX 模型保存到destination
。- 参数
destination (str | io.BufferedIOBase) – 保存 ONNX 模型的目标位置。它可以是字符串或类文件对象。当与
model_state
一起使用时,它必须是一个指向目标位置的完整路径字符串。如果 destination 是字符串,除了将 ONNX 模型保存到文件外,模型权重还会存储在与 ONNX 模型相同的目录中的单独文件中。例如,对于 destination=”/path/model.onnx”,初始化器将保存在“/path/”文件夹中,以及“onnx.model”。include_initializers (bool) – 是否将初始化器作为外部数据包含在 ONNX 图中。不能与 model_state_dict 结合使用。
model_state (dict[str, Any] | str | None) – 包含其所有权重的 PyTorch 模型的状态字典。它可以是指向检查点的路径的字符串,也可以是包含实际模型状态的字典。支持的文件格式与 torch.load 和 safetensors.safe_open 支持的格式相同。当使用
enable_fake_mode()
但 ONNX 图需要真实的初始化器时,这是必需的。
- save_diagnostics(destination)[source]¶
将导出诊断作为 SARIF 日志保存到指定的目标路径。
- 参数
destination (str) – 保存诊断 SARIF 日志的目标位置。它必须具有 .sarif 扩展名。
- Raises
ValueError – 如果目标路径不以 .sarif 扩展名结尾。
- class torch.onnx.ONNXRuntimeOptions(*, session_options=None, execution_providers=None, execution_provider_options=None)¶
影响通过 ONNX 运行时执行 ONNX 模型的选项。
- class torch.onnx.OnnxExporterError¶
ONNX 导出器引发的错误。这是所有导出器错误的基类。
- class torch.onnx.OnnxRegistry¶
ONNX 函数的注册表。
该注册表维护着从限定名到固定操作集版本的符号函数的映射。它支持注册自定义 onnx-script 函数,并支持调度程序将调用调度到相应的函数。
- get_op_functions(namespace, op_name, overload=None)[source]¶
返回给定操作的 ONNXFunction 列表:torch.ops.<namespace>.<op_name>.<overload>。
该列表按注册时间排序。自定义运算符应位于列表的后半部分。
- is_registered_op(namespace, op_name, overload=None)[source]¶
返回给定操作是否已注册:torch.ops.<namespace>.<op_name>.<overload>。
- property opset_version: int¶
导出器应针对的 ONNX 操作集版本。默认为最新支持的 ONNX 操作集版本:18。随着 ONNX 的不断发展,默认版本会随着时间的推移而增加。
- register_op(function, namespace, op_name, overload=None, is_complex=False)[source]¶
注册自定义运算符:torch.ops.<namespace>.<op_name>.<overload>。
- 参数
- Raises
ValueError – 如果名称不符合‘namespace::op’的形式。