torch.onnx¶
概述¶
开放神经网络交换 (ONNX) 是一种用于表示机器学习模型的开放标准格式。 torch.onnx
模块捕获来自原生 PyTorch torch.nn.Module
模型的计算图,并将其转换为 ONNX 图。
导出的模型可以被任何支持 ONNX 的众多 运行时 使用,包括微软的 ONNX Runtime。
您可以使用两种 ONNX 导出器 API,如下所示
基于 TorchDynamo 的 ONNX 导出器¶
基于 TorchDynamo 的 ONNX 导出器是 PyTorch 2.0 及更高版本最新的(也是 Beta 版)导出器
TorchDynamo 引擎被用来连接到 Python 的帧评估 API,并动态地将其字节码重写为 FX 图。生成的 FX 图将在最终转换为 ONNX 图之前进行优化。
这种方法的主要优势在于,FX 图 是使用字节码分析捕获的,它保留了模型的动态特性,而不是使用传统的静态跟踪技术。
基于 TorchScript 的 ONNX 导出器¶
基于 TorchScript 的 ONNX 导出器从 PyTorch 1.2.0 开始可用
TorchScript 被用来跟踪(通过 torch.jit.trace()
)模型并捕获一个静态计算图。
因此,生成的图有一些限制
它不记录任何控制流,例如 if 语句或循环;
不处理
training
和eval
模式之间的细微差别;不真正处理动态输入
为了尝试支持静态跟踪的限制,导出器还支持 TorchScript 脚本(通过 torch.jit.script()
),这增加了对数据依赖控制流的支持,例如。但是,TorchScript 本身是 Python 语言的一个子集,因此并非 Python 中的所有功能都受支持,例如就地操作。