torch.lu¶
- torch.lu(*args, **kwargs)[源代码]¶
计算矩阵或批次矩阵
A
的 LU 分解。返回一个包含A
的 LU 分解和主元的元组。如果pivot
设置为True
,则执行主元选择。警告
torch.lu()
已被torch.linalg.lu_factor()
和torch.linalg.lu_factor_ex()
弃用。torch.lu()
将在未来的 PyTorch 版本中移除。LU, pivots, info = torch.lu(A, compute_pivots)
应替换为LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)
应替换为LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
注意
批次中每个矩阵返回的置换矩阵由一个 1-indexed (起始索引为 1) 的向量表示,其大小为
min(A.shape[-2], A.shape[-1])
。pivots[i] == j
表示在算法的第i
步,第i
行与第j-1
行进行了置换。对于 CPU,不支持
pivot
=False
的 LU 分解,尝试这样做将抛出错误。但是,对于 CUDA,支持pivot
=False
的 LU 分解。如果
get_infos
为True
,此函数不会检查分解是否成功,因为分解的状态存在于返回元组的第三个元素中。对于 CUDA 设备上大小小于或等于 32 的批次方阵,由于 MAGMA 库中的错误 (参见 magma issue 13),会为奇异矩阵重复执行 LU 分解。
L
、U
和P
可以使用torch.lu_unpack()
派生得到。
警告
此函数的梯度仅在
A
是满秩时为有限值。这是因为 LU 分解仅在满秩矩阵处可微。此外,如果A
接近非满秩,则梯度将由于依赖于 和 的计算而导致数值不稳定。- 参数
- 返回
一个 tensor 元组,包含
factorization (Tensor):分解结果,大小为
pivots (IntTensor):主元,大小为 。
pivots
存储了所有中间的行转置。可以通过对i = 0, ..., pivots.size(-1) - 1
应用swap(perm[i], perm[pivots[i] - 1])
来重构最终的置换perm
,其中perm
最初是 个元素的单位置换(这与torch.lu_unpack()
所做的事情基本相同)。infos (IntTensor, 可选):如果
get_infos
为True
,这是一个大小为 的 tensor,其中非零值表示矩阵或每个 mini-batch 的分解是成功还是失败。
- 返回类型
(Tensor, IntTensor, IntTensor (可选))
示例
>>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = torch.lu(A) >>> A_LU tensor([[[ 1.3506, 2.5558, -0.0816], [ 0.1684, 1.1551, 0.1940], [ 0.1193, 0.6189, -0.5497]], [[ 0.4526, 1.2526, -0.3285], [-0.7988, 0.7175, -0.9701], [ 0.2634, -0.9255, -0.3459]]]) >>> pivots tensor([[ 3, 3, 3], [ 3, 3, 3]], dtype=torch.int32) >>> A_LU, pivots, info = torch.lu(A, get_infos=True) >>> if info.nonzero().size(0) == 0: ... print('LU factorization succeeded for all samples!') LU factorization succeeded for all samples!