快捷方式

torch.onnx.verification

ONNX 验证模块提供了一套工具,用于验证 ONNX 模型的正确性。

torch.onnx.verification.verify_onnx_program(onnx_program, args=None, kwargs=None, compare_intermediates=False)[源码]

通过将 ONNX 模型的值与来自 ExportedProgram 的预期值进行比较来验证 ONNX 模型。

参数
  • onnx_program (_onnx_program.ONNXProgram) – 要验证的 ONNX program。

  • args (tuple[Any, ...] | None) – 模型的输入参数。

  • kwargs (dict[str, Any] | None) – 模型的关键字参数。

  • compare_intermediates (bool) – 是否验证中间值。这将花费更长时间,因此默认禁用。

返回

包含每个值的验证信息的 VerificationInfo 对象。

返回类型

list[VerificationInfo]

class torch.onnx.verification.VerificationInfo(name, max_abs_diff, max_rel_diff, abs_diff_hist, rel_diff_hist, expected_dtype, actual_dtype)

ONNX program 中某个值的验证信息。

此类包含最大绝对差值、最大相对差值以及预期值与实际值之间绝对差值和相对差值的直方图。它还包括预期和实际数据类型。

直方图表示为张量的元组,其中第一个张量是直方图计数,第二个张量是 bin 边界。

变量
  • name (str) – 值的名称(输出或中间值)。

  • max_abs_diff (float) – 预期值与实际值之间的最大绝对差值。

  • max_rel_diff (float) – 预期值与实际值之间的最大相对差值。

  • abs_diff_hist (tuple[torch.Tensor, torch.Tensor]) – 表示绝对差值直方图的张量元组。第一个张量是直方图计数,第二个张量是 bin 边界。

  • rel_diff_hist (tuple[torch.Tensor, torch.Tensor]) – 表示相对差值直方图的张量元组。第一个张量是直方图计数,第二个张量是 bin 边界。

  • expected_dtype (torch.dtype) – 预期值的数据类型。

  • actual_dtype (torch.dtype) – 实际值的数据类型。

classmethod from_tensors(name, expected, actual)[源码][源码]

从两个张量创建一个 VerificationInfo 对象。

参数
返回

VerificationInfo 对象。

返回类型

VerificationInfo

torch.onnx.verification.verify(model, input_args, input_kwargs=None, do_constant_folding=True, dynamic_axes=None, input_names=None, output_names=None, training=<TrainingMode.EVAL: 0>, opset_version=None, keep_initializers_as_inputs=True, verbose=False, fixed_batch_size=False, use_external_data=False, additional_test_inputs=None, options=None)[源码][源码]

验证模型导出到 ONNX 是否与原始 PyTorch 模型一致。

自版本 2.7 起已弃用: 考虑使用 torch.onnx.export(..., dynamo=True) 并使用返回的 ONNXProgram 来测试 ONNX 模型。

参数
触发
  • AssertionError – 如果 ONNX 模型和 PyTorch 模型的输出在指定精度内不相等。

  • ValueError – 如果提供的参数无效。

已弃用

以下类和函数已弃用。

class torch.onnx.verification.check_export_model_diff[源码][源码]
class torch.onnx.verification.GraphInfo[源码][源码]
class torch.onnx.verification.GraphInfoPrettyPrinter[源码][源码]
class torch.onnx.verification.OnnxBackend[源码][源码]
class torch.onnx.verification.OnnxTestCaseRepro[源码][源码]
class torch.onnx.verification.VerificationOptions[源码][源码]
torch.onnx.verification.find_mismatch()[源码][源码]
torch.onnx.verification.verify_aten_graph()[源码][源码]

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源