数值精度¶
在现代计算机中,浮点数使用 IEEE 754 标准表示。有关浮点算术和 IEEE 754 标准的更多详细信息,请参阅 浮点算术。特别是,请注意浮点数提供有限的精度(单精度浮点数约 7 位小数,双精度浮点数约 16 位小数),并且浮点加法和乘法不具有结合性,因此操作顺序会影响结果。正因为如此,PyTorch 不保证对于数学上相同的浮点计算会产生逐位相同的结果。同样,不保证在不同 PyTorch 版本、单个提交或不同平台之间产生逐位相同的结果。特别是,即使输入完全相同,并且即使控制了随机性来源,CPU 和 GPU 结果也可能不同。
批量计算或切片计算¶
PyTorch 中的许多操作支持批量计算,即对输入批次中的元素执行相同的操作。torch.mm()
和 torch.bmm()
就是这样的例子。尽管从数学上讲,批量计算可以通过循环遍历批次元素并将必要的数学运算应用于单个批次元素来实现,但出于效率原因,我们并未这样做,通常会对整个批次执行计算。在这种情况下,我们调用的数学库以及 PyTorch 内部操作实现与非批量计算相比,可能会产生略微不同的结果。特别是,设 A
和 B
是维度适合批量矩阵乘法的 3D 张量。那么 (A@B)[0]
(批量结果的第一个元素)不保证与 A[0]@B[0]
(输入批次的第一个元素的矩阵乘积)逐位相同,尽管数学上它们是相同的计算。
同样,应用于张量切片的操作不保证与应用于完整张量结果的切片产生相同的结果。例如,设 A
是一个 2 维张量。A.sum(-1)[0]
不保证与 A[:,0].sum()
逐位相同。
极值¶
当输入包含较大的值,导致中间结果可能超出所用数据类型的范围时,即使最终结果在原始数据类型中是可表示的,最终结果也可能溢出。例如:
import torch
a=torch.tensor([1e20, 1e20]) # fp32 type by default
a.norm() # produces tensor(inf)
a.double().norm() # produces tensor(1.4142e+20, dtype=torch.float64), representable in fp32
线性代数 (torch.linalg
)¶
非有限值¶
torch.linalg
使用的外部库(后端)不保证输入包含非有限值(如 inf
或 NaN
)时的行为。因此,PyTorch 也不提供此保证。操作可能返回包含非有限值的张量,或引发异常,甚至导致段错误(segfault)。
在调用这些函数之前,考虑使用 torch.isfinite()
来检测这种情况。
linalg 中的极值¶
torch.linalg
中的函数相比其他 PyTorch 函数更容易遇到 极值 问题。
求解器 和 求逆器 假定输入矩阵 A
是可逆的。如果它接近不可逆(例如,如果它具有非常小的奇异值),则这些算法可能会静默返回不正确的结果。这些矩阵被称为病态矩阵。如果提供病态输入,这些函数的结果在使用相同输入但在不同设备上或使用不同的 driver
关键字后端时可能会有所不同。
svd
、eig
和 eigh
等谱运算在其输入具有接近的奇异值时,也可能返回不正确的结果(并且它们的梯度可能为无穷大)。这是因为用于计算这些分解的算法对于这些输入难以收敛。
在 float64
中运行计算(NumPy 默认这样做)通常会有所帮助,但这并不能在所有情况下解决这些问题。通过 torch.linalg.svdvals()
分析输入的奇异值或通过 torch.linalg.cond()
分析其条件数可能有助于检测这些问题。
Nvidia Ampere(及后续)设备上的 TensorFloat-32(TF32)¶
在 Ampere(及后续)Nvidia GPU 上,PyTorch 可以使用 TensorFloat32 (TF32) 来加速计算密集型操作,特别是矩阵乘法和卷积。当使用 TF32 Tensor Core 执行操作时,只读取输入尾数的头 10 位。这可能会降低精度并产生令人惊讶的结果(例如,矩阵乘以单位矩阵可能产生与输入不同的结果)。默认情况下,矩阵乘法的 TF32 Tensor Core 是禁用的,而卷积的 TF32 Tensor Core 是启用的,尽管大多数神经网络工作负载在使用 TF32 时与使用 fp32 具有相同的收敛行为。如果您的网络不需要完整的 float32 精度,我们建议使用 torch.backends.cuda.matmul.allow_tf32 = True
启用矩阵乘法的 TF32 Tensor Core。如果您的网络对于矩阵乘法和卷积都需要完整的 float32 精度,则可以使用 torch.backends.cudnn.allow_tf32 = False
禁用卷积的 TF32 Tensor Core。
更多信息请参阅 TensorFloat32。
FP16 和 BF16 GEMM 的降精度归约¶
半精度 GEMM 操作通常在单精度下进行中间累加(归约),以提高数值精度并增强抗溢出能力。为了性能,某些 GPU 架构,特别是更新的架构,允许将中间累加结果截断到较低精度(例如半精度)几次。这种改变通常对模型收敛是良性的,尽管它可能导致意外结果(例如,最终结果应在半精度范围内但出现 inf
值)。如果降精度归约有问题,可以使用 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
将其关闭。
BF16 GEMM 操作也存在类似的标志,并且默认是开启的。如果 BF16 降精度归约有问题,可以使用 torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
将其关闭。
更多信息请参阅 allow_fp16_reduced_precision_reduction 和 allow_bf16_reduced_precision_reduction。
缩放点积注意力 (SDPA) 中 FP16 和 BF16 的降精度归约¶
一个朴素的 SDPA 数学后端在使用 FP16/BF16 输入时,由于使用低精度中间缓冲区,可能会累积显著的数值误差。为了缓解这个问题,默认行为现在涉及将 FP16/BF16 输入上转换为 FP32。计算在 FP32/TF32 中执行,然后将最终的 FP32 结果下转换回 FP16/BF16。这将提高使用 FP16/BF16 输入的数学后端最终输出的数值精度,但会增加内存使用量,并可能导致数学后端的性能下降,因为计算从 FP16/BF16 BMM 转移到 FP32/TF32 BMM/Matmul。
对于偏好速度而选择降精度归约的场景,可以使用以下设置启用它们:torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
AMD Instinct MI200 设备上的 FP16 和 BF16 GEMM 和卷积降精度¶
在 AMD Instinct MI200 GPU 上,FP16 和 BF16 的 V_DOT2 和 MFMA 矩阵指令会将输入和输出的次正规值刷新为零。FP32 和 FP64 的 MFMA 矩阵指令不会将输入和输出的次正规值刷新为零。受影响的指令仅由 rocBLAS (GEMM) 和 MIOpen (卷积) 内核使用;所有其他 PyTorch 操作都不会遇到这种行为。所有其他支持的 AMD GPU 都不会遇到这种行为。
rocBLAS 和 MIOpen 为受影响的 FP16 操作提供了备用实现。BF16 操作没有提供备用实现;BF16 数值的动态范围比 FP16 数值更大,遇到次正规值的可能性较低。对于 FP16 的备用实现,FP16 输入值被转换为中间 BF16 值,然后在累加 FP32 操作后转换回 FP16 输出。通过这种方式,输入和输出类型保持不变。
使用 FP16 精度进行训练时,某些模型可能会因为 FP16 次正规值刷新为零而无法收敛。次正规值更常出现在训练的反向传播过程中,即梯度计算期间。PyTorch 默认在反向传播过程中使用 rocBLAS 和 MIOpen 的备用实现。可以使用环境变量 ROCBLAS_INTERNAL_FP16_ALT_IMPL 和 MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL 覆盖默认行为。这些环境变量的行为如下:
正向传播 |
反向传播 |
|
---|---|---|
环境变量未设置 |
原始 |
备用 |
环境变量设置为 1 |
备用 |
备用 |
环境变量设置为 0 |
原始 |
原始 |
以下是可能使用 rocBLAS 的操作列表
torch.addbmm
torch.addmm
torch.baddbmm
torch.bmm
torch.mm
torch.nn.GRUCell
torch.nn.LSTMCell
torch.nn.Linear
torch.sparse.addmm
以下 torch._C._ConvBackend 实现
slowNd
slowNd_transposed
slowNd_dilated
slowNd_dilated_transposed
以下是可能使用 MIOpen 的操作列表
torch.nn.Conv[Transpose]Nd
以下 torch._C._ConvBackend 实现
ConvBackend::Miopen
ConvBackend::MiopenDepthwise
ConvBackend::MiopenTranspose