快捷方式

torch.lu

torch.lu(*args, **kwargs)[源代码]

计算矩阵或批次矩阵 A 的 LU 分解。返回一个包含 A 的 LU 分解和主元的元组。如果 pivot 设置为 True,则执行主元选择。

警告

torch.lu() 已被 torch.linalg.lu_factor()torch.linalg.lu_factor_ex() 弃用。torch.lu() 将在未来的 PyTorch 版本中移除。LU, pivots, info = torch.lu(A, compute_pivots) 应替换为

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True) 应替换为

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

注意

  • 批次中每个矩阵返回的置换矩阵由一个 1-indexed (起始索引为 1) 的向量表示,其大小为 min(A.shape[-2], A.shape[-1])pivots[i] == j 表示在算法的第 i 步,第 i 行与第 j-1 行进行了置换。

  • 对于 CPU,不支持 pivot = False 的 LU 分解,尝试这样做将抛出错误。但是,对于 CUDA,支持 pivot = False 的 LU 分解。

  • 如果 get_infosTrue,此函数不会检查分解是否成功,因为分解的状态存在于返回元组的第三个元素中。

  • 对于 CUDA 设备上大小小于或等于 32 的批次方阵,由于 MAGMA 库中的错误 (参见 magma issue 13),会为奇异矩阵重复执行 LU 分解。

  • LUP 可以使用 torch.lu_unpack() 派生得到。

警告

此函数的梯度仅在 A 是满秩时为有限值。这是因为 LU 分解仅在满秩矩阵处可微。此外,如果 A 接近非满秩,则梯度将由于依赖于 L1L^{-1}U1U^{-1} 的计算而导致数值不稳定。

参数
  • A (Tensor) – 要分解的 tensor,大小为 (,m,n)(*, m, n)

  • pivot (bool, 可选) – 控制是否进行主元选择。默认值:True

  • get_infos (bool, 可选) – 如果设置为 True,则返回一个 info IntTensor。默认值:False

  • out (tuple, 可选) – 可选的输出元组。如果 get_infosTrue,则元组中的元素为 Tensor、IntTensor 和 IntTensor。如果 get_infosFalse,则元组中的元素为 Tensor、IntTensor。默认值:None

返回

一个 tensor 元组,包含

  • factorization (Tensor):分解结果,大小为 (,m,n)(*, m, n)

  • pivots (IntTensor):主元,大小为 (,min(m,n))(*, \text{min}(m, n))pivots 存储了所有中间的行转置。可以通过对 i = 0, ..., pivots.size(-1) - 1 应用 swap(perm[i], perm[pivots[i] - 1]) 来重构最终的置换 perm,其中 perm 最初是 mm 个元素的单位置换(这与 torch.lu_unpack() 所做的事情基本相同)。

  • infos (IntTensor, 可选):如果 get_infosTrue,这是一个大小为 ()(*) 的 tensor,其中非零值表示矩阵或每个 mini-batch 的分解是成功还是失败。

返回类型

(Tensor, IntTensor, IntTensor (可选))

示例

>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.lu(A)
>>> A_LU
tensor([[[ 1.3506,  2.5558, -0.0816],
         [ 0.1684,  1.1551,  0.1940],
         [ 0.1193,  0.6189, -0.5497]],

        [[ 0.4526,  1.2526, -0.3285],
         [-0.7988,  0.7175, -0.9701],
         [ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3,  3,  3],
        [ 3,  3,  3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0:
...     print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源