快捷方式

基于 TorchScript 的 ONNX 导出器

注意

若要使用 TorchDynamo(而不是 TorchScript)导出 ONNX 模型,请参阅 torch.onnx.dynamo_export()

示例:将 AlexNet 从 PyTorch 转换为 ONNX

这是一个简单的脚本,用于将预训练的 AlexNet 导出到名为 alexnet.onnx 的 ONNX 文件。调用 torch.onnx.export 一次运行模型以跟踪其执行情况,然后将跟踪的模型导出到指定的文件

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

生成的 alexnet.onnx 文件包含一个二进制 协议缓冲区,其中包含您导出的模型(在本例中为 AlexNet)的网络结构和参数。参数 verbose=True 会导致导出器打印出模型的可读表示

# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
      %learned_0 : Float(64, 3, 11, 11)
      %learned_1 : Float(64)
      %learned_2 : Float(192, 64, 5, 5)
      %learned_3 : Float(192)
      # ---- omitted for brevity ----
      %learned_14 : Float(1000, 4096)
      %learned_15 : Float(1000)) {
  # Every statement consists of some output tensors (and their types),
  # the operator to be run (with its attributes, e.g., kernels, strides,
  # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
  %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
  %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
  %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
  # ---- omitted for brevity ----
  %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
  # Dynamic means that the shape is not known. This may be because of a
  # limitation of our implementation (which we would like to fix in a
  # future release) or shapes which are truly dynamic.
  %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
  %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
  %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
  %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
  # ---- omitted for brevity ----
  %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
  return (%output1);
}

您还可以使用 ONNX 库来验证输出,您可以使用 pip 安装该库

pip install onnx

然后,您可以运行

import onnx

# Load the ONNX model
model = onnx.load("alexnet.onnx")

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

您还可以使用众多支持 ONNX 的 运行时 中的一种来运行导出的模型。例如,在安装 ONNX Runtime 后,您可以加载并运行模型

import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession("alexnet.onnx")

outputs = ort_session.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])

这是一个更复杂的 教程,介绍如何导出模型并使用 ONNX Runtime 运行它

跟踪与脚本编写

从内部来看,torch.onnx.export() 要求 torch.jit.ScriptModule 而不是 torch.nn.Module。如果传入的模型不是 ScriptModuleexport() 会使用追踪功能将其转换为 ScriptModule

  • 追踪:如果 torch.onnx.export() 与不是 ScriptModule 的模块一起调用,那么它会首先执行相当于 torch.jit.trace() 的操作,即使用给定的 args 执行模型一次,并记录执行期间发生的所有操作。这意味着如果模型是动态的,例如根据输入数据改变行为,那么导出的模型不会捕获此动态行为。我们建议检查导出的模型,并确保操作符看起来合理。追踪会展开循环和 if 语句,并导出与追踪运行完全相同的静态图。如果想要动态控制流程地导出模型,则需要使用脚本编制

  • 脚本编制:通过脚本编制编译模型可以保留动态控制流程,并且对于不同大小的输入有效。要使用脚本编制,请执行以下操作

    • 使用 torch.jit.script() 生成 ScriptModule

    • 使用 ScriptModule 作为模型调用 torch.onnx.export()args 仍然是必需的,但它们仅在内部用于生成示例输出,以便可以捕获输出的类型和形状。不会执行追踪。

有关更多详细信息,包括如何整合追踪和脚本编制以满足不同模型的特定需求,请参阅 TorchScript 简介TorchScript

避免陷入误区

避免使用 NumPy 和内置 Python 类型

PyTorch 模型可以使用 NumPy 或 Python 类型和函数编写,但在 跟踪 期间,任何 NumPy 或 Python 类型(而非 torch.Tensor)的变量都会转换为常量,如果这些值应随输入而更改,这将产生错误的结果。

例如,不要在 numpy.ndarrays 上使用 numpy 函数,而要

# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)

在 torch.Tensors 上使用 torch 运算符

# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)

不要使用 torch.Tensor.item()(它将张量转换为 Python 内置数字),而要

# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
    return x.reshape(y.item(), -1)

使用 torch 对单元素张量的隐式转换的支持

# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
    return x.reshape(y, -1)

避免使用 Tensor.data

使用 Tensor.data 字段可能会产生错误的跟踪信息,进而产生错误的 ONNX 图。请改用 torch.Tensor.detach()。(正在进行移除 Tensor.data 的工作,请在此处查看)。

在追踪模式下使用 tensor.shape 时避免就地操作

在追踪模式下,从 tensor.shape 获取的形状被追踪为张量,并共享相同的内存。这可能会导致最终输出值不匹配。作为一种解决方法,在此类场景中避免使用就地操作。例如,在模型中

class Model(torch.nn.Module):
  def forward(self, states):
      batch_size, seq_length = states.shape[:2]
      real_seq_length = seq_length
      real_seq_length += 2
      return real_seq_length + seq_length

real_seq_lengthseq_length 在追踪模式下共享相同的内存。这可以通过重写就地操作来避免

real_seq_length = real_seq_length + 2

局限性

类型

  • torch.Tensors、可以轻松转换为 torch.Tensors 的数字类型(如 float、int)以及这些类型的元组和列表受支持作为模型输入或输出。字典和 str 输入和输出在 跟踪 模式下被接受,但

    • 任何依赖于词典或 str 输入值的计算都将用在一条追踪执行期间看到的常数值替换

    • 任何输出是字典的都会被静默地替换为其值的扁平序列(键将被移除)。例如:{"foo": 1, "bar": 2} 将变成 (1, 2)

    • 任何输出是 str 的都会被静默地移除。

  • 由于 ONNX 对嵌套序列的支持有限,因此在 脚本 模式中不支持涉及元组和列表的某些操作。特别是并不支持将元组追加到列表中。在跟踪模式中,将在跟踪期间自动使嵌套序列变平。

算子实现中的差异

由于算子实现的不同,在不同的运行时上运行导出模型可能会产生彼此或与 PyTorch 不同的结果。通常情况下,这些差异在数值上非常小,所以只有当您的应用程序对此类微小差异很敏感时,这才是您需要担心的问题。

不受支持的张量索引模式

以下是无法导出的张量索引模式的列表。如果您正在导出不包含任何以下不受支持模式的模型时遇到问题,请仔细检查您是否使用最新的 opset_version 进行导出。

读取/获取

在对张量进行索引以进行读取时,不支持以下模式

# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.

写入/设置

在对张量进行索引以进行写入时,不支持以下模式

# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
#              or multiple consecutive tensor indices with rank == 1.

# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.

# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.

# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
#   data shape: [3, 4, 5]
#   new_data shape: [5]
#   expected new_data shape after broadcasting: [2, 2, 2, 5]

添加对算子的支持

在导出包含不受支持算子的模型时,您将会看到这样的错误消息

RuntimeError: ONNX export failed: Couldn't export operator foo

当出现这种情况时,您可以执行以下几项操作

  1. 更改模型以避免使用该算子。

  2. 创建一个符号函数来转换算子并将其注册为自定义符号函数。

  3. 为 PyTorch 作出贡献,将同一个符号函数添加到 torch.onnx 本身。

如果您决定实现一个符号函数(我们希望您能将它贡献回 PyTorch!),请使用以下方法开始

ONNX 导出器内部

“符号函数”是一个函数,它将 PyTorch 算子分解为一系列 ONNX 算子的组合。

在导出期间,导出器将按拓扑顺序访问 TorchScript 图中的每个节点(包含一个 PyTorch 算子)。在访问一个节点时,导出器将查找针对该算子的已注册的符号函数。符号函数在 Python 中实现。名为 foo 的算子的符号函数看起来类似于

def foo(
  g,
  input_0: torch._C.Value,
  input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
  """
  Adds the ONNX operations representing this PyTorch function by updating the
  graph g with `g.op()` calls.

  Args:
    g (Graph): graph to write the ONNX representation into.
    input_0 (Value): value representing the variables which contain
        the first input for this operator.
    input_1 (Value): value representing the variables which contain
        the second input for this operator.

  Returns:
    A Value or List of Values specifying the ONNX nodes that compute something
    equivalent to the original PyTorch operator with the given inputs.

    None if it cannot be converted to ONNX.
  """
  ...

torch._C 类型是 C++ 中定义的类型的 Python 包装器,这些类型在 ir.h 中定义。

添加符号函数的过程取决于算子类型。

ATen 算子

ATen 是 PyTorch 的内部张量库。如果该算子是 ATen 算子(在 TorchScript 图中以前缀 aten:: 出现),确保还没有支持它。

支持的算子列表

访问自动生成的 支持的 TorchScript 算子列表,详细了解每个 opset_version 中支持哪些算子。

为一个 aten 或量化算子添加支持

如果算子没有出现在上述列表中

  • torch/onnx/symbolic_opset<version>.py 中定义符号函数,例如 torch/onnx/symbolic_opset9.py”。确保此函数具有与 ATen 函数相同的名字,该函数可能在 torch/_C/_VariableFunctions.pyitorch/nn/functional.pyi 中声明(这些文件在生成时生成,因此在你构建 PyTorch 之前,这些文件不会出现在你的签出中)。

  • 默认情况下,第一个参数是 ONNX 图。其他参数名必须严格匹配 .pyi 文件中的名称,因为分派使用关键字参数完成。

  • 在符号函数中,如果算子在 ONNX 标准算子集 中,我们只需要创建一个节点来表示图中的 ONNX 算子。如果不这样做,我们可以组成几个具有与 ATen 算子等效语义的标准算子。

下面是处理 ELU 算子的缺失符号函数的一个示例。

如果我们运行以下代码

print(
    torch.jit.trace(
        torch.nn.ELU(), # module
        torch.ones(1)   # example input
    ).graph
)

我们会看到类似于

graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
      %input : Float(1, strides=[1], requires_grad=0, device=cpu)):
  %4 : float = prim::Constant[value=1.]()
  %5 : int = prim::Constant[value=1]()
  %6 : int = prim::Constant[value=1]()
  %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
  return (%7)

由于我们在图中看到了 aten::elu,所以我们知道这是一个 ATen 运算符。

我们查看 ONNX 运算符列表,并确认 Elu 在 ONNX 中是标准化的。

我们在 torch/nn/functional.pyi 中找到了 elu 的签名

def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...

我们将下列行添加到 symbolic_opset9.py

def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
    return g.op("Elu", input, alpha_f=alpha)

现在 PyTorch 能够导出包含 aten::elu 运算符的模型!

请参阅 torch/onnx/symbolic_opset*.py 文件以了解更多示例。

torch.autograd.Functions

如果运算符是 torch.autograd.Function 的子类,则有三种方法可以导出它。

静态符号方法

可以向你的函数类添加一个名为 symbolic 的静态方法。它应该返回在 ONNX 中表示函数行为的 ONNX 运算符。例如

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
        return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))

内联自动微分函数

在针对其后续 torch.autograd.Function 未提供静态符号方法或者在未提供将 prim::PythonOp 注册为自定义符号函数的函数的情况下,torch.onnx.export() 会尝试内联对应于该 torch.autograd.Function 的图,以便将此函数分解为函数中使用的单独运算符。只要支持这些单独的运算符,导出就应成功。例如

class MyLogExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        h = input.exp()
        return h.log().log()

此模型没有出现静态符号方法,但会按如下方式导出

graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
    %1 : float = onnx::Exp[](%input)
    %2 : float = onnx::Log[](%1)
    %3 : float = onnx::Log[](%2)
    return (%3)

如果您需要避免内联torch.autograd.Function,应将模型导出为 operator_export_type 设置为 ONNX_FALLTHROUGHONNX_ATEN_FALLBACK

自定义算子

您可以导出包含多种标准 ONNX op 组合的模型或由自定义 C++ 后端驱动模型。

ONNX-script 函数

如果算子不是标准 ONNX op,但可以由多个现有的 ONNX op 组成,您可以利用ONNX-script创建一个外部 ONNX 函数来支持该算子。您可以按照以下示例进行导出

import onnxscript
# There are three opset version needed to be aligned
# This is (1) the opset version in ONNX function
from onnxscript.onnx_opset import opset15 as op
opset_version = 15

x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()

custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)

@onnxscript.script(custom_opset)
def Selu(X):
    alpha = 1.67326  # auto wrapped as Constants
    gamma = 1.0507
    alphaX = op.CastLike(alpha, X)
    gammaX = op.CastLike(gamma, X)
    neg = gammaX * (alphaX * op.Exp(X) - alphaX)
    pos = gammaX * X
    zero = op.CastLike(0, X)
    return op.Where(X <= zero, neg, pos)

# setType API provides shape/type to ONNX shape/type inference
def custom_selu(g: jit_utils.GraphContext, X):
    return g.onnxscript_op(Selu, X).setType(X.type())

# Register custom symbolic function
# There are three opset version needed to be aligned
# This is (2) the opset version in registry
torch.onnx.register_custom_op_symbolic(
    symbolic_name="aten::selu",
    symbolic_fn=custom_selu,
    opset_version=opset_version,
)

# There are three opset version needed to be aligned
# This is (2) the opset version in exporter
torch.onnx.export(
    model,
    x,
    "model.onnx",
    opset_version=opset_version,
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"onnx-script": 2}
)

上述示例将其导出为“onnx-script”opset 中的自定义算子。导出自定义算子时,可以使用 custom_opsets 词典在导出时指定自定义域版本。如果未指定,自定义 opset 版本默认为 1。

注意:请小心对齐上述示例中提到的 opset 版本,并确保在导出步骤中使用它们。如何编写 onnx-script 函数的示例用法是 onnx-script 中积极开发的一个 beta 版本。请遵循最新的ONNX-script

C++ 算子

如果模型使用以 C++ 实现的自定义算子(如使用自定义 C++ 算子扩展 TorchScript中所述),您可以按照以下示例进行导出

from torch.onnx import symbolic_helper


# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)


# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)


class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super().__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # Calling custom op
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)


model = FooModel(attr1, attr2)
torch.onnx.export(
    model,
    (example_input1, example_input1),
    "model.onnx",
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"custom_domain": 2}
)

上述示例将其导出为“custom_domain”opset 中的自定义算子。导出自定义算子时,可以使用 custom_opsets 词典在导出时指定自定义域版本。如果未指定,自定义 opset 版本默认为 1。

使用该模型的运行时需要支持自定义 op。请参阅Caffe2 自定义 opONNX Runtime 自定义 op或所选运行时的文档。

一次发现所有不可转换的 Aten op

如果导出因无法转换的 ATen op 而失败,实际上可能存在多于一个这样的 op,但错误消息仅提及第一个 op。你可以一次性发现所有无法转换的 op

# prepare model, args, opset_version
...

torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
    model, args, opset_version=opset_version
)

print(set(unconvertible_ops))

该集合是个近似,因为某些 op 可能会在转换过程中被移除且无需转换。某些其他 op 可能具备部分支持,但使用特定输入时会转换失败,但这应能让你大致了解不支持哪些 op。欢迎在 GitHub 上提出 op 支持请求。

常见问题

问:我已导出 LSTM 模型,但其输入大小似乎已固定?

追踪器会记录示例输入的形状。如果模型应接受动态形状的输入,请在调用 torch.onnx.export() 时设置 dynamic_axes

问:如何导出包含循环的模型?

请参阅 追踪与脚本

问:如何导出带有原始类型输入(如 int、float)的模型?

在 PyTorch 1.9 中添加了对原始数字类型输入的支持。但是,导出器不支持带有 str 输入的模型。

问:ONNX 是否支持隐式标量数据类型强制转换?

ONNX 标准不支持,但导出器会尝试处理该部分。标量以常量张量形式导出。导出器会找出标量正确的类型。在少数无法完成此操作的情况下,你需要手动用 dtype=torch.float32 等指定数据类型。如果你看到任何错误,请 [创建 GitHub issue](https://github.com/pytorch/pytorch/issues).

问:张量列表可以导出到 ONNX 吗?

可以,对于 opset_version >= 11,因为 ONNX 在 opset 11 中引入了 Sequence 类型。

Python API

函数

torch.onnx.export(model, args, f=None, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True, dynamo=False)[source]

以 ONNX 格式导出模型。

如果 model 不是 torch.jit.ScriptModuletorch.jit.ScriptFunction,这会运行 model 一次,以便将其转换为要导出的 TorchScript 图(相当于 torch.jit.trace())。因此,它对动态控制流的支持与 torch.jit.trace() 的支持相同。

参数
  • model (torch.nn.Module, torch.jit.ScriptModuletorch.jit.ScriptFunction) –要导出的模型。

  • args (元组torch.Tensor) –

    args 可以构成如下形式:

    1. 仅参数元组

      args = (x, y, z)
      

    元组应包含模型输入,确保 model(*args) 是对模型的有效调用。任何非张量参数都将硬编码到导出的模型中;任何张量参数都将按其在元组中出现的顺序成为导出模型的输入。

    1. 张量

      args = torch.Tensor([1])
      

    这等效于该张量的 1 元组。

    1. 以命名参数字典结尾的参数元组

      args = (
          x,
          {
              "y": input_y,
              "z": input_z
          }
      )
      

    元组的所有元素都作为非关键字参数传递,而命名参数将从最后一个元素中设定。如果命名参数不存在于字典中,则会向其分配默认值,如果没有提供默认值,则该值为空。

    注意

    如果字典是 args 元组的最后一个元素,则会在其中解释为包含命名参数。为了将字典作为最后一个非关键字参数传递,请提供一个空字典作为 args 元组的最后一个元素。例如,请执行以下操作,而不是

    torch.onnx.export(
        model,
        (
            x,
            # WRONG: will be interpreted as named arguments
            {y: z}
        ),
        "test.onnx.pb"
    )
    

    编写以下内容

    torch.onnx.export(
        model,
        (
            x,
            {y: z},
            {}
        ),
        "test.onnx.pb"
    )
    

  • f (可选项[并集[str, BytesIO]]) - 文件的类似对象(此类对象 f.fileno() 会返回文件描述符)或包含文件名字符串。二进制协议缓冲区将被写入此文件。

  • export_params (布尔值, 默认值为 True) - 如果为 True,则将导出所有参数。如果您想导出未训练的模型,请将其设置为 False。在这种情况下,导出的模型将首先获取其所有参数作为参数,排序则根据 model.state_dict().values() 指定的先后顺序

  • verbose (布尔值, 默认值为 False) - 如果为 True,则会打印要导出到 stdout 的模型说明。此外,最终的 ONNX 图表将包含导出模型中 doc_string` 域,其中会提及 model 的源代码位置。如果为 True,则将启动 ONNX 导出程序记录。

  • training (枚举, 默认值为 TrainingMode.EVAL) -

    • TrainingMode.EVAL:以推理模式导出模型。

    • TrainingMode.PRESERVE:如果 model.training 为

      False,则以推理模式导出模型,如果 model.training 为 True,则以训练模式导出模型。

    • TrainingMode.TRAINING:以训练模式导出模型。禁用优化

      可能会干扰训练。

  • input_names (列表 of str, 默认空列表) - 按顺序为图表的输入节点分配的名称。

  • output_names (列表 of str, 默认空列表) - 按顺序为图表的输出节点分配的名称。

  • operator_export_type (枚举, 默认 OperatorExportTypes.ONNX) -

    • OperatorExportTypes.ONNX:将所有运算符导出为常规 ONNX 运算符

      (在默认的 opset 域中)。

    • OperatorExportTypes.ONNX_FALLTHROUGH:尝试转换所有运算符

      到默认 opset 域中的标准 ONNX 运算符。如果无法这样做(例如,因为尚未添加将特定 PyTorch 运算符转换为 ONNX 的支持),则退回到将运算符导出到自定义 opset 域而不进行转换。适用于 自定义运算符 和 ATen 运算符。为了使导出的模型可用,运行时必须支持这些非标准运算符。

    • OperatorExportTypes.ONNX_ATEN:所有 ATen 运算符(在 TorchScript 名称空间 “aten” 中)

      被导出为 ATen 运算符(在 opset 域 “org.pytorch.aten” 中)。ATen 是 PyTorch 的内置张量库,所以这会指示运行时使用 PyTorch 的这些运算符的实现。

      警告

      通过这种方式导出的模型可能只能由 Caffe2 运行。

      如果算子实现中的数值差异导致 PyTorch 和 Caffe2 之间存在较大的行为差异(在未训练的模型中更为常见),这可能很有用。

    • OperatorExportTypes.ONNX_ATEN_FALLBACK:尝试导出每个 ATen op

      (在 TorchScript 名称空间 “aten” 中) 作为常规 ONNX op。如果我们无法这样做(例如,因为尚未添加将特定 PyTorch 运算符转换为 ONNX 的支持),则退回到导出 ATen op。请参阅 OperatorExportTypes.ONNX_ATEN 的文档以了解上下文。例如

      graph(%0 : Float):
      %3 : int = prim::Constant[value=0]()
      # conversion unsupported
      %4 : Float = aten::triu(%0, %3)
      # conversion supported
      %5 : Float = aten::mul(%4, %0)
      return (%5)
      

      假设 aten::triu 在 ONNX 中不受支持,这将导出为

      graph(%0 : Float):
      %1 : Long() = onnx::Constant[value={0}]()
      # not converted
      %2 : Float = aten::ATen[operator="triu"](%0, %1)
      # converted
      %3 : Float = onnx::Mul(%2, %0)
      return (%3)
      

      警告

      通过这种方式导出的模型可能只能由 Caffe2 运行。

  • opset_version (整数, 默认 17) - 默认的 (ai.onnx) opset 的版本的目标。必须大于等于 7 且小于等于 17。

  • do_constant_folding (bool, 默认值 True) – 应用常量折叠优化。常量折叠将用预先计算好的常量节点替换掉所有常量输入的某些 op。

  • dynamic_axes (dict[string, dict[int, string]] 或 dict[string, list(int)], 默认值为空字典) –

    默认情况下,导出的模型将具有所有输入和输出张量的形状,确切匹配在 args 中给定的形状。如需将张量的轴指定为动态(即在运行时才知道),请将 dynamic_axes 设置为符合架构的字典

    • 键 (str):输入或输出名称。每个名称也必须提供在 input_names

      output_names 中.

    • 值 (字典或列表):如果为字典,则键为轴索引,值为轴名称。如果为

      列表,每个元素都是轴索引。

    例如

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"]
    )
    

    产生

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
    ...
    

    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"],
        dynamic_axes={
            # dict value: manually named axes
            "x": {0: "my_custom_axis_name"},
            # list value: automatic names
            "sum": [0],
        }
    )
    

    产生

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
    

  • keep_initializers_as_inputs (bool, 默认值 None) –

    如果为 True,则导出的图中的所有初始化器(通常对应于参数)也将作为输入添加到图中。如果为 False,则不会将初始化器作为输入添加到图中,仅将非参数输入作为输入添加。这可能允许后端/运行时进行更好的优化(例如常量折叠)。

    如果为 True,则不会执行 deduplicate_initializers 传递。这意味着具有重复值的初始化器不会进行重复数据删除,并会被视为图中的不同输入。这允许在导出后在运行时提供不同的输入初始化器。

    如果 opset_version < 9,则初始化器必须是图输入的一部分,并且将忽略此参数,并且行为等同于将此参数设置为 True。

    如果为 None,则自动选择行为,如下所示

    • 如果 operator_export_type=OperatorExportTypes.ONNX,则行为等同于

      将此参数设置为 False。

    • 其他情况下,行为等同于将此参数设置为 True。

  • custom_opsets (dict[str, int], 默认 empty dict) –

    具有架构的字典

    • 键 (str): opset 域名称

    • 值 (int): opset 版本

    如果 model 引用了自定义 opset,但此字典中未提及,则 opset 版本设置为 1。只能通过此参数指示自定义 opset 域名称和版本。

  • export_modules_as_functions (boolset of type of nn.Module, 默认值 False) –

    启用将所有 nn.Module 转发调用作为 ONNX 中的局部函数导出的标志。或指示要作为 ONNX 中局部函数导出的模块类型。此功能需要 opset_version >= 15,否则导出会失败。这是因为 opset_version < 15 意味着 IR 版本 < 8,这意味着不支持局部函数。模块变量将作为函数属性导出。共有两类函数属性。

    1. 带注释的属性:通过 PEP 526 样式 进行类型注释的类变量将作为属性导出。ONNX 局部函数的子图内未使用带注释的属性,因为它们不是由 PyTorch JIT 跟踪创建的,但使用者可能会使用它们来确定是否使用特定融合核替换函数。

    2. 推断出的属性:模块中运算符使用的变量。属性名称将带有前缀“inferred::”。这是为了区别于从 Python 模块注释中检索到的预定义属性。推断出的属性用于 ONNX 局部函数的子图内。

    • False(默认值):将 nn.Module 转发调用导出为细粒度的节点。

    • True:将所有 nn.Module 转发调用导出为局部函数节点。

    • 类型 nn.Module 的集合:导出 nn.Module forward 调用为局部函数节点,

      仅当该 nn.Module 的类型存在于该集合中时。

  • autograd_inlining (布尔值默认为 True) - 用于控制是否内联自动梯度函数的标志。有关更多详细信息,请参阅https://github.com/pytorch/pytorch/pull/74765

  • dynamo (布尔值默认为 False) - 是否使用 Dynamo(而不是 TorchScript)导出模型。

引发
  • torch.onnx.errors.CheckerError - 如果 ONNX 检查器检测到无效的 ONNX 图形。

  • torch.onnx.errors.UnsupportedOperatorError - 如果 ONNX 图形因使用导出器不支持的操作符而无法导出。

  • torch.onnx.errors.OnnxExporterError - 导出过程中可能发生的其他错误。所有错误均为 errors.OnnxExporterError 的子类。

返回类型

可选[ONNXProgram]

torch.onnx.export_to_pretty_string(model, args, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, export_type=None, google_printer=False, opset_version=None, keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True, do_constant_folding=True, dynamic_axes=None)[源代码]

类似于 export(),但返回 ONNX 模型的文本表示。仅在下面列出参数时存在差异。所有其他参数都与 export() 相同。

参数
  • add_node_names (bool, 默认 True) – 是否设置 NodeProto.name。除非 google_printer=True,否则这一点没有区别。

  • google_printer (bool, 默认 False) – 如果为 False,将返回模型的自定义紧凑表示。如果为 True,将返回 protobuf 的 Message::DebugString(),它更详细。

返回

包含 ONNX 模型可读表示形式的 UTF-8 字符串。

torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[source]

为自定义运算符注册符号函数。

当用户为自定义/contrib 运算符注册符号时强烈建议通过 setType API 为该运算符添加形状推断,否则在一些极端情况下导出的图可能具有不正确的形状推断。setType 的示例是 test_operators.py 中的 test_aten_embedding_2

有关示例用法,请参见模块文档中的“自定义运算符”。

参数
  • symbolic_name (str) – “<domain>::<op>” 格式的自定义运算符的名称。

  • symbolic_fn (可调用对象) – 接受 ONNX 图以及当前运算符的输入参数的函数,并返回要添加到图中的新运算符节点。

  • opset_version (int) – 在其中注册的 ONNX opset 版本。

torch.onnx.unregister_custom_op_symbolic(symbolic_name, opset_version)[源代码]

注销 symbolic_name

有关示例用法,请参见模块文档中的“自定义运算符”。

参数
  • symbolic_name (str) – “<domain>::<op>” 格式的自定义运算符的名称。

  • opset_version (int) – 注销的 ONNX opset 版本。

torch.onnx.select_model_mode_for_export(model, mode)[源代码]

一个上下文管理器,用于暂时设置 model 的训练模式为 mode,并当离开 with-block 时将其重置。

参数
  • model – 与 export()model 参数类型和含义相同。

  • mode (TrainingMode) – 与 export()training 参数类型和含义相同。

torch.onnx.is_in_onnx_export()[源代码]

返回当前是否处于 ONNX 导出阶段。

返回类型

布尔类型

torch.onnx.enable_log()[源代码]

启用 ONNX 日志。

torch.onnx.disable_log()[源代码]

禁用 ONNX 日志记录。

torch.onnx.verification.find_mismatch(model, input_args, do_constant_folding=True, training=<TrainingMode.EVAL: 0>, opset_version=None, keep_initializers_as_inputs=True, verbose=False, options=None)[source]

找出原始模型和导出模型之间的所有不匹配项。

实验性。此 API 会发生更改。

此工具有助于调试原始 PyTorch 模型和导出的 ONNX 模型之间的不匹配。它对模型图进行二分查找,以找出出现不匹配的最小子图。

参数
返回

一个包含不匹配信息 GraphInfo 对象。

返回类型

GraphInfo

实例

>>> import torch
>>> import torch.onnx.verification
>>> torch.manual_seed(0)
>>> opset_version = 15
>>> # Define a custom symbolic function for aten::relu.
>>> # The custom symbolic function is incorrect, which will result in mismatches.
>>> def incorrect_relu_symbolic_function(g, self):
...     return self
>>> torch.onnx.register_custom_op_symbolic(
...     "aten::relu",
...     incorrect_relu_symbolic_function,
...     opset_version=opset_version,
... )
>>> class Model(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.layers = torch.nn.Sequential(
...             torch.nn.Linear(3, 4),
...             torch.nn.ReLU(),
...             torch.nn.Linear(4, 5),
...             torch.nn.ReLU(),
...             torch.nn.Linear(5, 6),
...         )
...     def forward(self, x):
...         return self.layers(x)
>>> graph_info = torch.onnx.verification.find_mismatch(
...     Model(),
...     (torch.randn(2, 3),),
...     opset_version=opset_version,
... )
===================== Mismatch info for graph partition : ======================
================================ Mismatch error ================================
Tensor-likes are not close!
Mismatched elements: 12 / 12 (100.0%)
Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
==================================== Tree: =====================================
5 X   __2 X    __1 \u2713
id:  |  id: 0 |  id: 00
     |        |
     |        |__1 X (aten::relu)
     |           id: 01
     |
     |__3 X    __1 \u2713
        id: 1 |  id: 10
              |
              |__2 X     __1 X (aten::relu)
                 id: 11 |  id: 110
                        |
                        |__1 \u2713
                           id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}

JitScalarType

在 torch 中定义的标量类型。

torch.onnx.verification.GraphInfo

GraphInfo 包含 TorchScript 图及其转换后的 ONNX 图的验证信息。

torch.onnx.verification.VerificationOptions

ONNX 导出验证的选项。

文档

获取 PyTorch 全面的开发人员文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源