快捷方式

JitScalarType

class torch.onnx.JitScalarType(value)

torch 中定义的标量类型。

使用 JitScalarType 将 torch 和 JIT 标量类型转换为 ONNX 标量类型。

示例

>>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type()
TensorProtoDataType.FLOAT
>>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type()
TensorProtoDataType.FLOAT
>>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type()
TensorProtoDataType.FLOAT
dtype()[source]

将 JitScalarType 转换为 torch dtype。

返回类型

dtype

classmethod from_dtype(dtype)[source]

将 torch dtype 转换为 JitScalarType。

注意:当 dtype 来自 torch._C.Value.type() 调用时,请勿使用此 API。

在某些情况下,形状信息不存在时,可能会引发“RuntimeError: INTERNAL ASSERT FAILED at “../aten/src/ATen/core/jit_type_base.h”。请改用 from_value API,它更安全。

参数

dtype (torch.dtype | None) – 用于创建 JitScalarType 的 torch.dtype

返回

JitScalarType

引发

OnnxExporterError – 如果 dtype 不是有效的 torch.dtype 或如果它为 None。

返回类型

JitScalarType

classmethod from_onnx_type(onnx_type)[source]

将 ONNX 数据类型转换为 JitScalarType。

参数

onnx_type (int | _C_onnx.TensorProtoDataType | None) – 用于创建 JitScalarType 的 torch._C._onnx.TensorProtoDataType

返回

JitScalarType

引发

OnnxExporterError – 如果 dtype 不是有效的 torch.dtype 或如果它为 None。

返回类型

JitScalarType

classmethod from_value(value, default=None)[source]

从值中的标量类型创建 JitScalarType。

参数
  • value (None | torch._C.Value | torch.Tensor) – 用于获取标量类型的对象。

  • default – 如果无法从 value 中获取有效标量,则返回的 JitScalarType

返回

JitScalarType。

引发
  • OnnxExporterError – 如果 value 没有有效的标量类型并且 default 为 None。

  • SymbolicValueError – 当 value.type() 的信息为空并且 default 为 None 时

返回类型

JitScalarType

onnx_compatible()[source]

返回此 JitScalarType 是否与 ONNX 兼容。

返回类型

bool

onnx_type()[source]

将 JitScalarType 转换为 ONNX 数据类型。

返回类型

TensorProtoDataType

scalar_name()[source]

将 JitScalarType 转换为 JIT 标量类型名称。

返回类型

Literal[‘Byte’, ‘Char’, ‘Double’, ‘Float’, ‘Half’, ‘Int’, ‘Long’, ‘Short’, ‘Bool’, ‘ComplexHalf’, ‘ComplexFloat’, ‘ComplexDouble’, ‘QInt8’, ‘QUInt8’, ‘QInt32’, ‘BFloat16’, ‘Float8E5M2’, ‘Float8E4M3FN’, ‘Float8E5M2FNUZ’, ‘Float8E4M3FNUZ’, ‘Undefined’]

torch_name()[source]

将 JitScalarType 转换为 PyTorch 类型名称。

返回类型

Literal[‘bool’, ‘uint8_t’, ‘int8_t’, ‘double’, ‘float’, ‘half’, ‘int’, ‘int64_t’, ‘int16_t’, ‘complex32’, ‘complex64’, ‘complex128’, ‘qint8’, ‘quint8’, ‘qint32’, ‘bfloat16’, ‘float8_e5m2’, ‘float8_e4m3fn’, ‘float8_e5m2fnuz’, ‘float8_e4m3fnuz’]

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源