快捷方式

基于 TorchScript 的 ONNX 导出器

注意

若要使用 TorchDynamo 而不是 TorchScript 导出 ONNX 模型,请参阅 torch.onnx.dynamo_export().

示例:将 AlexNet 从 PyTorch 导出到 ONNX

这是一个简单的脚本,它将一个预训练的 AlexNet 导出到名为 alexnet.onnx 的 ONNX 文件。对 torch.onnx.export 的调用会运行模型一次以跟踪其执行情况,然后将跟踪到的模型导出到指定文件

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

生成的 alexnet.onnx 文件包含一个二进制 协议缓冲区,其中包含您导出的模型的网络结构和参数(在本例中为 AlexNet)。参数 verbose=True 会使导出器打印出模型的人类可读表示形式

# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
      %learned_0 : Float(64, 3, 11, 11)
      %learned_1 : Float(64)
      %learned_2 : Float(192, 64, 5, 5)
      %learned_3 : Float(192)
      # ---- omitted for brevity ----
      %learned_14 : Float(1000, 4096)
      %learned_15 : Float(1000)) {
  # Every statement consists of some output tensors (and their types),
  # the operator to be run (with its attributes, e.g., kernels, strides,
  # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
  %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
  %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
  %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
  # ---- omitted for brevity ----
  %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
  # Dynamic means that the shape is not known. This may be because of a
  # limitation of our implementation (which we would like to fix in a
  # future release) or shapes which are truly dynamic.
  %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
  %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
  %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
  %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
  # ---- omitted for brevity ----
  %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
  return (%output1);
}

您还可以使用 ONNX 库验证输出,您可以使用 pip 安装该库

pip install onnx

然后,您可以运行

import onnx

# Load the ONNX model
model = onnx.load("alexnet.onnx")

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

您还可以使用许多 支持 ONNX 的运行时 中的一个运行导出的模型。例如,在安装 ONNX 运行时 后,您可以加载并运行模型

import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession("alexnet.onnx")

outputs = ort_session.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])

这里有一个更复杂的 有关导出模型并使用 ONNX 运行时运行它的教程.

跟踪与脚本化

在内部,torch.onnx.export() 需要 torch.jit.ScriptModule 而不是 torch.nn.Module。如果传入的模型不是 ScriptModule,则 export() 将使用跟踪将其转换为一个

  • 跟踪:如果 torch.onnx.export() 被调用时传入的模型不是 ScriptModule,则它首先执行等效于 torch.jit.trace() 的操作,该操作会使用给定的 args 运行模型一次,并记录在该运行期间发生的所有操作。这意味着,如果您的模型是动态的,例如,其行为会根据输入数据而改变,则导出的模型将不会捕获此动态行为。我们建议检查导出的模型,并确保运算符看起来合理。跟踪将展开循环和 if 语句,导出与跟踪运行完全相同的静态图。如果您想要导出具有动态控制流的模型,则需要使用脚本化

  • 脚本化:通过脚本化编译模型可以保留动态控制流,并且对不同大小的输入有效。要使用脚本化

    • 使用 torch.jit.script() 生成 ScriptModule

    • 使用 ScriptModule 作为模型调用 torch.onnx.export()。仍然需要 args,但它们仅用于在内部生成示例输出,以便可以捕获输出的类型和形状。不会执行跟踪。

有关更多详细信息,包括如何组合跟踪和脚本以满足不同模型的特定要求,请参阅TorchScript 简介TorchScript

避免陷阱

避免使用 NumPy 和内置 Python 类型

PyTorch 模型可以使用 NumPy 或 Python 类型和函数编写,但在进行 跟踪 时,任何 NumPy 或 Python 类型变量(而不是 torch.Tensor)都会被转换为常量,如果这些值应该根据输入进行改变,就会产生错误的结果。

例如,不要在 numpy.ndarrays 上使用 numpy 函数

# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)

在 torch.Tensors 上使用 torch 运算符

# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)

并且不要使用 torch.Tensor.item()(它将 Tensor 转换为 Python 内置数字)

# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
    return x.reshape(y.item(), -1)

使用 torch 对单元素张量进行隐式转换的支持

# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
    return x.reshape(y, -1)

避免使用 Tensor.data

使用 Tensor.data 字段可能会产生错误的跟踪,从而产生错误的 ONNX 图。请改用 torch.Tensor.detach()。(正在进行的工作是 完全删除 Tensor.data)。

在跟踪模式下使用 tensor.shape 时避免使用就地操作

在跟踪模式下,从 tensor.shape 获取的形状将被跟踪为张量,并共享同一内存。这可能会导致最终输出值的差异。作为解决方法,请避免在这些情况下使用就地操作。例如,在模型中

class Model(torch.nn.Module):
  def forward(self, states):
      batch_size, seq_length = states.shape[:2]
      real_seq_length = seq_length
      real_seq_length += 2
      return real_seq_length + seq_length

real_seq_lengthseq_length 在跟踪模式下共享同一内存。可以通过重写就地操作来避免这种情况

real_seq_length = real_seq_length + 2

限制

类型

  • 仅支持 torch.Tensors、可以轻松转换为 torch.Tensors 的数字类型(例如 float、int)以及这些类型的元组和列表作为模型输入或输出。在 跟踪 模式下接受 Dict 和 str 输入和输出,但

    • 任何依赖于 dict 或 str 输入值的计算都将被替换为在一次跟踪执行期间看到的常量值。

    • 任何作为 dict 的输出将被静默替换为其值的扁平化序列(键将被删除)。例如,{"foo": 1, "bar": 2} 变为 (1, 2)

    • 任何作为 str 的输出将被静默删除。

  • 由于 ONNX 对嵌套序列的支持有限,某些涉及元组和列表的操作在 脚本 模式下不受支持。特别是将元组追加到列表中不受支持。在跟踪模式下,嵌套序列将在跟踪期间自动扁平化。

运算符实现的差异

由于运算符实现的差异,在不同运行时运行导出的模型可能会产生彼此之间或与 PyTorch 不同的结果。通常这些差异在数值上很小,因此只有在您的应用程序对这些微小差异敏感时才会造成问题。

不受支持的张量索引模式

下面列出了无法导出的张量索引模式。如果您遇到导出模型时出现问题,但该模型不包含下面任何不受支持的模式,请仔细检查您是否使用最新的 opset_version 导出。

读取/获取

从张量中索引以进行读取时,以下模式不受支持

# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.

写入/设置

从张量中索引以进行写入时,以下模式不受支持

# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
#              or multiple consecutive tensor indices with rank == 1.

# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.

# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.

# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
#   data shape: [3, 4, 5]
#   new_data shape: [5]
#   expected new_data shape after broadcasting: [2, 2, 2, 5]

添加对运算符的支持

导出包含不受支持的运算符的模型时,您会看到类似以下的错误消息

RuntimeError: ONNX export failed: Couldn't export operator foo

发生这种情况时,您可以采取以下几种措施

  1. 修改模型以不再使用该运算符。

  2. 创建一个符号函数来转换运算符,并将它注册为自定义符号函数。

  3. 为 PyTorch 贡献代码,将同一个符号函数添加到 torch.onnx 本身。

如果您决定实现符号函数(我们希望您将其贡献回 PyTorch!),以下是如何开始操作的

ONNX 导出器内部

“符号函数”是一个将 PyTorch 运算符分解为一系列 ONNX 运算符组成的函数。

在导出过程中,导出器将按拓扑顺序访问 TorchScript 图中的每个节点(包含 PyTorch 运算符)。访问节点时,导出器会为该运算符查找已注册的符号函数。符号函数是用 Python 实现的。名为 foo 的 op 的符号函数看起来像这样

def foo(
  g,
  input_0: torch._C.Value,
  input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
  """
  Adds the ONNX operations representing this PyTorch function by updating the
  graph g with `g.op()` calls.

  Args:
    g (Graph): graph to write the ONNX representation into.
    input_0 (Value): value representing the variables which contain
        the first input for this operator.
    input_1 (Value): value representing the variables which contain
        the second input for this operator.

  Returns:
    A Value or List of Values specifying the ONNX nodes that compute something
    equivalent to the original PyTorch operator with the given inputs.

    None if it cannot be converted to ONNX.
  """
  ...

torch._C 类型是围绕 ir.h 中用 C++ 定义的类型的 Python 包装器。

添加符号函数的过程取决于运算符的类型。

ATen 运算符

ATen 是 PyTorch 的内置张量库。如果运算符是 ATen 运算符(在 TorchScript 图中显示为前缀为 aten::),请确保它尚未得到支持。

支持的运算符列表

访问自动生成的 支持的 TorchScript 运算符列表,了解每个 opset_version 支持哪些运算符的详细信息。

添加对 aten 或量化运算符的支持

如果运算符不在上面的列表中

  • torch/onnx/symbolic_opset<version>.py 中定义符号函数,例如 torch/onnx/symbolic_opset9.py。确保函数的名称与 ATen 函数的名称相同,ATen 函数的名称可能在 torch/_C/_VariableFunctions.pyitorch/nn/functional.pyi 中声明(这些文件是在构建时生成的,因此在您构建 PyTorch 之前不会出现在您的检出中)。

  • 默认情况下,第一个参数是 ONNX 图。其他参数名称必须与 .pyi 文件中的名称完全匹配,因为调度是使用关键字参数完成的。

  • 在符号函数中,如果运算符位于 ONNX 标准运算符集中,我们只需要创建一个节点来在图中表示 ONNX 运算符。否则,我们可以组合几个具有与 ATen 运算符等效语义的标准运算符。

以下是处理 ELU 运算符的符号函数缺失的示例。

如果我们运行以下代码

print(
    torch.jit.trace(
        torch.nn.ELU(), # module
        torch.ones(1)   # example input
    ).graph
)

我们会看到类似以下的内容

graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
      %input : Float(1, strides=[1], requires_grad=0, device=cpu)):
  %4 : float = prim::Constant[value=1.]()
  %5 : int = prim::Constant[value=1]()
  %6 : int = prim::Constant[value=1]()
  %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
  return (%7)

由于我们在图中看到了 aten::elu,因此我们知道这是一个 ATen 运算符。

我们检查 ONNX 运算符列表,并确认 Elu 在 ONNX 中是标准化的。

我们在 torch/nn/functional.pyi 中找到了 elu 的签名

def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...

我们向 symbolic_opset9.py 添加以下几行

def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
    return g.op("Elu", input, alpha_f=alpha)

现在,PyTorch 能够导出包含 aten::elu 运算符的模型!

请参阅 torch/onnx/symbolic_opset*.py 文件以获取更多示例。

torch.autograd.Functions

如果运算符是 torch.autograd.Function 的子类,则有三种方法可以导出它。

静态符号方法

您可以向函数类添加一个名为 symbolic 的静态方法。它应该返回表示函数在 ONNX 中的行为的 ONNX 运算符。例如

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
        return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))

内联 Autograd 函数

在没有为其后续 torch.autograd.Function 提供静态符号方法或没有为其提供注册 prim::PythonOp 作为自定义符号函数的函数的情况下,torch.onnx.export() 尝试内联对应于该 torch.autograd.Function 的图,以便将该函数分解为函数中使用的各个运算符。只要这些单个运算符得到支持,导出就会成功。例如

class MyLogExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        h = input.exp()
        return h.log().log()

该模型没有静态符号方法,但可以导出如下

graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
    %1 : float = onnx::Exp[](%input)
    %2 : float = onnx::Log[](%1)
    %3 : float = onnx::Log[](%2)
    return (%3)

如果您需要避免内联 torch.autograd.Function,则应使用 operator_export_type 设置为 ONNX_FALLTHROUGHONNX_ATEN_FALLBACK 来导出模型。

自定义算子

您可以导出包含多个标准 ONNX 算子组合或由自定义 C++ 后端驱动的自定义算子的模型。

ONNX-script 函数

如果一个算子不是标准 ONNX 算子,但可以由多个现有的 ONNX 算子组成,您可以利用 ONNX-script 创建一个外部 ONNX 函数来支持该算子。您可以按照以下示例进行导出

import onnxscript
# There are three opset version needed to be aligned
# This is (1) the opset version in ONNX function
from onnxscript.onnx_opset import opset15 as op
opset_version = 15

x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()

custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)

@onnxscript.script(custom_opset)
def Selu(X):
    alpha = 1.67326  # auto wrapped as Constants
    gamma = 1.0507
    alphaX = op.CastLike(alpha, X)
    gammaX = op.CastLike(gamma, X)
    neg = gammaX * (alphaX * op.Exp(X) - alphaX)
    pos = gammaX * X
    zero = op.CastLike(0, X)
    return op.Where(X <= zero, neg, pos)

# setType API provides shape/type to ONNX shape/type inference
def custom_selu(g: jit_utils.GraphContext, X):
    return g.onnxscript_op(Selu, X).setType(X.type())

# Register custom symbolic function
# There are three opset version needed to be aligned
# This is (2) the opset version in registry
torch.onnx.register_custom_op_symbolic(
    symbolic_name="aten::selu",
    symbolic_fn=custom_selu,
    opset_version=opset_version,
)

# There are three opset version needed to be aligned
# This is (2) the opset version in exporter
torch.onnx.export(
    model,
    x,
    "model.onnx",
    opset_version=opset_version,
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"onnx-script": 2}
)

上面的示例将它导出为“onnx-script” opset 中的自定义算子。导出自定义算子时,您可以使用导出时的 custom_opsets 字典来指定自定义域版本。如果未指定,则自定义 opset 版本默认为 1。

注意:请注意对齐上述示例中提到的 opset 版本,并确保它们在导出步骤中被使用。有关如何编写 onnx-script 函数的示例用法是关于 onnx-script 的活跃开发的测试版。请遵循最新的 ONNX-script

C++ 算子

如果模型使用在 C++ 中实现的自定义算子,如 使用自定义 C++ 算子扩展 TorchScript 中所述,您可以按照以下示例进行导出

from torch.onnx import symbolic_helper


# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)


# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)


class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super().__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # Calling custom op
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)


model = FooModel(attr1, attr2)
torch.onnx.export(
    model,
    (example_input1, example_input1),
    "model.onnx",
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"custom_domain": 2}
)

上面的示例将它导出为“custom_domain” opset 中的自定义算子。导出自定义算子时,您可以使用导出时的 custom_opsets 字典来指定自定义域版本。如果未指定,则自定义 opset 版本默认为 1。

使用模型的运行时需要支持自定义 op。请参阅 Caffe2 自定义 opONNX Runtime 自定义 op 或您选择的运行时的文档。

一次发现所有无法转换的 ATen op

当导出由于无法转换的 ATen op 而失败时,实际上可能存在多个这样的 op,但错误消息只提到了第一个。要一次性发现所有无法转换的 op,您可以

# prepare model, args, opset_version
...

torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
    model, args, opset_version=opset_version
)

print(set(unconvertible_ops))

该集合是近似的,因为某些 op 可能会在转换过程中被移除,因此不需要转换。其他一些 op 可能会具有部分支持,这些支持将在特定输入的情况下导致转换失败,但这应该可以让你对哪些 op 不受支持有一个大致的了解。如果您需要 op 支持请求,请随时在 GitHub 上创建问题。

常见问题解答

问:我已经导出了我的 LSTM 模型,但它的输入大小似乎是固定的?

跟踪器记录了示例输入的形状。如果模型应该接受动态形状的输入,请在调用 torch.onnx.export() 时设置 dynamic_axes

问:如何导出包含循环的模型?

请参阅 跟踪与脚本

问:如何导出具有基本类型输入的模型(例如 int、float)?

PyTorch 1.9 中添加了对基本数字类型输入的支持。但是,导出器不支持具有 str 输入的模型。

问:ONNX 支持隐式标量数据类型转换吗?

ONNX 标准不支持,但导出器会尝试处理该部分。标量被导出为常量张量。导出器将确定标量的正确数据类型。在极少数情况下,如果它无法这样做,您需要使用例如 dtype=torch.float32 手动指定数据类型。如果您看到任何错误,请 [在 GitHub 上创建问题](https://github.com/pytorch/pytorch/issues).

问:张量列表是否可以导出到 ONNX?

是的,对于 opset_version >= 11,因为 ONNX 在 opset 11 中引入了 Sequence 类型。

Python API

函数

torch.onnx.export(model, args=(), f=None, *, kwargs=None, export_params=True, verbose=None, input_names=None, output_names=None, opset_version=None, dynamic_axes=None, keep_initializers_as_inputs=False, dynamo=False, external_data=True, dynamic_shapes=None, report=False, verify=False, profile=False, dump_exported_program=False, artifacts_dir='.', fallback=False, training=<TrainingMode.EVAL: 0>, operator_export_type=<OperatorExportTypes.ONNX: 0>, do_constant_folding=True, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True, **_)[source]

将模型导出为 ONNX 格式。

参数
  • model (torch.nn.Module | torch.export.ExportedProgram | torch.jit.ScriptModule | torch.jit.ScriptFunction) – 要导出的模型。

  • args (tuple[Any, ...]) – 示例位置输入。任何非张量参数都将被硬编码到导出的模型中;任何张量参数都将成为导出的模型的输入,按照它们在元组中出现的顺序。

  • f (str | os.PathLike | None) – 输出 ONNX 模型文件的路径。例如“model.onnx”。

  • kwargs (dict[str, Any] | None) – 可选的示例关键字输入。

  • export_params (bool) – 如果为 false,则不会导出参数(权重)。

  • verbose (bool | None) – 是否启用详细日志记录。

  • input_names (Sequence[str] | None) – 要分配给图的输入节点的名称,按顺序排列。

  • output_names (Sequence[str] | None) – 要分配给图的输出节点的名称,按顺序排列。

  • opset_version (int | None) – 要目标的 默认 (ai.onnx) opset 的版本。必须 >= 7。

  • dynamic_axes (Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None) –

    默认情况下,导出的模型将具有所有输入和输出张量的形状,这些形状将被设置为完全匹配 args 中给出的形状。要将张量的轴指定为动态的(即仅在运行时才知道),请将 dynamic_axes 设置为一个具有以下模式的字典

    • 键 (str):输入或输出名称。每个名称也必须在 input_names

      output_names 中提供.

    • 值 (字典或列表):如果是一个字典,键是轴索引,值是轴名称。如果是一个

      列表,则每个元素都是一个轴索引。

    例如

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    
    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"],
    )
    

    生成

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
    ...
    

    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"],
        dynamic_axes={
            # dict value: manually named axes
            "x": {0: "my_custom_axis_name"},
            # list value: automatic names
            "sum": [0],
        },
    )
    

    生成

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
    

  • keep_initializers_as_inputs (bool) –

    如果为 True,则导出的图中所有初始值(通常对应于模型权重)也将作为图的输入添加。如果为 False,则初始值不会作为图的输入添加,只有用户输入会作为输入添加。

    如果您打算在运行时提供模型权重,请将其设置为 True。如果权重是静态的,请将其设置为 False,以便允许后端/运行时进行更好的优化(例如常量折叠)。

  • dynamo (bool) – 是否使用 torch.export ExportedProgram 而不是 TorchScript 导出模型。

  • external_data (bool) – 是否将模型权重保存为外部数据文件。对于权重过大而超过 ONNX 文件大小限制 (2GB) 的模型,这必不可少。如果为 False,则权重将与模型架构一起保存在 ONNX 文件中。

  • dynamic_shapes (dict[str, Any] | tuple[Any, ...] | list[Any] | None) – 模型输入的动态形状字典。有关更多详细信息,请参阅 torch.export.export()。这仅在 dynamo 为 True 时使用(且优先使用)。一次只能设置一个参数 dynamic_axesdynamic_shapes

  • report (bool) – 是否为导出过程生成 Markdown 报告。

  • verify (bool) – 是否使用 ONNX Runtime 验证导出的模型。

  • profile (bool) – 是否分析导出过程。

  • dump_exported_program (bool) – 是否将 torch.export.ExportedProgram 导出到文件。这对调试导出器很有用。

  • artifacts_dir (str | os.PathLike) – 保存调试工件(如报告和序列化导出程序)的目录。

  • fallback (bool) – dynamo 导出器失败时是否回退到 TorchScript 导出器。

  • training (_C_onnx.TrainingMode) – 已弃用选项。改为在导出之前设置模型的训练模式。

  • operator_export_type (_C_onnx.OperatorExportTypes) – 已弃用选项。仅支持 ONNX。

  • do_constant_folding (bool) – 已弃用选项。导出的图始终被优化。

  • custom_opsets (Mapping[str, int] | None) –

    已弃用。字典

    • KEY (str): 操作集域名称

    • VALUE (int): 操作集版本

    如果 model 引用了自定义操作集,但未在此字典中提及,则操作集版本将设置为 1。仅应通过此参数指示自定义操作集域名称和版本。

  • export_modules_as_functions (bool | Collection[type[torch.nn.Module]]) –

    已弃用选项。

    标志用于启用将所有 nn.Module 正向调用导出为 ONNX 中的本地函数。或者是一组用于指示要导出为 ONNX 中的本地函数的特定模块类型。此功能需要 opset_version >= 15,否则导出将失败。这是因为 opset_version < 15 意味着 IR 版本 < 8,这意味着不支持本地函数。模块变量将导出为函数属性。函数属性有两类。

    1. 注释属性:通过 PEP 526 风格 具有类型注释的类变量将导出为属性。注释属性不会在 ONNX 本地函数的子图中使用,因为它们不是由 PyTorch JIT 跟踪创建的,但消费者可以使用它们来确定是否用特定融合内核替换函数。

    2. 推断属性:模块内部运算符使用的变量。属性名称将带有前缀“inferred::”。这是为了与从 Python 模块注释中检索到的预定义属性区分开来。推断属性在 ONNX 本地函数的子图中使用。

    • False(默认值):将 nn.Module 正向调用导出为细粒度节点。

    • True:将所有 nn.Module 正向调用导出为本地函数节点。

    • 一组 nn.Module 类型:将 nn.Module 正向调用导出为本地函数节点,

      仅当 nn.Module 的类型在该集合中找到时。

  • autograd_inlining (bool) – 已弃用。标志用于控制是否内联 autograd 函数。有关更多详细信息,请参阅 https://github.com/pytorch/pytorch/pull/74765

返回类型

Any | None

torch.onnx.export_to_pretty_string(model, args, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, export_type=None, google_printer=False, opset_version=None, keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True, do_constant_folding=True, dynamic_axes=None)[source]

Deprecated since version 2.5: 已弃用,将在未来版本中删除。请改用 onnx.printer.to_text()。

export() 类似,但返回 ONNX 模型的文本表示。仅以下列出的参数不同。所有其他参数与 export() 相同。

参数
  • add_node_names (bool, default True) – 是否设置 NodeProto.name。除非 google_printer=True,否则这没有区别。

  • google_printer (bool, default False) – 如果为 False,将返回模型的自定义、紧凑表示。如果为 True,将返回 protobuf 的 Message::DebugString(),它更加详细。

返回

包含 ONNX 模型的可读表示的 UTF-8 字符串。

torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[source]

为自定义运算符注册符号函数。

当用户为自定义/contrib 运算符注册符号时,强烈建议通过 setType API 为该运算符添加形状推断,否则导出的图在某些极端情况下可能会出现错误的形状推断。setType 的一个示例是 test_operators.py 中的 test_aten_embedding_2

请参阅模块文档中的“自定义运算符”以了解示例用法。

参数
  • symbolic_name (str) – 自定义运算符的名称,采用“<domain>::<op>” 格式。

  • symbolic_fn (Callable) – 一个函数,它接收 ONNX 图以及当前运算符的输入参数,并返回要添加到图中的新运算符节点。

  • opset_version (int) – 要注册的 ONNX 操作集版本。

torch.onnx.unregister_custom_op_symbolic(symbolic_name, opset_version)[source]

取消注册 symbolic_name

请参阅模块文档中的“自定义运算符”以了解示例用法。

参数
  • symbolic_name (str) – 自定义运算符的名称,采用“<domain>::<op>” 格式。

  • opset_version (int) – 要取消注册的 ONNX 操作集版本。

torch.onnx.select_model_mode_for_export(model, mode)[source]

一个上下文管理器,用于临时将 model 的训练模式设置为 mode,并在退出 with 块时将其重置。

参数
  • model – 与 export()model 参数类型和含义相同。

  • mode (TrainingMode) – 与 export()training 参数类型和含义相同。

torch.onnx.is_in_onnx_export()[source]

返回是否处于 ONNX 导出过程中。

返回类型

bool

torch.onnx.enable_log()[source]

启用 ONNX 日志记录。

torch.onnx.disable_log()[source]

禁用 ONNX 日志记录。

torch.onnx.verification.find_mismatch(model, input_args, do_constant_folding=True, training=<TrainingMode.EVAL: 0>, opset_version=None, keep_initializers_as_inputs=True, verbose=False, options=None)[source]

查找原始模型和导出模型之间的所有不匹配项。

实验性。API 可能会发生变化。

此工具有助于调试原始 PyTorch 模型和导出的 ONNX 模型之间的不匹配。它对模型图进行二进制搜索,以查找表现出不匹配的最小子图。

参数
返回

包含不匹配信息的 GraphInfo 对象。

返回类型

GraphInfo

示例

>>> import torch
>>> import torch.onnx.verification
>>> torch.manual_seed(0)
>>> opset_version = 15
>>> # Define a custom symbolic function for aten::relu.
>>> # The custom symbolic function is incorrect, which will result in mismatches.
>>> def incorrect_relu_symbolic_function(g, self):
...     return self
>>> torch.onnx.register_custom_op_symbolic(
...     "aten::relu",
...     incorrect_relu_symbolic_function,
...     opset_version=opset_version,
... )
>>> class Model(torch.nn.Module):
...     def __init__(self) -> None:
...         super().__init__()
...         self.layers = torch.nn.Sequential(
...             torch.nn.Linear(3, 4),
...             torch.nn.ReLU(),
...             torch.nn.Linear(4, 5),
...             torch.nn.ReLU(),
...             torch.nn.Linear(5, 6),
...         )
...     def forward(self, x):
...         return self.layers(x)
>>> graph_info = torch.onnx.verification.find_mismatch(
...     Model(),
...     (torch.randn(2, 3),),
...     opset_version=opset_version,
... )
===================== Mismatch info for graph partition : ======================
================================ Mismatch error ================================
Tensor-likes are not close!
Mismatched elements: 12 / 12 (100.0%)
Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
==================================== Tree: =====================================
5 X   __2 X    __1 \u2713
id:  |  id: 0 |  id: 00
     |        |
     |        |__1 X (aten::relu)
     |           id: 01
     |
     |__3 X    __1 \u2713
        id: 1 |  id: 10
              |
              |__2 X     __1 X (aten::relu)
                 id: 11 |  id: 110
                        |
                        |__1 \u2713
                           id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}

JitScalarType

在 torch 中定义的标量类型。

verification.GraphInfo

GraphInfo 包含 TorchScript 图与其转换的 ONNX 图的验证信息。

verification.VerificationOptions

ONNX 导出验证的选项。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源