快捷方式

torch.set_float32_matmul_precision

torch.set_float32_matmul_precision(precision)[源代码][源代码]

设置 float32 矩阵乘法的内部精度。

以较低精度运行 float32 矩阵乘法可以显著提升性能,在某些程序中,精度的损失影响微乎其微。

支持三种设置:

  • “highest”,float32 矩阵乘法使用 float32 数据类型(尾数位 24 位,其中 23 位显式存储)进行内部计算。

  • “high”,如果可用的快速矩阵乘法算法支持,float32 矩阵乘法要么使用 TensorFloat32 数据类型(尾数位 10 位显式存储),要么将每个 float32 数字视为两个 bfloat16 数字之和(尾数位约 16 位,其中 14 位显式存储)。否则,float32 矩阵乘法将按照“highest”精度进行计算。有关 bfloat16 方法的更多信息,请参见下文。

  • “medium”,如果内部使用 bfloat16 数据类型的快速矩阵乘法算法可用,float32 矩阵乘法将使用 bfloat16 数据类型(尾数位 8 位,其中 7 位显式存储)进行内部计算。否则,float32 矩阵乘法将按照“high”精度进行计算。

使用“high”精度时,float32 乘法可能会使用基于 bfloat16 的算法,该算法比简单地截断到较少尾数位(例如 TensorFloat32 的 10 位,bfloat16 显式存储的 7 位)更复杂。有关此算法的完整描述,请参阅 [Henry2019]。在此简要解释一下,第一步是意识到我们可以将单个 float32 数字完美地编码为三个 bfloat16 数字之和(因为 float32 有 23 个尾数位,而 bfloat16 有 7 个显式存储位,并且两者具有相同的指数位数)。这意味着两个 float32 数字的乘积可以精确地表示为九个 bfloat16 数字乘积之和。然后,我们可以通过丢弃其中一些乘积来权衡精度和速度。“high”精度算法特别只保留了三个最重要的乘积,这方便地排除了涉及任一输入最后 8 个尾数位的所有乘积。这意味着我们可以将输入表示为两个 bfloat16 数字之和,而不是三个。由于 bfloat16 乘加融合 (FMA) 指令通常比 float32 指令快 10 倍以上,因此使用 bfloat16 精度进行三次乘法和两次加法比使用 float32 精度进行一次乘法更快。

Henry2019

http://arxiv.org/abs/1904.06376

注意

这不会改变 float32 矩阵乘法的输出数据类型(dtype),它控制的是矩阵乘法的内部计算方式。

注意

这不会改变卷积操作的精度。其他标志,例如 torch.backends.cudnn.allow_tf32,可能会控制卷积操作的精度。

注意

当前,此标志仅影响一种原生设备类型:CUDA。如果设置为“high”或“medium”,则在计算 float32 矩阵乘法时将使用 TensorFloat32 数据类型,这等同于设置 torch.backends.cuda.matmul.allow_tf32 = True。当设置为“highest”(默认值)时,内部计算使用 float32 数据类型,这等同于设置 torch.backends.cuda.matmul.allow_tf32 = False

参数

precision (str) – 可以设置为“highest”(默认)、“high”或“medium”(参见上文)。

文档

获取 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源