torch.set_default_dtype¶
- torch.set_default_dtype(d, /)[源代码][源代码]¶
将默认浮点 dtype 设置为
d
。支持浮点 dtype 作为输入。其他 dtype 将导致 torch 引发异常。当 PyTorch 初始化时,其默认浮点 dtype 为 torch.float32,而 set_default_dtype(torch.float64) 的目的是为了方便类似 NumPy 的类型推断。默认浮点 dtype 用于
隐式确定默认复数 dtype。当默认浮点类型为 float16 时,默认复数 dtype 为 complex32。对于 float32,默认复数 dtype 为 complex64。对于 float64,则为 complex128。对于 bfloat16,由于 bfloat16 没有对应的复数类型,因此会引发异常。
推断使用 Python 浮点数或复数 Python 数字构造的张量的 dtype。请参阅以下示例。
确定布尔张量和整数张量与 Python 浮点数和复数 Python 数字之间类型提升的结果。
- 参数
d (
torch.dtype
) – 要设置为默认值的浮点 dtype。
示例
>>> # initial default for floating point is torch.float32 >>> # Python floats are interpreted as float32 >>> torch.tensor([1.2, 3]).dtype torch.float32 >>> # initial default for floating point is torch.complex64 >>> # Complex Python numbers are interpreted as complex64 >>> torch.tensor([1.2, 3j]).dtype torch.complex64
>>> torch.set_default_dtype(torch.float64) >>> # Python floats are now interpreted as float64 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float64 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex128
>>> torch.set_default_dtype(torch.float16) >>> # Python floats are now interpreted as float16 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float16 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex32