快捷方式

torch.triangular_solve

torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)

求解具有方形上三角或下三角可逆矩阵 AA 和多个右侧项 bb 的方程组。

符号表示为求解 AX=bAX = b,并假设 AA 是方形上三角矩阵(如果 upper= False 则为下三角矩阵),且对角线上没有零。

torch.triangular_solve(b, A) 可以接受 2D 输入 b, A 或 2D 矩阵的批次输入。如果输入是批次,则返回批次的输出 X

如果 A 的对角线包含零或非常接近零的元素,并且 unitriangular= False(默认值),或者输入矩阵是病态的,则结果可能包含 NaN

支持 float, double, cfloat 和 cdouble 数据类型的输入。

警告

torch.triangular_solve() 已被弃用,推荐使用 torch.linalg.solve_triangular(),并将在未来的 PyTorch 版本中移除。torch.linalg.solve_triangular() 的参数顺序相反,并且不返回其中一个输入的副本。

X = torch.triangular_solve(B, A).solution 应该替换为

X = torch.linalg.solve_triangular(A, B)
参数
  • b (Tensor) – 多个右侧项,形状为 (,m,k)(*, m, k),其中 * 是零个或多个批处理维度

  • A (Tensor) – 输入的三角系数矩阵,形状为 (,m,m)(*, m, m),其中 * 是零个或多个批处理维度

  • upper (bool, optional) – 表示 AA 是上三角矩阵还是下三角矩阵。默认值:True

  • transpose (bool, optional) – 求解 op(A)X = b,其中如果此标志为 True,则 op(A) = A^T;如果为 False,则 op(A) = A。默认值:False

  • unitriangular (bool, optional) – 表示 AA 是否为单位三角矩阵。如果为 True,则假设 AA 的对角线元素为 1,并且不参考 AA 中的值。默认值:False

关键字参数

out ((Tensor, Tensor), optional) – 用于写入输出的两个 Tensor 的元组。如果为 None 则忽略。默认值:None

返回值

一个 namedtuple (solution, cloned_coefficient),其中 cloned_coefficientAA 的副本,而 solution 是方程组 AX=bAX = b 的解 XX(或根据关键字参数决定的方程组的变体)。

示例

>>> A = torch.randn(2, 2).triu()
>>> A
tensor([[ 1.1527, -1.0753],
        [ 0.0000,  0.7986]])
>>> b = torch.randn(2, 3)
>>> b
tensor([[-0.0210,  2.3513, -1.5492],
        [ 1.5429,  0.7403, -1.0243]])
>>> torch.triangular_solve(b, A)
torch.return_types.triangular_solve(
solution=tensor([[ 1.7841,  2.9046, -2.5405],
        [ 1.9320,  0.9270, -1.2826]]),
cloned_coefficient=tensor([[ 1.1527, -1.0753],
        [ 0.0000,  0.7986]]))

文档

查阅 PyTorch 开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源