快捷方式

torch.linalg.lstsq

torch.linalg.lstsq(A, B, rcond=None, *, driver=None)

计算线性方程组最小二乘问题的解。

K\mathbb{K}R\mathbb{R}C\mathbb{C},线性系统 AX=BAX = B(其中 AKm×n,BKm×kA \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k})的**最小二乘问题**定义为

minXKn×kAXBF\min_{X \in \mathbb{K}^{n \times k}} \|AX - B\|_F

其中 F\|-\|_F 表示 Frobenius 范数。

支持 float, double, cfloat 和 cdouble 数据类型的输入。也支持矩阵的批量输入,如果输入是批量矩阵,则输出具有相同的批次维度。

driver 选择将使用的后端函数。对于 CPU 输入,有效值包括 ‘gels’, ‘gelsy’, ‘gelsd, ‘gelss’。选择最佳 CPU 驱动时考虑:

  • 如果 A 条件良好(其条件数不太大),或者您不介意一些精度损失。

    • 对于一般矩阵:‘gelsy’ (带主元旋转的 QR) (默认)

    • 如果 A 是满秩矩阵:‘gels’ (QR)

  • 如果 A 条件不好。

    • ‘gelsd’ (三对角化约简和 SVD)

    • 如果遇到内存问题:‘gelss’ (完全 SVD)。

对于 CUDA 输入,唯一有效的驱动是 ‘gels’,它假定 A 是满秩矩阵。

另请参阅这些驱动的完整描述

rcond 用于确定当 driver 是 (‘gelsy’, ‘gelsd’, ‘gelss’) 之一时,A 中矩阵的有效秩。在此情况下,如果 σi\sigma_i 是按降序排列的 A 的奇异值,则当 σircondσ1\sigma_i \leq \text{rcond} \cdot \sigma_1 时,σi\sigma_i 将被向下舍入为零。如果 rcond= None (默认),则 rcond 将设置为 A 的 dtype 的机器精度乘以 max(m, n)

此函数以包含四个 Tensor 的命名元组 (solution, residuals, rank, singular_values) 形式返回问题的解和一些额外信息。对于形状分别为 (*, m, n)(*, m, k) 的输入 AB,它包含:

  • solution: 最小二乘解。其形状为 (*, n, k)

  • residuals: 解的平方残差,即 AXBF2\|AX - B\|_F^2。其形状为 (*, k)。当 m > nA 中的每个矩阵都是满秩时计算此值,否则返回空 Tensor。如果 A 是批量矩阵,且批次中任何矩阵不是满秩,则返回空 Tensor。此行为在未来的 PyTorch 版本中可能会改变。

  • rank: A 中矩阵的秩 Tensor。其形状与 A 的批次维度相同。当 driver 是 (‘gelsy’, ‘gelsd’, ‘gelss’) 之一时计算此值,否则返回空 Tensor。

  • singular_values: A 中矩阵的奇异值 Tensor。其形状为 (*, min(m, n))。当 driver 是 (‘gelsd’, ‘gelss’) 之一时计算此值,否则返回空 Tensor。

注意

此函数以比单独计算更快且数值更稳定的方式计算 X = A.pinverse() @ B

警告

rcond 的默认值在未来的 PyTorch 版本中可能会改变。因此,建议使用固定值以避免潜在的破坏性更改。

参数
  • A (Tensor) – 形状为 (*, m, n) 的左侧 Tensor,其中 * 表示零个或多个批次维度。

  • B (Tensor) – 形状为 (*, m, k) 的右侧 Tensor,其中 * 表示零个或多个批次维度。

  • rcond (float, optional) – 用于确定 A 的有效秩。如果 rcond= None,则 rcond 将设置为 A 的 dtype 的机器精度乘以 max(m, n)。默认值: None

关键字参数

driver (str, optional) – 将使用的 LAPACK/MAGMA 方法的名称。如果 None,CPU 输入使用 ‘gelsy’,CUDA 输入使用 ‘gels’。默认值: None

返回值

一个命名元组 (solution, residuals, rank, singular_values)

示例

>>> A = torch.randn(1,3,3)
>>> A
tensor([[[-1.0838,  0.0225,  0.2275],
     [ 0.2438,  0.3844,  0.5499],
     [ 0.1175, -0.9102,  2.0870]]])
>>> B = torch.randn(2,3,3)
>>> B
tensor([[[-0.6772,  0.7758,  0.5109],
     [-1.4382,  1.3769,  1.1818],
     [-0.3450,  0.0806,  0.3967]],
    [[-1.3994, -0.1521, -0.1473],
     [ 1.9194,  1.0458,  0.6705],
     [-1.1802, -0.9796,  1.4086]]])
>>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3)
>>> torch.dist(X, torch.linalg.pinv(A) @ B)
tensor(1.5152e-06)

>>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values
>>> torch.dist(S, torch.linalg.svdvals(A))
tensor(2.3842e-07)

>>> A[:, 0].zero_()  # Decrease the rank of A
>>> rank = torch.linalg.lstsq(A, B).rank
>>> rank
tensor([2])

文档

访问 PyTorch 全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源