• 教程 >
  • 扩展 ONNX 注册表
快捷方式

ONNX 入门 || 将 PyTorch 模型导出到 ONNX || 扩展 ONNX 注册表

扩展 ONNX 注册表

作者:王帝泰 (titaiwang@microsoft.com)

概述

本教程介绍了 ONNX 注册表,它使用户能够实现新的 ONNX 运算符,甚至用新的实现替换现有的运算符。

在将模型导出到 ONNX 期间,PyTorch 模型会被降低到一个由 ATen 运算符 组成的中间表示形式。虽然 ATen 运算符由 PyTorch 核心团队维护,但 ONNX 导出团队有责任通过 ONNX Script 将这些运算符中的每一个独立地实现到 ONNX 中。用户还可以用自己的实现替换 ONNX 导出团队实现的行为,以修复错误或提高特定 ONNX 运行时的性能。

ONNX 注册表管理 PyTorch 运算符与其对应的 ONNX 运算符之间的映射,并提供用于扩展注册表的 API。

在本教程中,我们将介绍三种需要使用自定义运算符扩展 ONNX 注册表的场景

  • 不受支持的 ATen 运算符

  • 具有现有 ONNX 运行时支持的自定义运算符

  • 没有 ONNX 运行时支持的自定义运算符

不受支持的 ATen 运算符

尽管 ONNX 导出团队尽最大努力支持所有 ATen 运算符,但其中一些可能尚未得到支持。在本节中,我们将演示如何将不受支持的 ATen 运算符添加到 ONNX 注册表中。

注意

实现不受支持的 ATen 算子的步骤与使用自定义实现替换现有 ATen 算子的实现相同。因为在本教程中我们实际上没有不受支持的 ATen 算子可以使用,所以我们将利用这一点,并以与算子不存在于 ONNX 注册表中的相同方式,使用自定义实现替换 aten::add.Tensor 的实现。

当模型由于不支持的算子而无法导出到 ONNX 时,ONNX 导出器将显示类似以下的错误消息

RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.add.Tensor']}.

错误消息表明不受支持的 ATen 算子的完全限定名称为 aten::add.Tensor。算子的完全限定名称由命名空间、算子名称和重载组成,格式为 namespace::operator_name.overload

要添加对不受支持的 ATen 算子的支持或替换现有算子的实现,我们需要

  • ATen 算子的完全限定名称(例如 aten::add.Tensor)。此信息始终存在于上述错误消息中。

  • 使用 ONNX Script 实现算子。ONNX Script 是本教程的先决条件。请确保在继续之前已阅读 ONNX Script 教程

由于 aten::add.Tensor 已经得到 ONNX 注册表的支持,我们将演示如何用自定义实现替换它,但请记住,相同的步骤也适用于支持新的不受支持的 ATen 算子。

这是可能的,因为 OnnxRegistry 允许用户覆盖算子注册。我们将使用我们的自定义实现覆盖 aten::add.Tensor 的注册,并验证其是否存在。

import torch
import onnxruntime
import onnxscript
from onnxscript import opset18  # opset 18 is the latest (and only) supported version for now

class Model(torch.nn.Module):
    def forward(self, input_x, input_y):
        return torch.ops.aten.add(input_x, input_y)  # generates a aten::add.Tensor node

input_add_x = torch.randn(3, 4)
input_add_y = torch.randn(3, 4)
aten_add_model = Model()


# Now we create a ONNX Script function that implements ``aten::add.Tensor``.
# The function name (e.g. ``custom_aten_add``) is displayed in the ONNX graph, so we recommend to use intuitive names.
custom_aten = onnxscript.values.Opset(domain="custom.aten", version=1)

# NOTE: The function signature must match the signature of the unsupported ATen operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# NOTE: All attributes must be annotated with type hints.
@onnxscript.script(custom_aten)
def custom_aten_add(input_x, input_y, alpha: float = 1.0):
    input_y = opset18.Mul(input_y, alpha)
    return opset18.Add(input_x, input_y)


# Now we have everything we need to support unsupported ATen operators.
# Let's register the ``custom_aten_add`` function to ONNX registry, and export the model to ONNX again.
onnx_registry = torch.onnx.OnnxRegistry()
onnx_registry.register_op(
    namespace="aten", op_name="add", overload="Tensor", function=custom_aten_add
    )
print(f"aten::add.Tensor is supported by ONNX registry: \
      {onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='Tensor')}"
      )
export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
onnx_program = torch.onnx.dynamo_export(
    aten_add_model, input_add_x, input_add_y, export_options=export_options
    )
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/onnx/_internal/_exporter_legacy.py:116: UserWarning:

torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.

aten::add.Tensor is supported by ONNX registry:       True

现在让我们检查模型并验证模型是否具有 custom_aten_add 而不是 aten::add.Tensor。图中有一个 custom_aten_add 的图节点,在其内部有四个函数节点,每个算子一个,以及一个常量属性。

# graph node domain is the custom domain we registered
assert onnx_program.model_proto.graph.node[0].domain == "custom.aten"
assert len(onnx_program.model_proto.graph.node) == 1
# graph node name is the function name
assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_add"
# function node domain is empty because we use standard ONNX operators
assert {node.domain for node in onnx_program.model_proto.functions[0].node} == {""}
# function node name is the standard ONNX operator name
assert {node.op_type for node in onnx_program.model_proto.functions[0].node} == {"Add", "Mul", "Constant"}

这是使用 Netron 在 ONNX 图中 custom_aten_add_model 的外观

../../_images/custom_aten_add_model.png

custom_aten_add 函数内部,我们可以看到我们在函数中使用的三个 ONNX 节点(CastLikeAddMul),以及一个 Constant 属性

../../_images/custom_aten_add_function.png

这就是我们将新 ATen 算子注册到 ONNX 注册表中所需做的全部操作。作为额外步骤,我们可以使用 ONNX Runtime 运行模型,并将结果与 PyTorch 进行比较。

# Use ONNX Runtime to run the model, and compare the results with PyTorch
onnx_program.save("./custom_add_model.onnx")
ort_session = onnxruntime.InferenceSession(
    "./custom_add_model.onnx", providers=['CPUExecutionProvider']
    )

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_add_x, input_add_y)
onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)

torch_outputs = aten_add_model(input_add_x, input_add_y)
torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)

assert len(torch_outputs) == len(onnxruntime_outputs)
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
    torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))

具有现有 ONNX Runtime 支持的自定义算子

在这种情况下,用户使用标准 PyTorch 算子创建模型,但 ONNX 运行时(例如 Microsoft 的 ONNX Runtime)可以为该内核提供自定义实现,有效地替换 ONNX 注册表中的现有实现。另一个用例是当用户想要使用现有 ONNX 算子的自定义实现来修复错误或提高特定算子的性能时。为了实现这一点,我们只需要将新的实现与现有的 ATen 完全限定名称一起注册。

在以下示例中,我们使用 ONNX Runtime 中的 com.microsoft.Gelu,它与 ONNX 规范中的 Gelu 不同。因此,我们将 Gelu 与命名空间 com.microsoft 和算子名称 Gelu 一起注册。

在开始之前,让我们检查 aten::gelu.default 是否确实得到 ONNX 注册表的支持。

onnx_registry = torch.onnx.OnnxRegistry()
print(f"aten::gelu.default is supported by ONNX registry: \
    {onnx_registry.is_registered_op(namespace='aten', op_name='gelu', overload='default')}")
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/onnx/_internal/_exporter_legacy.py:116: UserWarning:

torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.

aten::gelu.default is supported by ONNX registry:     True

在我们的示例中,aten::gelu.default 算子得到 ONNX 注册表的支持,因此 onnx_registry.is_registered_op() 返回 True

class CustomGelu(torch.nn.Module):
    def forward(self, input_x):
        return torch.ops.aten.gelu(input_x)

# com.microsoft is an official ONNX Runtime namspace
custom_ort = onnxscript.values.Opset(domain="com.microsoft", version=1)

# NOTE: The function signature must match the signature of the unsupported ATen operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# NOTE: All attributes must be annotated with type hints.
@onnxscript.script(custom_ort)
def custom_aten_gelu(input_x, approximate: str = "none"):
    # We know com.microsoft::Gelu is supported by ONNX Runtime
    # It's only not supported by ONNX
    return custom_ort.Gelu(input_x)


onnx_registry = torch.onnx.OnnxRegistry()
onnx_registry.register_op(
    namespace="aten", op_name="gelu", overload="default", function=custom_aten_gelu)
export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)

aten_gelu_model = CustomGelu()
input_gelu_x = torch.randn(3, 3)

onnx_program = torch.onnx.dynamo_export(
    aten_gelu_model, input_gelu_x, export_options=export_options
    )
'Gelu' is not a known op in 'com.microsoft'
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/onnx/_internal/_exporter_legacy.py:116: UserWarning:

torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.

让我们检查模型并验证模型是否使用了来自命名空间 com.microsoft 的 op_type Gelu

注意

custom_aten_gelu() 不存在于图中,因为少于三个算子的函数会自动内联。

# graph node domain is the custom domain we registered
assert onnx_program.model_proto.graph.node[0].domain == "com.microsoft"
# graph node name is the function name
assert onnx_program.model_proto.graph.node[0].op_type == "Gelu"

下图显示了使用 Netron 的 custom_aten_gelu_model ONNX 图,我们可以看到函数中使用了来自模块 com.microsoftGelu 节点

../../_images/custom_aten_gelu_model.png

这就是我们需要做的全部操作。作为额外步骤,我们可以使用 ONNX Runtime 运行模型,并将结果与 PyTorch 进行比较。

onnx_program.save("./custom_gelu_model.onnx")
ort_session = onnxruntime.InferenceSession(
    "./custom_gelu_model.onnx", providers=['CPUExecutionProvider']
    )

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_gelu_x)
onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)

torch_outputs = aten_gelu_model(input_gelu_x)
torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)

assert len(torch_outputs) == len(onnxruntime_outputs)
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
    torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))

没有 ONNX Runtime 支持的自定义算子

在这种情况下,算子不受任何 ONNX 运行时支持,但我们希望将其用作 ONNX 图中的自定义算子。因此,我们需要在三个地方实现该算子

  1. PyTorch FX 图

  2. ONNX 注册表

  3. ONNX Runtime

在以下示例中,我们希望使用一个自定义算子,该算子接收一个张量输入并返回一个输出。该算子将输入添加到自身,并返回舍入后的结果。

在 PyTorch FX 图中注册自定义算子(Beta)

首先,我们需要在 PyTorch FX 图中实现该算子。这可以通过使用 torch._custom_op 来完成。

# NOTE: This is a beta feature in PyTorch, and is subject to change.
from torch._custom_op import impl as custom_op

@custom_op.custom_op("mylibrary::addandround_op")
def addandround_op(tensor_x: torch.Tensor) -> torch.Tensor:
    ...

@addandround_op.impl_abstract()
def addandround_op_impl_abstract(tensor_x):
    return torch.empty_like(tensor_x)

@addandround_op.impl("cpu")
def addandround_op_impl(tensor_x):
    return torch.round(tensor_x + tensor_x)  # add x to itself, and round the result

torch._dynamo.allow_in_graph(addandround_op)

class CustomFoo(torch.nn.Module):
    def forward(self, tensor_x):
        return addandround_op(tensor_x)

input_addandround_x = torch.randn(3)
custom_addandround_model = CustomFoo()

在 ONNX 注册表中注册自定义算子

对于步骤 2 和 3,我们需要在 ONNX 注册表中实现该算子。在本示例中,我们将在 ONNX 注册表中使用命名空间 test.customop 和算子名称 CustomOpOne 以及 CustomOpTwo 来实现该算子。这两个算子在 cpu_ops.cc 中注册和构建。

custom_opset = onnxscript.values.Opset(domain="test.customop", version=1)

# NOTE: The function signature must match the signature of the unsupported ATen operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# NOTE: All attributes must be annotated with type hints.
@onnxscript.script(custom_opset)
def custom_addandround(input_x):
    # The same as opset18.Add(x, x)
    add_x = custom_opset.CustomOpOne(input_x, input_x)
    # The same as opset18.Round(x, x)
    round_x = custom_opset.CustomOpTwo(add_x)
    # Cast to FLOAT to match the ONNX type
    return opset18.Cast(round_x, to=1)


onnx_registry = torch.onnx.OnnxRegistry()
onnx_registry.register_op(
    namespace="mylibrary", op_name="addandround_op", overload="default", function=custom_addandround
    )

export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
onnx_program = torch.onnx.dynamo_export(
    custom_addandround_model, input_addandround_x, export_options=export_options
    )
onnx_program.save("./custom_addandround_model.onnx")
'CustomOpOne' is not a known op in 'test.customop'
'CustomOpTwo' is not a known op in 'test.customop'
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/onnx/_internal/_exporter_legacy.py:116: UserWarning:

torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.

onnx_program 通过 onnx_program.model_proto 将导出的模型作为 protobuf 公开。图中有一个 custom_addandround 的图节点,在 custom_addandround 内部有两个函数节点,每个算子一个。

assert onnx_program.model_proto.graph.node[0].domain == "test.customop"
assert onnx_program.model_proto.graph.node[0].op_type == "custom_addandround"
assert onnx_program.model_proto.functions[0].node[0].domain == "test.customop"
assert onnx_program.model_proto.functions[0].node[0].op_type == "CustomOpOne"
assert onnx_program.model_proto.functions[0].node[1].domain == "test.customop"
assert onnx_program.model_proto.functions[0].node[1].op_type == "CustomOpTwo"

这是使用 Netron 查看 custom_addandround_model ONNX 图的外观

../../_images/custom_addandround_model.png

custom_addandround 函数内部,我们可以看到我们在函数中使用的两个自定义算子(CustomOpOneCustomOpTwo),它们来自模块 test.customop

../../_images/custom_addandround_function.png

在 ONNX Runtime 中注册自定义算子

要将自定义算子库链接到 ONNX Runtime,您需要将 C++ 代码编译成共享库并将其链接到 ONNX Runtime。请按照以下说明操作

  1. 按照 ONNX Runtime 指令 使用 C++ 实现自定义算子。

  2. ONNX Runtime 版本 下载 ONNX Runtime 源代码分发版。

  3. 编译并将自定义算子库链接到 ONNX Runtime,例如

$ gcc -shared -o libcustom_op_library.so custom_op_library.cc -L /path/to/downloaded/ort/lib/ -lonnxruntime -fPIC
  1. 使用 ONNX Runtime Python API 运行模型,并将结果与 PyTorch 进行比较。

ort_session_options = onnxruntime.SessionOptions()

# NOTE: Link the custom op library to ONNX Runtime and replace the path
# with the path to your custom op library
ort_session_options.register_custom_ops_library(
    "/path/to/libcustom_op_library.so"
)
ort_session = onnxruntime.InferenceSession(
    "./custom_addandround_model.onnx", providers=['CPUExecutionProvider'], sess_options=ort_session_options)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_addandround_x)
onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)

torch_outputs = custom_addandround_model(input_addandround_x)
torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)

assert len(torch_outputs) == len(onnxruntime_outputs)
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
    torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))

结论

恭喜!在本教程中,我们探索了 ONNXRegistry API,并了解了如何使用 ONNX Script 为不受支持或现有的 ATen 算子创建自定义实现。最后,我们利用 ONNX Runtime 执行模型并将结果与 PyTorch 进行比较,从而使我们全面了解了在 ONNX 生态系统中处理不受支持的算子的方法。

进一步阅读

以下列表列出了从基本示例到高级场景的教程,不一定按照列出的顺序排列。您可以随意直接跳转到您感兴趣的特定主题,或者坐下来享受所有教程,学习有关 ONNX 导出器的所有内容。

脚本的总运行时间:(0 分 3.184 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源