torch.set_float32_matmul_precision¶
- torch.set_float32_matmul_precision(precision)[source][source]¶
设置 float32 矩阵乘法的内部精度。
以较低精度运行 float32 矩阵乘法可以显著提高性能,并且在某些程序中,精度损失的影响可以忽略不计。
支持三种设置
“最高”,float32 矩阵乘法内部计算使用 float32 数据类型(24 位尾数,显式存储 23 位)。
“高”,float32 矩阵乘法要么使用 TensorFloat32 数据类型(显式存储 10 位尾数),要么将每个 float32 数字视为两个 bfloat16 数字的总和(大约 16 位尾数,显式存储 14 位),如果可以使用适当的快速矩阵乘法算法。否则,float32 矩阵乘法的计算方式与精度为“最高”时相同。有关 bfloat16 方法的更多信息,请参见下文。
“中”,如果内部使用该数据类型的快速矩阵乘法算法可用,则 float32 矩阵乘法内部计算使用 bfloat16 数据类型(8 位尾数,显式存储 7 位)。否则,float32 矩阵乘法的计算方式与精度为“高”时相同。
当使用“高”精度时,float32 乘法可能会使用基于 bfloat16 的算法,该算法比简单地截断为较小的尾数位数(例如,TensorFloat32 为 10,显式存储的 bfloat16 为 7)更复杂。有关此算法的完整描述,请参阅 [Henry2019]。在此简要解释一下,第一步是意识到我们可以将单个 float32 数字完美编码为三个 bfloat16 数字的总和(因为 float32 有 23 位尾数,而 bfloat16 显式存储 7 位,并且两者都具有相同数量的指数位)。这意味着两个 float32 数字的乘积可以精确地由九个 bfloat16 数字的乘积之和给出。然后,我们可以通过丢弃一些乘积来权衡精度以提高速度。“高”精度算法专门仅保留三个最重要的乘积,这方便地排除了涉及任一输入的最后 8 位尾数的所有乘积。这意味着我们可以将输入表示为两个 bfloat16 数字的总和,而不是三个。由于 bfloat16 融合乘加 (FMA) 指令通常比 float32 指令快 10 倍以上,因此使用 bfloat16 精度进行三次乘法和两次加法比使用 float32 精度进行单次乘法更快。
注意
这不会更改 float32 矩阵乘法的输出 dtype,它控制矩阵乘法的内部计算方式。
注意
这不会更改卷积运算的精度。其他标志,如 torch.backends.cudnn.allow_tf32,可能会控制卷积运算的精度。
注意
此标志目前仅影响一种原生设备类型:CUDA。“高”或“中”设置时,TensorFloat32 数据类型将用于计算 float32 矩阵乘法,等效于设置 torch.backends.cuda.matmul.allow_tf32 = True。“最高”(默认)设置时,float32 数据类型用于内部计算,等效于设置 torch.backends.cuda.matmul.allow_tf32 = False。
- 参数
precision (str) – 可以设置为 “highest”(默认)、“high” 或 “medium”(见上文)。