torch.set_float32_matmul_precision¶
- torch.set_float32_matmul_precision(precision)[源代码]¶
设置 float32 矩阵乘法的内部精度。
以较低精度运行 float32 矩阵乘法可能会显著提高性能,并且在某些程序中,精度损失的影响可以忽略不计。
支持三种设置
“highest”,float32 矩阵乘法在内部计算中使用 float32 数据类型(24 位尾数,其中 23 位显式存储)。
“high”,float32 矩阵乘法使用 TensorFloat32 数据类型(显式存储 10 位尾数)或将每个 float32 数字视为两个 bfloat16 数字的总和(大约 16 位尾数,其中 14 位显式存储),如果相应的快速矩阵乘法算法可用。否则,float32 矩阵乘法将按“highest”精度计算。有关 bfloat16 方法的更多信息,请参见下文。
“medium”,float32 矩阵乘法在内部计算中使用 bfloat16 数据类型(8 位尾数,其中 7 位显式存储),如果使用该数据类型的快速矩阵乘法算法可用。否则,float32 矩阵乘法将按“high”精度计算。
使用“high”精度时,float32 乘法可能会使用比简单地截断到某些较小的数字尾数位(例如 TensorFloat32 为 10 位,显式存储的 bfloat16 为 7 位)更复杂的基于 bfloat16 的算法。有关此算法的完整描述,请参阅 [Henry2019]。简单解释一下,第一步是认识到我们可以将单个 float32 数字完美地编码为三个 bfloat16 数字的总和(因为 float32 有 23 位尾数,而 bfloat16 有 7 位显式存储,并且两者具有相同的指数位数)。这意味着两个 float32 数字的乘积可以由九个 bfloat16 数字乘积的总和精确给出。然后,我们可以通过删除一些乘积来以准确性换取速度。“high”精度算法专门保留三个最重要的乘积,这恰好排除了所有涉及任一输入最后 8 位尾数的乘积。这意味着我们可以将输入表示为两个 bfloat16 数字的总和,而不是三个。由于 bfloat16 融合乘加 (FMA) 指令通常比 float32 指令快 10 倍以上,因此使用 bfloat16 精度执行三个乘法和 2 次加法比使用 float32 精度执行一次乘法更快。
注意
这不会更改 float32 矩阵乘法的输出数据类型,它控制如何执行矩阵乘法的内部计算。
注意
这不会更改卷积运算的精度。其他标志(如 torch.backends.cudnn.allow_tf32)可能会控制卷积运算的精度。
注意
此标志目前仅影响一种原生设备类型:CUDA。如果设置了“high”或“medium”,则在计算 float32 矩阵乘法时将使用 TensorFloat32 数据类型,这等效于设置 torch.backends.cuda.matmul.allow_tf32 = True。当设置“highest”(默认值)时,内部计算将使用 float32 数据类型,这等效于设置 torch.backends.cuda.matmul.allow_tf32 = False。
- 参数
precision (str) – 可以设置为“highest”(默认值)、“high”或“medium”(请参见上文)。