torch.linalg.householder_product¶
- torch.linalg.householder_product(A, tau, *, out=None) Tensor ¶
计算 Householder 矩阵乘积的前 n 列。
设 为 或 ,设 为一个矩阵,其列向量为 (对于 ,且 )。记 为将 的前 个分量置零、第 个分量设为 `1` 所得的向量。对于一个向量 (其中 ),本函数计算以下矩阵的前 列
其中 是 m 维单位矩阵,$b^{\text{H}}$ 在 $b$ 是复数时表示共轭转置,在 $b$ 是实数时表示转置。输出矩阵的大小与输入矩阵
A
相同。有关更多详细信息,请参见 Representation of Orthogonal or Unitary Matrices。
支持 float、double、cfloat 和 cdouble 数据类型的输入。也支持批处理矩阵输入,如果输入是批处理矩阵,则输出具有相同的批处理维度。
另请参阅
torch.geqrf()
可与本函数结合使用,从qr()
分解中形成 Q 矩阵。torch.ormqr()
是一个相关函数,用于计算 Householder 矩阵乘积与另一个矩阵的矩阵乘法。但是,该函数不支持自动求导。警告
只有当 时,梯度计算才是良好定义的。如果未满足此条件,不会抛出错误,但生成的梯度可能包含 NaN。
- 参数
- 关键字参数
out (Tensor, 可选) – 输出张量。如果为 None 则忽略。默认值:None。
- 抛出
RuntimeError – 如果
A
不满足 m >= n 的要求,或tau
不满足 n >= k 的要求。
示例
>>> A = torch.randn(2, 2) >>> h, tau = torch.geqrf(A) >>> Q = torch.linalg.householder_product(h, tau) >>> torch.dist(Q, torch.linalg.qr(A).Q) tensor(0.) >>> h = torch.randn(3, 2, 2, dtype=torch.complex128) >>> tau = torch.randn(3, 1, dtype=torch.complex128) >>> Q = torch.linalg.householder_product(h, tau) >>> Q tensor([[[ 1.8034+0.4184j, 0.2588-1.0174j], [-0.6853+0.7953j, 2.0790+0.5620j]], [[ 1.4581+1.6989j, -1.5360+0.1193j], [ 1.3877-0.6691j, 1.3512+1.3024j]], [[ 1.4766+0.5783j, 0.0361+0.6587j], [ 0.6396+0.1612j, 1.3693+0.4481j]]], dtype=torch.complex128)