快捷方式

torch.set_float32_matmul_precision

torch.set_float32_matmul_precision(precision)[source][source]

设置 float32 矩阵乘法的内部精度。

以较低精度运行 float32 矩阵乘法可以显著提高性能,并且在某些程序中,精度损失的影响可以忽略不计。

支持三种设置

  • “最高”,float32 矩阵乘法内部计算使用 float32 数据类型(24 位尾数,显式存储 23 位)。

  • “高”,float32 矩阵乘法要么使用 TensorFloat32 数据类型(显式存储 10 位尾数),要么将每个 float32 数字视为两个 bfloat16 数字的总和(大约 16 位尾数,显式存储 14 位),如果可以使用适当的快速矩阵乘法算法。否则,float32 矩阵乘法的计算方式与精度为“最高”时相同。有关 bfloat16 方法的更多信息,请参见下文。

  • “中”,如果内部使用该数据类型的快速矩阵乘法算法可用,则 float32 矩阵乘法内部计算使用 bfloat16 数据类型(8 位尾数,显式存储 7 位)。否则,float32 矩阵乘法的计算方式与精度为“高”时相同。

当使用“高”精度时,float32 乘法可能会使用基于 bfloat16 的算法,该算法比简单地截断为较小的尾数位数(例如,TensorFloat32 为 10,显式存储的 bfloat16 为 7)更复杂。有关此算法的完整描述,请参阅 [Henry2019]。在此简要解释一下,第一步是意识到我们可以将单个 float32 数字完美编码为三个 bfloat16 数字的总和(因为 float32 有 23 位尾数,而 bfloat16 显式存储 7 位,并且两者都具有相同数量的指数位)。这意味着两个 float32 数字的乘积可以精确地由九个 bfloat16 数字的乘积之和给出。然后,我们可以通过丢弃一些乘积来权衡精度以提高速度。“高”精度算法专门仅保留三个最重要的乘积,这方便地排除了涉及任一输入的最后 8 位尾数的所有乘积。这意味着我们可以将输入表示为两个 bfloat16 数字的总和,而不是三个。由于 bfloat16 融合乘加 (FMA) 指令通常比 float32 指令快 10 倍以上,因此使用 bfloat16 精度进行三次乘法和两次加法比使用 float32 精度进行单次乘法更快。

Henry2019

http://arxiv.org/abs/1904.06376

注意

这不会更改 float32 矩阵乘法的输出 dtype,它控制矩阵乘法的内部计算方式。

注意

这不会更改卷积运算的精度。其他标志,如 torch.backends.cudnn.allow_tf32,可能会控制卷积运算的精度。

注意

此标志目前仅影响一种原生设备类型:CUDA。“高”或“中”设置时,TensorFloat32 数据类型将用于计算 float32 矩阵乘法,等效于设置 torch.backends.cuda.matmul.allow_tf32 = True。“最高”(默认)设置时,float32 数据类型用于内部计算,等效于设置 torch.backends.cuda.matmul.allow_tf32 = False

参数

precision (str) – 可以设置为 “highest”(默认)、“high” 或 “medium”(见上文)。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源