torch.linalg.lstsq¶
- torch.linalg.lstsq(A, B, rcond=None, *, driver=None)¶
计算线性方程组最小二乘问题的解。
设 为 或 ,线性系统 的 **最小二乘问题**,其中 定义为
其中 表示 Frobenius 范数。
支持 float、double、cfloat 和 cdouble 数据类型。也支持矩阵的批处理,如果输入是矩阵的批处理,则输出具有相同的批处理维度。
driver
选择将要使用的后端函数。对于 CPU 输入,有效值为 ‘gels’、‘gelsy’、‘gelsd’、‘gelss’。要选择 CPU 上的最佳驱动程序,请考虑如果
A
条件良好(其 条件数 不太大),或者您不介意一些精度损失。对于一般矩阵:‘gelsy’(带 pivoting 的 QR)(默认)
如果
A
是满秩的:‘gels’(QR)
如果
A
条件不好。‘gelsd’(三对角线约简和 SVD)
但是,如果您遇到内存问题:‘gelss’(完整 SVD)。
对于 CUDA 输入,唯一有效的驱动程序是 ‘gels’,它假定
A
是满秩的。另请参阅 这些驱动程序的完整描述
rcond
用于确定A
中矩阵的有效秩,当driver
为以下之一时:(‘gelsy’、‘gelsd’、‘gelss’)。在这种情况下,如果 是 A 的奇异值,按降序排列,则 将被舍入为零,如果 . 如果rcond
= None(默认),则rcond
设置为A
数据类型的机器精度乘以 max(m, n)。此函数返回问题的解以及一些额外的信息,这些信息包含在四个张量的命名元组 (solution, residuals, rank, singular_values) 中。对于形状分别为 (*, m, n)、(*, m, k) 的输入
A
、B
,它包含solution:最小二乘解。它的形状为 (*, n, k)。
residuals:解的平方残差,即 . 它的形状等于
A
的批处理维度。它在 m > n 且A
中的每个矩阵都是满秩时计算,否则是一个空张量。如果A
是一个矩阵批处理,并且批处理中的任何矩阵不是满秩,则返回一个空张量。此行为可能会在将来的 PyTorch 版本中更改。rank:
A
中矩阵的秩张量。它的形状等于A
的批处理维度。它在driver
为以下之一时计算:(‘gelsy’、‘gelsd’、‘gelss’),否则是一个空张量。singular_values:
A
中矩阵的奇异值张量。它的形状为 (*, min(m, n))。它在driver
为以下之一时计算:(‘gelsd’、‘gelss’),否则是一个空张量。
注意
此函数以比分别执行计算更快且更数值稳定的方式计算 X =
A
.pinverse() @B
。警告
rcond
的默认值可能会在将来的 PyTorch 版本中更改。因此建议使用固定值以避免潜在的重大更改。- 参数
- 关键字参数
driver (str, optional) – 要使用的 LAPACK/MAGMA 方法的名称。如果为 None,则对于 CPU 输入使用 ‘gelsy’,对于 CUDA 输入使用 ‘gels’。默认值:None。
- 返回
一个命名元组 (solution, residuals, rank, singular_values)。
示例
>>> A = torch.randn(1,3,3) >>> A tensor([[[-1.0838, 0.0225, 0.2275], [ 0.2438, 0.3844, 0.5499], [ 0.1175, -0.9102, 2.0870]]]) >>> B = torch.randn(2,3,3) >>> B tensor([[[-0.6772, 0.7758, 0.5109], [-1.4382, 1.3769, 1.1818], [-0.3450, 0.0806, 0.3967]], [[-1.3994, -0.1521, -0.1473], [ 1.9194, 1.0458, 0.6705], [-1.1802, -0.9796, 1.4086]]]) >>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3) >>> torch.dist(X, torch.linalg.pinv(A) @ B) tensor(1.5152e-06) >>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values >>> torch.dist(S, torch.linalg.svdvals(A)) tensor(2.3842e-07) >>> A[:, 0].zero_() # Decrease the rank of A >>> rank = torch.linalg.lstsq(A, B).rank >>> rank tensor([2])