基于 TorchScript 的 ONNX 导出器¶
注意
要使用 TorchDynamo 而不是 TorchScript 导出 ONNX 模型,请参见 torch.onnx.dynamo_export()
.
示例:将 PyTorch 中的 AlexNet 转换为 ONNX¶
这是一个简单的脚本,它将预训练的 AlexNet 导出到名为 alexnet.onnx
的 ONNX 文件中。对 torch.onnx.export
的调用会运行一次模型以追踪其执行,然后将追踪的模型导出到指定的文件
import torch
import torchvision
dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()
# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
生成的 alexnet.onnx
文件包含一个二进制 协议缓冲区,其中包含您导出的模型的网络结构和参数(在本例中为 AlexNet)。参数 verbose=True
会导致导出器打印出模型的人类可读表示
# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
%learned_0 : Float(64, 3, 11, 11)
%learned_1 : Float(64)
%learned_2 : Float(192, 64, 5, 5)
%learned_3 : Float(192)
# ---- omitted for brevity ----
%learned_14 : Float(1000, 4096)
%learned_15 : Float(1000)) {
# Every statement consists of some output tensors (and their types),
# the operator to be run (with its attributes, e.g., kernels, strides,
# etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
%17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
%18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
%19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
# ---- omitted for brevity ----
%29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
# Dynamic means that the shape is not known. This may be because of a
# limitation of our implementation (which we would like to fix in a
# future release) or shapes which are truly dynamic.
%30 : Dynamic = onnx::Shape(%29), scope: AlexNet
%31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
%32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
%33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
# ---- omitted for brevity ----
%output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
return (%output1);
}
您还可以使用 ONNX 库验证输出,您可以使用 pip
安装该库
pip install onnx
然后,您可以运行
import onnx
# Load the ONNX model
model = onnx.load("alexnet.onnx")
# Check that the model is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))
您也可以使用众多支持 ONNX 的运行时来运行导出的模型。例如,在安装ONNX Runtime后,您可以加载并运行模型。
import onnxruntime as ort
import numpy as np
ort_session = ort.InferenceSession("alexnet.onnx")
outputs = ort_session.run(
None,
{"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])
这里有一个更详细的关于导出模型并使用 ONNX Runtime 运行它的教程。
跟踪与脚本化¶
在内部,torch.onnx.export()
需要一个torch.jit.ScriptModule
而不是一个torch.nn.Module
。如果传入的模型不是 ScriptModule
,export()
将使用跟踪将其转换为一个。
跟踪:如果
torch.onnx.export()
被调用时,模型不是一个ScriptModule
,它首先会执行等效于torch.jit.trace()
的操作,该操作会使用给定的args
执行模型一次,并记录执行过程中发生的所有操作。这意味着,如果您的模型是动态的,例如,根据输入数据改变行为,导出的模型不会捕获这种动态行为。我们建议检查导出的模型,并确保操作符看起来合理。跟踪会展开循环和 if 语句,导出一个与跟踪运行完全相同的静态图。如果您想导出具有动态控制流的模型,则需要使用脚本化。脚本化:通过脚本化编译模型可以保留动态控制流,并且对不同大小的输入有效。要使用脚本化
使用
torch.jit.script()
生成一个ScriptModule
。使用
ScriptModule
作为模型调用torch.onnx.export()
。仍然需要args
,但它们只会在内部使用以生成示例输出,以便可以捕获输出的类型和形状。不会执行跟踪。
有关更多详细信息,包括如何组合跟踪和脚本以满足不同模型的特定要求,请参阅TorchScript 简介和TorchScript。
避免陷阱¶
避免使用 NumPy 和内置 Python 类型¶
PyTorch 模型可以使用 NumPy 或 Python 类型和函数编写,但在跟踪期间,任何 NumPy 或 Python 类型(而不是 torch.Tensor)的变量都会被转换为常量,如果这些值应该根据输入而改变,则会产生错误的结果。
例如,不要在 numpy.ndarrays 上使用 numpy 函数
# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)
在 torch.Tensors 上使用 torch 运算符
# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)
不要使用torch.Tensor.item()
(它将 Tensor 转换为 Python 内置数字)
# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
return x.reshape(y.item(), -1)
使用 torch 对单元素张量的隐式转换的支持
# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
return x.reshape(y, -1)
避免使用 Tensor.data¶
使用 Tensor.data 字段可能会产生不正确的跟踪,从而导致不正确的 ONNX 图。请改用torch.Tensor.detach()
。(正在进行的工作是完全删除 Tensor.data)。
在跟踪模式下使用 tensor.shape 时避免就地操作¶
在跟踪模式下,从 tensor.shape
获取的形状被跟踪为张量,并共享相同的内存。这可能会导致最终输出值不匹配。作为解决方法,请避免在这些情况下使用就地操作。例如,在模型中
class Model(torch.nn.Module):
def forward(self, states):
batch_size, seq_length = states.shape[:2]
real_seq_length = seq_length
real_seq_length += 2
return real_seq_length + seq_length
real_seq_length
和 seq_length
在跟踪模式下共享相同的内存。可以通过重写就地操作来避免这种情况
real_seq_length = real_seq_length + 2
限制¶
类型¶
仅支持
torch.Tensors
、可以轻松转换为 torch.Tensors 的数字类型(例如 float、int)以及这些类型的元组和列表作为模型输入或输出。字典和字符串输入和输出在跟踪模式下被接受,但是任何依赖于字典或字符串输入值的计算**将被替换为**在一次跟踪执行期间看到的常数值。
任何字典类型的输出将被静默地替换为**其值的扁平化序列(键将被移除)**。例如,
{"foo": 1, "bar": 2}
将变为(1, 2)
。任何字符串类型的输出将被静默地移除。
由于 ONNX 对嵌套序列的支持有限,在脚本模式下,某些涉及元组和列表的操作不受支持。特别是将元组追加到列表的操作不受支持。在跟踪模式下,嵌套序列将在跟踪期间自动扁平化。
操作符实现差异¶
由于操作符实现的差异,在不同运行时上运行导出的模型可能会产生与彼此或与 PyTorch 不同的结果。通常这些差异在数值上很小,因此只有当您的应用程序对这些微小差异敏感时,这才是需要关注的问题。
不支持的张量索引模式¶
无法导出的张量索引模式列在下面。如果您在导出不包含以下任何不支持模式的模型时遇到问题,请仔细检查您是否使用最新的 opset_version
进行导出。
读取/获取¶
当索引到张量中进行读取时,以下模式不受支持
# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.
写入/设置¶
当索引到张量中进行写入时,以下模式不受支持
# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
# or multiple consecutive tensor indices with rank == 1.
# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.
# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.
# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
# data shape: [3, 4, 5]
# new_data shape: [5]
# expected new_data shape after broadcasting: [2, 2, 2, 5]
添加对操作符的支持¶
当导出包含不支持操作符的模型时,您将看到类似以下的错误消息
RuntimeError: ONNX export failed: Couldn't export operator foo
当这种情况发生时,您可以做几件事
更改模型以不使用该操作符。
创建一个符号函数来转换操作符,并将其注册为自定义符号函数。
为 PyTorch 贡献代码,将相同的符号函数添加到
torch.onnx
本身。
如果您决定实现一个符号函数(我们希望您能将其贡献回 PyTorch!),以下是如何开始。
ONNX 导出器内部¶
“符号函数”是一个将 PyTorch 运算符分解为一系列 ONNX 运算符组合的函数。
在导出期间,导出器会以拓扑顺序访问 TorchScript 图中的每个节点(包含 PyTorch 运算符)。访问节点时,导出器会为该运算符查找已注册的符号函数。符号函数是用 Python 实现的。名为 foo
的运算符的符号函数看起来像这样
def foo(
g,
input_0: torch._C.Value,
input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
"""
Adds the ONNX operations representing this PyTorch function by updating the
graph g with `g.op()` calls.
Args:
g (Graph): graph to write the ONNX representation into.
input_0 (Value): value representing the variables which contain
the first input for this operator.
input_1 (Value): value representing the variables which contain
the second input for this operator.
Returns:
A Value or List of Values specifying the ONNX nodes that compute something
equivalent to the original PyTorch operator with the given inputs.
None if it cannot be converted to ONNX.
"""
...
torch._C
类型是围绕 ir.h 中用 C++ 定义的类型的 Python 包装器。
添加符号函数的过程取决于运算符的类型。
ATen 运算符¶
ATen 是 PyTorch 的内置张量库。如果运算符是 ATen 运算符(在 TorchScript 图中以 aten::
为前缀显示),请确保它尚未得到支持。
支持的运算符列表¶
访问自动生成的 支持的 TorchScript 运算符列表,了解每个 opset_version
中支持哪些运算符的详细信息。
添加对 aten 或量化运算符的支持¶
如果操作符不在上面的列表中
在
torch/onnx/symbolic_opset<version>.py
中定义符号函数,例如 torch/onnx/symbolic_opset9.py。确保函数名称与 ATen 函数相同,该函数可能在torch/_C/_VariableFunctions.pyi
或torch/nn/functional.pyi
中声明(这些文件在构建时生成,因此在构建 PyTorch 之前不会出现在您的检出中)。默认情况下,第一个参数是 ONNX 图。其他参数名称必须与
.pyi
文件中的名称完全匹配,因为调度是使用关键字参数完成的。在符号函数中,如果操作符在 ONNX 标准操作符集中,我们只需要创建一个节点来表示图中的 ONNX 操作符。如果不是,我们可以组合几个具有等效语义的标准操作符来表示 ATen 操作符。
以下是如何处理 ELU
操作符缺少符号函数的示例。
如果我们运行以下代码
print(
torch.jit.trace(
torch.nn.ELU(), # module
torch.ones(1) # example input
).graph
)
我们会看到类似的东西
graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%4 : float = prim::Constant[value=1.]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=1]()
%7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
return (%7)
由于我们在图中看到了 aten::elu
,我们知道这是一个 ATen 操作符。
我们检查 ONNX 操作符列表,并确认 Elu
在 ONNX 中是标准化的。
我们在 torch/nn/functional.pyi
中找到了 elu
的签名
def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
我们在 symbolic_opset9.py
中添加以下几行
def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
return g.op("Elu", input, alpha_f=alpha)
现在 PyTorch 能够导出包含 aten::elu
操作符的模型!
有关更多示例,请参阅 torch/onnx/symbolic_opset*.py
文件。
torch.autograd.Functions¶
如果操作符是 torch.autograd.Function
的子类,则有三种方法可以导出它。
静态符号方法¶
您可以在函数类中添加一个名为 symbolic
的静态方法。它应该返回表示函数行为的 ONNX 运算符。例如
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
内联自动微分函数¶
在没有为其后续 torch.autograd.Function
提供静态符号方法,或没有提供用于注册 prim::PythonOp
作为自定义符号函数的函数的情况下,torch.onnx.export()
尝试内联对应于该 torch.autograd.Function
的图,以便将此函数分解为函数中使用的单个运算符。只要这些单个运算符受支持,导出应该成功。例如
class MyLogExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
h = input.exp()
return h.log().log()
此模型没有静态符号方法,但它按如下方式导出
graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
%1 : float = onnx::Exp[](%input)
%2 : float = onnx::Log[](%1)
%3 : float = onnx::Log[](%2)
return (%3)
如果您需要避免内联 torch.autograd.Function
,您应该使用 operator_export_type
设置为 ONNX_FALLTHROUGH
或 ONNX_ATEN_FALLBACK
来导出模型。
自定义运算符¶
您可以使用自定义运算符导出您的模型,这些运算符包括许多标准 ONNX 运算符的组合,或者由自定义的 C++ 后端驱动。
ONNX-script 函数¶
如果一个运算符不是标准的 ONNX 运算符,但可以由多个现有的 ONNX 运算符组成,您可以利用 ONNX-script 创建一个外部 ONNX 函数来支持该运算符。您可以按照此示例进行导出
import onnxscript
# There are three opset version needed to be aligned
# This is (1) the opset version in ONNX function
from onnxscript.onnx_opset import opset15 as op
opset_version = 15
x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
@onnxscript.script(custom_opset)
def Selu(X):
alpha = 1.67326 # auto wrapped as Constants
gamma = 1.0507
alphaX = op.CastLike(alpha, X)
gammaX = op.CastLike(gamma, X)
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
pos = gammaX * X
zero = op.CastLike(0, X)
return op.Where(X <= zero, neg, pos)
# setType API provides shape/type to ONNX shape/type inference
def custom_selu(g: jit_utils.GraphContext, X):
return g.onnxscript_op(Selu, X).setType(X.type())
# Register custom symbolic function
# There are three opset version needed to be aligned
# This is (2) the opset version in registry
torch.onnx.register_custom_op_symbolic(
symbolic_name="aten::selu",
symbolic_fn=custom_selu,
opset_version=opset_version,
)
# There are three opset version needed to be aligned
# This is (2) the opset version in exporter
torch.onnx.export(
model,
x,
"model.onnx",
opset_version=opset_version,
# only needed if you want to specify an opset version > 1.
custom_opsets={"onnx-script": 2}
)
上面的示例将其导出为“onnx-script” opset 中的自定义运算符。在导出自定义运算符时,您可以使用导出时的 custom_opsets
字典指定自定义域版本。如果未指定,自定义 opset 版本默认为 1。
注意:请注意上面示例中提到的 opset 版本,并确保它们在导出步骤中被使用。关于如何编写 onnx-script 函数的示例用法是 onnx-script 积极开发中的 beta 版本。请遵循最新的 ONNX-script
C++ 运算符¶
如果模型使用在 C++ 中实现的自定义运算符,如 使用自定义 C++ 运算符扩展 TorchScript 中所述,您可以按照以下示例导出它
from torch.onnx import symbolic_helper
# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)
class FooModel(torch.nn.Module):
def __init__(self, attr1, attr2):
super().__init__()
self.attr1 = attr1
self.attr2 = attr2
def forward(self, input1, input2):
# Calling custom op
return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)
model = FooModel(attr1, attr2)
torch.onnx.export(
model,
(example_input1, example_input1),
"model.onnx",
# only needed if you want to specify an opset version > 1.
custom_opsets={"custom_domain": 2}
)
上面的示例将其导出为“custom_domain” opset 中的自定义运算符。导出自定义运算符时,可以使用导出时的 custom_opsets
字典指定自定义域版本。如果未指定,自定义 opset 版本默认为 1。
使用模型的运行时需要支持自定义运算符。请参阅 Caffe2 自定义运算符、ONNX Runtime 自定义运算符 或您选择的运行时的文档。
一次性发现所有不可转换的 ATen 运算符¶
如果导出失败是因为不可转换的 ATen 运算符,实际上可能存在多个这样的运算符,但错误消息只提到了第一个。要一次性发现所有不可转换的运算符,您可以
# prepare model, args, opset_version
...
torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
model, args, opset_version=opset_version
)
print(set(unconvertible_ops))
该集合是近似的,因为某些运算符可能会在转换过程中被移除,不需要转换。其他一些运算符可能具有部分支持,在特定输入下会失败转换,但这应该可以让你大致了解哪些运算符不受支持。如果您需要支持运算符,请随时在 GitHub 上提交问题。
常见问题解答¶
问:我已经导出了我的 LSTM 模型,但它的输入大小似乎是固定的?
跟踪器记录了示例输入的形状。如果模型应该接受动态形状的输入,请在调用
torch.onnx.export()
时设置dynamic_axes
。
Q: 如何导出包含循环的模型?
参见 跟踪与脚本.
Q: 如何导出具有原始类型输入(例如 int、float)的模型?
PyTorch 1.9 中添加了对原始数字类型输入的支持。但是,导出器不支持具有 str 输入的模型。
Q: ONNX 是否支持隐式标量数据类型转换?
ONNX 标准不支持,但导出器将尝试处理该部分。标量将导出为常量张量。导出器将确定标量的正确数据类型。在极少数情况下,如果它无法做到这一点,您需要使用例如 dtype=torch.float32 手动指定数据类型。如果您看到任何错误,请 [创建 GitHub 问题](https://github.com/pytorch/pytorch/issues).
Q: 张量列表是否可以导出到 ONNX?
是的,对于
opset_version
>= 11,因为 ONNX 在 opset 11 中引入了 Sequence 类型。
Python API¶
函数¶
- torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True)[source]¶
将模型导出到 ONNX 格式。
如果
model
不是torch.jit.ScriptModule
也不 是torch.jit.ScriptFunction
,这将运行model
一次以将其转换为要导出的 TorchScript 图(等效于torch.jit.trace()
)。因此,这与torch.jit.trace()
具有相同的有限动态控制流支持。- 参数
model (
torch.nn.Module
,torch.jit.ScriptModule
或torch.jit.ScriptFunction
) – 要导出的模型。args (tuple 或 torch.Tensor) –
args 可以结构化为
仅参数元组
args = (x, y, z)
元组应包含模型输入,以便
model(*args)
是模型的有效调用。任何非张量参数将被硬编码到导出的模型中;任何张量参数都将成为导出模型的输入,其顺序与它们在元组中的顺序相同。张量
args = torch.Tensor([1])
这等效于该张量的 1 元元组。
以命名参数字典结尾的参数元组
args = ( x, { "y": input_y, "z": input_z } )
元组中除最后一个元素之外的所有元素都将作为非关键字参数传递,并且命名参数将从最后一个元素设置。如果字典中不存在命名参数,则将其分配默认值,如果未提供默认值,则分配 None。
注意
如果字典是 args 元组的最后一个元素,它将被解释为包含命名参数。为了将字典作为最后一个非关键字参数传递,请在 args 元组的最后一个元素中提供一个空字典。例如,不要写
torch.onnx.export( model, ( x, # WRONG: will be interpreted as named arguments {y: z} ), "test.onnx.pb" )
写
torch.onnx.export( model, ( x, {y: z}, {} ), "test.onnx.pb" )
f (Union[str, BytesIO]) – 类文件对象(使得
f.fileno()
返回文件描述符)或包含文件名字符串。二进制协议缓冲区将写入此文件。export_params (bool, 默认值为 True) – 如果为 True,所有参数都将被导出。如果要导出未训练的模型,请将其设置为 False。在这种情况下,导出的模型将首先将所有参数作为参数,其顺序由
model.state_dict().values()
指定verbose (bool, 默认值为 False) – 如果为 True,则将模型的描述打印到标准输出。此外,最终的 ONNX 图将包含来自导出模型的字段
doc_string`
,其中提到了model
的源代码位置。如果为 True,ONNX 导出器日志记录将被打开。training (枚举, 默认值为 TrainingMode.EVAL) –
TrainingMode.EVAL
: 以推理模式导出模型。TrainingMode.PRESERVE
: 如果 model.training 为 False,则以推理模式导出模型;如果 model.training 为 True,则以训练模式导出模型。False,如果 model.training 为 True,则以训练模式导出模型。
TrainingMode.TRAINING
: 以训练模式导出模型。禁用可能干扰训练的优化可能会干扰训练。
operator_export_type (枚举, 默认值为 OperatorExportTypes.ONNX) –
OperatorExportTypes.ONNX
: 将所有操作导出为常规 ONNX 操作(在默认的 opset 域中)。
OperatorExportTypes.ONNX_FALLTHROUGH
: 尝试将所有操作符转换为默认 opset 域中的标准 ONNX 操作符。如果无法做到这一点(例如,因为尚未添加支持将特定 torch 操作符转换为 ONNX),则回退到将操作符导出到自定义 opset 域而无需转换。适用于 自定义操作符 以及 ATen 操作符。为了使导出的模型可用,运行时必须支持这些非标准操作符。
OperatorExportTypes.ONNX_ATEN
: 所有 ATen 操作符(在 TorchScript 命名空间“aten”中)都导出为 ATen 操作符(在 opset 域“org.pytorch.aten”中)。ATen 是 PyTorch 的内置张量库,因此这指示运行时使用 PyTorch 的这些操作符的实现。
警告
以这种方式导出的模型可能只能由 Caffe2 运行。
如果操作符实现中的数值差异导致 PyTorch 和 Caffe2 之间出现较大行为差异(这在未训练的模型中更为常见),这可能会有用。
OperatorExportTypes.ONNX_ATEN_FALLBACK
: 尝试将每个 ATen 操作符(在 TorchScript 命名空间“aten”中)导出为常规 ONNX 操作符。如果我们无法做到这一点(例如,因为尚未添加支持将特定 torch 操作符转换为 ONNX),则回退到导出 ATen 操作符。有关上下文,请参阅 OperatorExportTypes.ONNX_ATEN 上的文档。例如
graph(%0 : Float): %3 : int = prim::Constant[value=0]() # conversion unsupported %4 : Float = aten::triu(%0, %3) # conversion supported %5 : Float = aten::mul(%4, %0) return (%5)
假设
aten::triu
在 ONNX 中不受支持,这将导出为graph(%0 : Float): %1 : Long() = onnx::Constant[value={0}]() # not converted %2 : Float = aten::ATen[operator="triu"](%0, %1) # converted %3 : Float = onnx::Mul(%2, %0) return (%3)
如果 PyTorch 是使用 Caffe2 构建的(即使用
BUILD_CAFFE2=1
),那么将启用 Caffe2 特定的行为,包括对 量化 中描述的模块产生的操作符的特殊支持。警告
以这种方式导出的模型可能只能由 Caffe2 运行。
opset_version (int, default 17) – 要针对的 默认 (ai.onnx) opset 的版本。必须 >= 7 且 <= 17。
do_constant_folding (布尔值, 默认值为 True) – 应用常量折叠优化。常量折叠将用预先计算的常量节点替换一些所有输入都是常量的操作。
dynamic_axes (字典[字符串, 字典[整数, 字符串]] or 字典[字符串, 列表(整数)], 默认值为空字典) –
默认情况下,导出的模型将具有所有输入和输出张量的形状,这些形状设置为与
args
中给出的形状完全匹配。要将张量的轴指定为动态(即仅在运行时才知道),请将dynamic_axes
设置为具有以下模式的字典- 键 (str):输入或输出名称。每个名称也必须在
input_names
中提供,或者 output_names
.
- 键 (str):输入或输出名称。每个名称也必须在
- 值 (字典或列表):如果为字典,则键为轴索引,值为轴名称。如果为
列表,则每个元素都是一个轴索引。
例如
class SumModule(torch.nn.Module): def forward(self, x): return torch.sum(x, dim=1) torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"] )
产生
input { name: "x" ... shape { dim { dim_value: 2 # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_value: 2 # axis 0 ...
而
torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], dynamic_axes={ # dict value: manually named axes "x": {0: "my_custom_axis_name"}, # list value: automatic names "sum": [0], } )
产生
input { name: "x" ... shape { dim { dim_param: "my_custom_axis_name" # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_param: "sum_dynamic_axes_1" # axis 0 ...
keep_initializers_as_inputs (布尔值, 默认值为 None) –
如果为 True,则导出图中的所有初始化器(通常对应于参数)也将作为图的输入添加。如果为 False,则初始化器不会作为图的输入添加,只有非参数输入会被添加为输入。这可能允许后端/运行时进行更好的优化(例如常量折叠)。
如果为 True,则不会执行 deduplicate_initializers 传递。这意味着具有重复值的初始化器不会被去重,并将被视为图的独立输入。这允许在导出后在运行时提供不同的输入初始化器。
如果
opset_version < 9
,则初始化器必须是图输入的一部分,并且此参数将被忽略,其行为等同于将此参数设置为 True。如果为 None,则行为将根据以下方式自动选择
- 如果
operator_export_type=OperatorExportTypes.ONNX
,则行为等同于 将此参数设置为 False。
- 如果
否则,行为等同于将此参数设置为 True。
custom_opsets (dict[str, int], 默认空字典) –
具有以下模式的字典
KEY (str): opset 域名称
VALUE (int): opset 版本
如果
model
引用了自定义 opset 但未在此字典中提及,则 opset 版本将设置为 1。仅应通过此参数指示自定义 opset 域名称和版本。export_modules_as_functions (bool 或 set of type of nn.Module, 默认 False) –
标志,用于启用将所有
nn.Module
正向调用作为本地函数导出到 ONNX 中。或者一个集合,用于指示要作为本地函数导出到 ONNX 中的特定类型的模块。此功能需要opset_version
>= 15,否则导出将失败。这是因为opset_version
< 15 意味着 IR 版本 < 8,这意味着不支持本地函数。模块变量将作为函数属性导出。函数属性有两类。1. 注释属性:通过 PEP 526 风格 具有类型注释的类变量将被导出为属性。注释属性不会在 ONNX 局部函数的子图中使用,因为它们不是由 PyTorch JIT 跟踪创建的,但它们可能会被使用者用来确定是否用特定融合内核替换函数。
2. 推断属性:模块内部运算符使用的变量。属性名称将以“inferred::”为前缀。这是为了与从 Python 模块注释中检索到的预定义属性区分开来。推断属性在 ONNX 局部函数的子图中使用。
False
(默认):将nn.Module
前向调用导出为细粒度节点。True
:将所有nn.Module
前向调用导出为局部函数节点。- 一组 nn.Module 类型:将
nn.Module
前向调用导出为局部函数节点, 仅当
nn.Module
的类型在该集合中找到时。
- 一组 nn.Module 类型:将
autograd_inlining (bool, 默认 True) – 用于控制是否内联 autograd 函数的标志。有关更多详细信息,请参阅 https://github.com/pytorch/pytorch/pull/74765。
- 引发
torch.onnx.errors.CheckerError – 如果 ONNX 检查器检测到无效的 ONNX 图。
torch.onnx.errors.UnsupportedOperatorError – 如果 ONNX 图无法导出,因为它使用了导出器不支持的运算符。
torch.onnx.errors.OnnxExporterError – 导出过程中可能发生的其它错误。所有错误都是
errors.OnnxExporterError
的子类。
- torch.onnx.export_to_pretty_string(model, args, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, export_type=None, google_printer=False, opset_version=None, keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True, do_constant_folding=True, dynamic_axes=None)[source]¶
与
export()
相似,但返回 ONNX 模型的文本表示形式。仅以下列出的参数存在差异。所有其他参数与export()
相同。
- torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[source]¶
为自定义算子注册符号函数。
当用户为自定义/contrib 算子注册符号函数时,强烈建议通过 setType API 为该算子添加形状推断,否则导出的图在某些极端情况下可能会出现错误的形状推断。setType 的示例是 test_operators.py 中的 test_aten_embedding_2。
有关示例用法,请参阅模块文档中的“自定义算子”。
- torch.onnx.unregister_custom_op_symbolic(symbolic_name, opset_version)[source]¶
取消注册
symbolic_name
。有关示例用法,请参阅模块文档中的“自定义算子”。
- torch.onnx.select_model_mode_for_export(model, mode)[source]¶
一个上下文管理器,用于将
model
的训练模式临时设置为mode
,并在退出 with 块时将其重置。
- torch.onnx.verification.find_mismatch(model, input_args, do_constant_folding=True, training=<TrainingMode.EVAL: 0>, opset_version=None, keep_initializers_as_inputs=True, verbose=False, options=None)[source]¶
查找原始模型和导出模型之间的所有不匹配项。
实验性。API 可能会发生变化。
此工具有助于调试原始 PyTorch 模型和导出 ONNX 模型之间的不匹配。它对模型图进行二进制搜索,以找到表现出不匹配的最小子图。
- 参数
model (Union[Module, ScriptModule]) – 要导出的模型。
do_constant_folding (bool) – 与
torch.onnx.export()
中的 do_constant_folding 相同。training (TrainingMode) – 与
torch.onnx.export()
中的 training 相同。opset_version (Optional[int]) – 与
torch.onnx.export()
中的 opset_version 相同。keep_initializers_as_inputs (bool) – 与
torch.onnx.export()
中的 keep_initializers_as_inputs 相同。verbose (bool) – 与
torch.onnx.export()
中的 verbose 相同。options (Optional[VerificationOptions]) – 不匹配验证的选项。
- 返回值
包含不匹配信息的 GraphInfo 对象。
- 返回类型
示例
>>> import torch >>> import torch.onnx.verification >>> torch.manual_seed(0) >>> opset_version = 15 >>> # Define a custom symbolic function for aten::relu. >>> # The custom symbolic function is incorrect, which will result in mismatches. >>> def incorrect_relu_symbolic_function(g, self): ... return self >>> torch.onnx.register_custom_op_symbolic( ... "aten::relu", ... incorrect_relu_symbolic_function, ... opset_version=opset_version, ... ) >>> class Model(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.layers = torch.nn.Sequential( ... torch.nn.Linear(3, 4), ... torch.nn.ReLU(), ... torch.nn.Linear(4, 5), ... torch.nn.ReLU(), ... torch.nn.Linear(5, 6), ... ) ... def forward(self, x): ... return self.layers(x) >>> graph_info = torch.onnx.verification.find_mismatch( ... Model(), ... (torch.randn(2, 3),), ... opset_version=opset_version, ... ) ===================== Mismatch info for graph partition : ====================== ================================ Mismatch error ================================ Tensor-likes are not close! Mismatched elements: 12 / 12 (100.0%) Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) ==================================== Tree: ===================================== 5 X __2 X __1 ✓ id: | id: 0 | id: 00 | | | |__1 X (aten::relu) | id: 01 | |__3 X __1 ✓ id: 1 | id: 10 | |__2 X __1 X (aten::relu) id: 11 | id: 110 | |__1 ✓ id: 111 =========================== Mismatch leaf subgraphs: =========================== ['01', '110'] ============================= Mismatch node kinds: ============================= {'aten::relu': 2}