快捷方式

基于 TorchScript 的 ONNX 导出器

注意

要使用 TorchDynamo 而不是 TorchScript 导出 ONNX 模型,请参阅了解更多关于基于 TorchDynamo 的 ONNX 导出器的信息

示例:从 PyTorch 到 ONNX 的 AlexNet

这是一个简单的脚本,将预训练的 AlexNet 导出到名为 alexnet.onnx 的 ONNX 文件。调用 torch.onnx.export 运行模型一次以追踪其执行,然后将追踪的模型导出到指定文件

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

生成的 alexnet.onnx 文件包含一个二进制协议缓冲区,其中包含您导出的模型的网络结构和参数(在本例中为 AlexNet)。参数 verbose=True 使导出器打印出模型的人工可读表示

# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
      %learned_0 : Float(64, 3, 11, 11)
      %learned_1 : Float(64)
      %learned_2 : Float(192, 64, 5, 5)
      %learned_3 : Float(192)
      # ---- omitted for brevity ----
      %learned_14 : Float(1000, 4096)
      %learned_15 : Float(1000)) {
  # Every statement consists of some output tensors (and their types),
  # the operator to be run (with its attributes, e.g., kernels, strides,
  # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
  %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
  %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
  %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
  # ---- omitted for brevity ----
  %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
  # Dynamic means that the shape is not known. This may be because of a
  # limitation of our implementation (which we would like to fix in a
  # future release) or shapes which are truly dynamic.
  %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
  %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
  %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
  %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
  # ---- omitted for brevity ----
  %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
  return (%output1);
}

您还可以使用 ONNX 库验证输出,您可以使用 pip 安装它

pip install onnx

然后,您可以运行

import onnx

# Load the ONNX model
model = onnx.load("alexnet.onnx")

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

您还可以使用许多支持 ONNX 的运行时之一运行导出的模型。例如,在安装 ONNX Runtime 后,您可以加载并运行模型

import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession("alexnet.onnx")

outputs = ort_session.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])

这是一个更详细的关于导出模型并使用 ONNX Runtime 运行它的教程

追踪 vs 脚本

在内部,torch.onnx.export() 需要 torch.jit.ScriptModule 而不是 torch.nn.Module。如果传入的模型还不是 ScriptModuleexport() 将使用追踪将其转换为一个

  • 追踪:如果使用 Module 调用 torch.onnx.export(),且该 Module 还不是 ScriptModule,则它首先执行等效于 torch.jit.trace() 的操作,即使用给定的 args 执行模型一次,并记录执行期间发生的所有操作。这意味着如果您的模型是动态的,例如,根据输入数据更改行为,则导出的模型将不会捕获此动态行为。我们建议检查导出的模型,并确保运算符看起来合理。追踪将展开循环和 if 语句,导出与追踪运行完全相同的静态图。如果您想导出具有动态控制流的模型,则需要使用脚本

  • 脚本:通过脚本编译模型可以保留动态控制流,并且对于不同大小的输入都有效。要使用脚本

    • 使用 torch.jit.script() 生成 ScriptModule

    • 使用 ScriptModule 作为模型调用 torch.onnx.export()args 仍然是必需的,但它们将在内部仅用于生成示例输出,以便可以捕获输出的类型和形状。不会执行追踪。

有关更多详细信息,包括如何组合追踪和脚本以适应不同模型的特定要求,请参阅TorchScript 简介TorchScript

避免陷阱

避免 NumPy 和内置 Python 类型

可以使用 NumPy 或 Python 类型和函数编写 PyTorch 模型,但在追踪期间,任何 NumPy 或 Python 类型的变量(而不是 torch.Tensor)都会转换为常量,如果这些值应根据输入而更改,则会产生错误的结果。

例如,不要在 numpy.ndarrays 上使用 numpy 函数

# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)

在 torch.Tensors 上使用 torch 运算符

# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)

并且不要使用 torch.Tensor.item()(它将 Tensor 转换为 Python 内置数字)

# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
    return x.reshape(y.item(), -1)

使用 torch 对单元素张量的隐式转换的支持

# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
    return x.reshape(y, -1)

避免 Tensor.data

使用 Tensor.data 字段可能会产生不正确的追踪,从而导致不正确的 ONNX 图。请改用 torch.Tensor.detach()。(正在进行完全删除 Tensor.data的工作)。

在追踪模式下使用 tensor.shape 时避免原地操作

在追踪模式下,从 tensor.shape 获取的形状被追踪为张量,并共享相同的内存。这可能会导致最终输出值不匹配。作为一种解决方法,在这些情况下避免使用原地操作。例如,在模型中

class Model(torch.nn.Module):
  def forward(self, states):
      batch_size, seq_length = states.shape[:2]
      real_seq_length = seq_length
      real_seq_length += 2
      return real_seq_length + seq_length

real_seq_lengthseq_length 在追踪模式下共享相同的内存。可以通过重写原地操作来避免这种情况

real_seq_length = real_seq_length + 2

局限性

类型

  • 只有 torch.Tensors,可以轻松转换为 torch.Tensors 的数字类型(例如 float、int),以及这些类型的元组和列表才支持作为模型输入或输出。字典和字符串输入和输出在追踪模式下被接受,但是

    • 任何依赖于字典或字符串输入值的计算将被替换为在一次追踪执行期间看到的常量值

    • 任何作为字典的输出都将被静默地替换为其值的扁平化序列(键将被删除)。例如,{"foo": 1, "bar": 2} 变为 (1, 2)

    • 任何作为字符串的输出都将被静默地删除。

  • 由于 ONNX 对嵌套序列的支持有限,因此在脚本模式下不支持某些涉及元组和列表的操作。特别是不支持将元组附加到列表。在追踪模式下,嵌套序列将在追踪期间自动展平。

运算符实现中的差异

由于运算符实现中的差异,在不同的运行时上运行导出的模型可能会产生彼此之间或与 PyTorch 不同的结果。通常,这些差异在数值上很小,因此只有当您的应用程序对这些小差异敏感时,这才会成为问题。

不支持的张量索引模式

下面列出了无法导出的张量索引模式。如果您在导出不包含以下任何不支持模式的模型时遇到问题,请仔细检查您是否使用了最新的 opset_version 进行导出。

读取 / 获取

当索引到张量以进行读取时,不支持以下模式

# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.

写入 / 设置

当索引到张量以进行写入时,不支持以下模式

# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
#              or multiple consecutive tensor indices with rank == 1.

# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.

# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.

# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
#   data shape: [3, 4, 5]
#   new_data shape: [5]
#   expected new_data shape after broadcasting: [2, 2, 2, 5]

添加对运算符的支持

当导出包含不支持的运算符的模型时,您将看到类似以下的错误消息

RuntimeError: ONNX export failed: Couldn't export operator foo

发生这种情况时,您可以执行以下几项操作

  1. 更改模型以不使用该运算符。

  2. 创建符号函数来转换运算符,并将其注册为自定义符号函数。

  3. 贡献给 PyTorch,将相同的符号函数添加到 torch.onnx 本身。

如果您决定实现符号函数(我们希望您将其贡献回 PyTorch!),以下是您可以开始使用的方法

ONNX 导出器内部机制

“符号函数”是将 PyTorch 运算符分解为一系列 ONNX 运算符的组合的函数。

在导出期间,导出器按拓扑顺序访问 TorchScript 图中的每个节点(其中包含 PyTorch 运算符)。在访问节点时,导出器会查找该运算符的已注册符号函数。符号函数在 Python 中实现。名为 foo 的运算符的符号函数看起来像这样

def foo(
  g,
  input_0: torch._C.Value,
  input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
  """
  Adds the ONNX operations representing this PyTorch function by updating the
  graph g with `g.op()` calls.

  Args:
    g (Graph): graph to write the ONNX representation into.
    input_0 (Value): value representing the variables which contain
        the first input for this operator.
    input_1 (Value): value representing the variables which contain
        the second input for this operator.

  Returns:
    A Value or List of Values specifying the ONNX nodes that compute something
    equivalent to the original PyTorch operator with the given inputs.

    None if it cannot be converted to ONNX.
  """
  ...

torch._C 类型是 C++ 中 ir.h 中定义的类型的 Python 包装器。

添加符号函数的过程取决于运算符的类型。

ATen 运算符

ATen 是 PyTorch 的内置张量库。如果运算符是 ATen 运算符(在带有前缀 aten:: 的 TorchScript 图中显示),请确保它尚未受支持。

支持的运算符列表

访问自动生成的支持的 TorchScript 运算符列表,以获取有关每个 opset_version 中支持哪些运算符的详细信息。

添加对 aten 或量化运算符的支持

如果运算符不在上面的列表中

  • torch/onnx/symbolic_opset<version>.py 中定义符号函数,例如 torch/onnx/symbolic_opset9.py。确保该函数具有与 ATen 函数相同的名称,该名称可能在 torch/_C/_VariableFunctions.pyitorch/nn/functional.pyi 中声明(这些文件在构建时生成,因此在您构建 PyTorch 之前不会出现在您的检出中)。

  • 默认情况下,第一个参数是 ONNX 图。其他参数名称必须与 .pyi 文件中的名称完全匹配,因为调度是使用关键字参数完成的。

  • 在符号函数中,如果运算符在ONNX 标准运算符集中,我们只需要创建一个节点来表示图中的 ONNX 运算符。如果不是,我们可以组合几个标准运算符,这些运算符具有与 ATen 运算符等效的语义。

这是一个处理 ELU 运算符的缺失符号函数的示例。

如果我们运行以下代码

print(
    torch.jit.trace(
        torch.nn.ELU(), # module
        torch.ones(1)   # example input
    ).graph
)

我们看到类似这样的内容

graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
      %input : Float(1, strides=[1], requires_grad=0, device=cpu)):
  %4 : float = prim::Constant[value=1.]()
  %5 : int = prim::Constant[value=1]()
  %6 : int = prim::Constant[value=1]()
  %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
  return (%7)

由于我们在图中看到 aten::elu,我们知道这是一个 ATen 运算符。

我们检查ONNX 运算符列表,并确认 Elu 在 ONNX 中是标准化的。

我们在 torch/nn/functional.pyi 中找到 elu 的签名

def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...

我们将以下行添加到 symbolic_opset9.py

def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
    return g.op("Elu", input, alpha_f=alpha)

现在 PyTorch 能够导出包含 aten::elu 运算符的模型了!

有关更多示例,请参阅 torch/onnx/symbolic_opset*.py 文件。

torch.autograd.Functions

如果运算符是 torch.autograd.Function 的子类,则有三种导出方法。

静态符号方法

您可以向您的函数类添加一个名为 symbolic 的静态方法。它应该返回表示 ONNX 中函数行为的 ONNX 运算符。例如

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
        return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))

内联 Autograd 函数

如果在其后续 torch.autograd.Function 中未提供静态符号方法,或者未提供注册 prim::PythonOp 作为自定义符号函数的函数,则 torch.onnx.export() 尝试内联与该 torch.autograd.Function 对应的图,以便将此函数分解为函数中使用的各个运算符。只要支持这些单个运算符,导出就应该成功。例如

class MyLogExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        h = input.exp()
        return h.log().log()

此模型没有静态符号方法,但它按如下方式导出

graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
    %1 : float = onnx::Exp[](%input)
    %2 : float = onnx::Log[](%1)
    %3 : float = onnx::Log[](%2)
    return (%3)

如果您需要避免内联 torch.autograd.Function,则应使用设置为 ONNX_FALLTHROUGHONNX_ATEN_FALLBACKoperator_export_type 导出模型。

自定义运算符

您可以导出带有自定义运算符的模型,该运算符包括许多标准 ONNX 运算符的组合,或者由自定义的 C++ 后端驱动。

ONNX-script 函数

如果运算符不是标准 ONNX 运算符,但可以由多个现有 ONNX 运算符组成,则可以使用 ONNX-script 创建外部 ONNX 函数以支持该运算符。您可以通过遵循此示例来导出它

import onnxscript
# There are three opset version needed to be aligned
# This is (1) the opset version in ONNX function
from onnxscript.onnx_opset import opset15 as op
opset_version = 15

x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()

custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)

@onnxscript.script(custom_opset)
def Selu(X):
    alpha = 1.67326  # auto wrapped as Constants
    gamma = 1.0507
    alphaX = op.CastLike(alpha, X)
    gammaX = op.CastLike(gamma, X)
    neg = gammaX * (alphaX * op.Exp(X) - alphaX)
    pos = gammaX * X
    zero = op.CastLike(0, X)
    return op.Where(X <= zero, neg, pos)

# setType API provides shape/type to ONNX shape/type inference
def custom_selu(g: jit_utils.GraphContext, X):
    return g.onnxscript_op(Selu, X).setType(X.type())

# Register custom symbolic function
# There are three opset version needed to be aligned
# This is (2) the opset version in registry
torch.onnx.register_custom_op_symbolic(
    symbolic_name="aten::selu",
    symbolic_fn=custom_selu,
    opset_version=opset_version,
)

# There are three opset version needed to be aligned
# This is (2) the opset version in exporter
torch.onnx.export(
    model,
    x,
    "model.onnx",
    opset_version=opset_version,
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"onnx-script": 2}
)

上面的示例将其作为 “onnx-script” 运算符集中的自定义运算符导出。导出自定义运算符时,您可以使用导出时的 custom_opsets 字典指定自定义域版本。如果未指定,则自定义运算符集版本默认为 1。

注意:请注意对齐上述示例中提到的运算符集版本,并确保它们在导出步骤中被使用。关于如何编写 onnx-script 函数的示例用法是 onnx-script 积极开发中的 beta 版本。请关注最新的ONNX-script

C++ 运算符

如果模型使用以 C++ 实现的自定义运算符,如使用自定义 C++ 运算符扩展 TorchScript中所述,您可以通过遵循此示例来导出它

from torch.onnx import symbolic_helper


# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)


# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)


class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super().__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # Calling custom op
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)


model = FooModel(attr1, attr2)
torch.onnx.export(
    model,
    (example_input1, example_input1),
    "model.onnx",
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"custom_domain": 2}
)

上面的示例将其作为 “custom_domain” 运算符集中的自定义运算符导出。导出自定义运算符时,您可以使用导出时的 custom_opsets 字典指定自定义域版本。如果未指定,则自定义运算符集版本默认为 1。

使用模型的运行时需要支持自定义运算符。请参阅 Caffe2 自定义运算符ONNX Runtime 自定义运算符或您选择的运行时的文档。

一次性发现所有不可转换的 ATen 运算符

当由于不可转换的 ATen 运算符导致导出失败时,实际上可能存在多个此类运算符,但错误消息仅提及第一个。要一次性发现所有不可转换的运算符,您可以

# prepare model, args, opset_version
...

torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
    model, args, opset_version=opset_version
)

print(set(unconvertible_ops))

该集合是近似的,因为某些运算符可能会在转换过程中被删除,而不需要转换。其他一些运算符可能具有部分支持,这些支持在特定输入下会转换失败,但这应该让您大致了解哪些运算符不受支持。请随时为运算符支持请求打开 GitHub Issue。

常见问题解答

问:我导出了我的 LSTM 模型,但其输入大小似乎是固定的?

追踪器记录示例输入的形状。如果模型应接受动态形状的输入,请在调用 torch.onnx.export() 时设置 dynamic_axes

问:如何导出包含循环的模型?

请参阅追踪 vs 脚本

问:如何导出具有原始类型输入(例如 int、float)的模型?

对原始数字类型输入的支持已在 PyTorch 1.9 中添加。但是,导出器不支持带有字符串输入的模型。

问:ONNX 是否支持隐式标量数据类型转换?

ONNX 标准不支持,但导出器将尝试处理该部分。标量导出为常量张量。导出器将找出标量的正确数据类型。在极少数情况下,当它无法做到这一点时,您需要手动指定数据类型,例如 dtype=torch.float32。如果您看到任何错误,请[创建 GitHub issue](https://github.com/pytorch/pytorch/issues)。

问:张量列表可以导出到 ONNX 吗?

是的,对于 opset_version >= 11,因为 ONNX 在 opset 11 中引入了 Sequence 类型。

Python API

函数

torch.onnx.export(model, args=(), f=None, *, kwargs=None, export_params=True, verbose=None, input_names=None, output_names=None, opset_version=None, dynamic_axes=None, keep_initializers_as_inputs=False, dynamo=False, external_data=True, dynamic_shapes=None, custom_translation_table=None, report=False, optimize=False, verify=False, profile=False, dump_exported_program=False, artifacts_dir='.', fallback=False, training=<TrainingMode.EVAL: 0>, operator_export_type=<OperatorExportTypes.ONNX: 0>, do_constant_folding=True, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True, **_)[source][source]

将模型导出为 ONNX 格式。

参数
  • model (torch.nn.Module | torch.export.ExportedProgram | torch.jit.ScriptModule | torch.jit.ScriptFunction) – 要导出的模型。

  • args (tuple[Any, ...]) – 示例位置输入。任何非 Tensor 参数将被硬编码到导出的模型中;任何 Tensor 参数将成为导出模型的输入,顺序与它们在元组中出现的顺序相同。

  • f (str | os.PathLike | None) – 输出 ONNX 模型文件的路径。例如 “model.onnx”。

  • kwargs (dict[str, Any] | None) – 可选的示例关键字输入。

  • export_params (bool) – 如果为 false,则不会导出参数(权重)。

  • verbose (bool | None) – 是否启用详细日志记录。

  • input_names (Sequence[str] | None) – 要按顺序分配给图输入节点的名称。

  • output_names (Sequence[str] | None) – 要按顺序分配给图输出节点的名称。

  • opset_version (int | None) – 要定位的 默认 (ai.onnx) opset 版本。必须 >= 7。

  • dynamic_axes (Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None) –

    默认情况下,导出的模型将把所有输入和输出张量的形状设置为与 args 中给出的形状完全匹配。要将张量的轴指定为动态的(即仅在运行时已知),请将 dynamic_axes 设置为具有以下模式的字典

    • KEY (str):输入或输出名称。每个名称也必须在 input_names 或中提供

      output_names.

    • VALUE (dict 或 list):如果是 dict,则键是轴索引,值是轴名称。如果是

      list,则每个元素都是一个轴索引。

    例如

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    
    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"],
    )
    

    产生

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
    ...
    

    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"],
        dynamic_axes={
            # dict value: manually named axes
            "x": {0: "my_custom_axis_name"},
            # list value: automatic names
            "sum": [0],
        },
    )
    

    产生

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
    

  • keep_initializers_as_inputs (bool) –

    如果为 True,则导出的图中所有初始值设定项(通常对应于模型权重)也将作为图的输入添加。如果为 False,则初始值设定项不会作为图的输入添加,并且仅用户输入作为输入添加。

    如果您打算在运行时提供模型权重,请将其设置为 True。如果权重是静态的,请将其设置为 False,以便后端/运行时进行更好的优化(例如,常量折叠)。

  • dynamo (bool) – 是否使用 torch.export ExportedProgram 而不是 TorchScript 导出模型。

  • external_data (bool) – 是否将模型权重另存为外部数据文件。对于权重很大的模型(超过 ONNX 文件大小限制 (2GB))这是必需的。如果为 False,则权重与模型架构一起保存在 ONNX 文件中。

  • dynamic_shapes (dict[str, Any] | tuple[Any, ...] | list[Any] | None) – 模型输入的动态形状的字典或元组。有关更多详细信息,请参阅 torch.export.export()。仅当 dynamo 为 True 时才使用(且首选)此选项。请注意,dynamic_shapes 旨在在模型使用 dynamo=True 导出时使用,而 dynamic_axes 在 dynamo=False 时使用。

  • custom_translation_table (dict[Callable, Callable | Sequence[Callable]] | None) – 模型中运算符的自定义分解字典。该字典应以 fx Node 中的可调用目标作为键(例如 torch.ops.aten.stft.default),值应是使用 ONNX Script 构建该图的函数。此选项仅在 dynamo 为 True 时有效。

  • report (bool) – 是否为导出过程生成 markdown 报告。此选项仅在 dynamo 为 True 时有效。

  • optimize (bool) – 是否优化导出的模型。此选项仅在 dynamo 为 True 时有效。

  • verify (bool) – 是否使用 ONNX Runtime 验证导出的模型。此选项仅在 dynamo 为 True 时有效。

  • profile (bool) – 是否分析导出过程。此选项仅在 dynamo 为 True 时有效。

  • dump_exported_program (bool) – 是否将 torch.export.ExportedProgram 转储到文件。这对于调试导出器很有用。此选项仅在 dynamo 为 True 时有效。

  • artifacts_dir (str | os.PathLike) – 保存调试工件(如报告和序列化的导出程序)的目录。此选项仅在 dynamo 为 True 时有效。

  • fallback (bool) – 如果 dynamo 导出器失败,是否回退到 TorchScript 导出器。此选项仅在 dynamo 为 True 时有效。启用回退后,即使提供了 dynamic_shapes,也建议设置 dynamic_axes。

  • training (_C_onnx.TrainingMode) – 已弃用的选项。请改为在导出前设置模型的训练模式。

  • operator_export_type (_C_onnx.OperatorExportTypes) – 已弃用的选项。仅支持 ONNX。

  • do_constant_folding (bool) – 已弃用的选项。

  • custom_opsets (Mapping[str, int] | None) –

    已弃用。一个字典

    • KEY (str):opset 域名称

    • VALUE (int):opset 版本

    如果 model 引用了自定义 opset,但未在此字典中提及,则 opset 版本设置为 1。只有自定义 opset 域名和版本应该通过此参数指示。

  • export_modules_as_functions (bool | Collection[type[torch.nn.Module]]) –

    已弃用的选项。

    标志,用于启用将所有 nn.Module 前向调用导出为 ONNX 中的本地函数。或者一个集合,用于指示要导出为 ONNX 中的本地函数的特定模块类型。此功能需要 opset_version >= 15,否则导出将失败。这是因为 opset_version < 15 意味着 IR 版本 < 8,这意味着没有本地函数支持。模块变量将导出为函数属性。函数属性分为两类。

    1. 注解属性:通过 PEP 526 风格 具有类型注解的类变量将导出为属性。注解属性在 ONNX 本地函数的子图中未使用,因为它们不是由 PyTorch JIT 跟踪创建的,但消费者可以使用它们来确定是否用特定的融合内核替换该函数。

    2. 推断属性:模块内部运算符使用的变量。属性名称将带有前缀 “inferred::”。这是为了区分从 python 模块注解检索到的预定义属性。推断属性在 ONNX 本地函数的子图中使用。

    • False(默认):将 nn.Module 前向调用导出为细粒度节点。

    • True:将所有 nn.Module 前向调用导出为本地函数节点。

    • nn.Module 类型集合:将 nn.Module 前向调用导出为本地函数节点,

      仅当在集合中找到 nn.Module 的类型时。

  • autograd_inlining (bool) – 已弃用。用于控制是否内联 autograd 函数的标志。有关更多详细信息,请参阅 https://github.com/pytorch/pytorch/pull/74765

返回类型

ONNXProgram | None

torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[source][source]

为自定义运算符注册符号函数。

当用户为自定义/contrib 运算符注册符号时,强烈建议通过 setType API 为该运算符添加形状推断,否则导出的图在某些极端情况下可能具有不正确的形状推断。setType 的一个示例是 test_aten_embedding_2test_operators.py 中。

有关示例用法,请参见模块文档中的“自定义运算符”。

参数
  • symbolic_name (str) – “<domain>::<op>” 格式的自定义运算符的名称。

  • symbolic_fn (Callable) – 一个函数,它接受 ONNX 图和当前运算符的输入参数,并返回要添加到图中的新运算符节点。

  • opset_version (int) – 要注册的 ONNX opset 版本。

torch.onnx.unregister_custom_op_symbolic(symbolic_name, opset_version)[source][source]

取消注册 symbolic_name

有关示例用法,请参见模块文档中的“自定义运算符”。

参数
  • symbolic_name (str) – “<domain>::<op>” 格式的自定义运算符的名称。

  • opset_version (int) – 要取消注册的 ONNX opset 版本。

torch.onnx.select_model_mode_for_export(model, mode)[source][source]

一个上下文管理器,用于临时将 model 的训练模式设置为 mode,并在我们退出 with 代码块时重置它。

参数
  • model – 与 export()model 参数类型和含义相同。

  • mode (TrainingMode) – 与 export()training 参数类型和含义相同。

torch.onnx.is_in_onnx_export()[source][source]

返回是否正处于 ONNX 导出过程中。

返回类型

bool

torch.onnx.verification.find_mismatch(model, input_args, do_constant_folding=True, training=<TrainingMode.EVAL: 0>, opset_version=None, keep_initializers_as_inputs=True, verbose=False, options=None)[source][source]

查找原始模型和导出模型之间的所有不匹配之处。

实验性功能。API 可能会发生变化。

此工具帮助调试原始 PyTorch 模型和导出的 ONNX 模型之间的不匹配。它对模型图进行二分搜索,以找到展示不匹配的最小子图。

参数
返回

一个 GraphInfo 对象,其中包含不匹配信息。

返回类型

GraphInfo

示例

>>> import torch
>>> import torch.onnx.verification
>>> torch.manual_seed(0)
>>> opset_version = 15
>>> # Define a custom symbolic function for aten::relu.
>>> # The custom symbolic function is incorrect, which will result in mismatches.
>>> def incorrect_relu_symbolic_function(g, self):
...     return self
>>> torch.onnx.register_custom_op_symbolic(
...     "aten::relu",
...     incorrect_relu_symbolic_function,
...     opset_version=opset_version,
... )
>>> class Model(torch.nn.Module):
...     def __init__(self) -> None:
...         super().__init__()
...         self.layers = torch.nn.Sequential(
...             torch.nn.Linear(3, 4),
...             torch.nn.ReLU(),
...             torch.nn.Linear(4, 5),
...             torch.nn.ReLU(),
...             torch.nn.Linear(5, 6),
...         )
...     def forward(self, x):
...         return self.layers(x)
>>> graph_info = torch.onnx.verification.find_mismatch(
...     Model(),
...     (torch.randn(2, 3),),
...     opset_version=opset_version,
... )
===================== Mismatch info for graph partition : ======================
================================ Mismatch error ================================
Tensor-likes are not close!
Mismatched elements: 12 / 12 (100.0%)
Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
==================================== Tree: =====================================
5 X   __2 X    __1 \u2713
id:  |  id: 0 |  id: 00
     |        |
     |        |__1 X (aten::relu)
     |           id: 01
     |
     |__3 X    __1 \u2713
        id: 1 |  id: 10
              |
              |__2 X     __1 X (aten::relu)
                 id: 11 |  id: 110
                        |
                        |__1 \u2713
                           id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}

JitScalarType

torch 中定义的标量类型。

verification.GraphInfo

GraphInfo 包含 TorchScript 图及其转换后的 ONNX 图的验证信息。

verification.VerificationOptions

ONNX 导出验证的选项。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源