快捷方式

torch.nn.utils.parametrizations.orthogonal

torch.nn.utils.parametrizations.orthogonal(module, name='weight', orthogonal_map=None, *, use_trivialization=True)[源代码]

将正交或酉参数化应用于矩阵或一批矩阵。

K\mathbb{K}R\mathbb{R}C\mathbb{C},参数化矩阵 QKm×nQ \in \mathbb{K}^{m \times n} 为 **正交**,因为

QHQ=Inif mnQQH=Imif m<n\begin{align*} Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} \end{align*}

其中 QHQ^{\text{H}} 是共轭转置,当 QQ 为复数时;是转置,当 QQ 为实数时。 In\mathrm{I}_nn 维单位矩阵。简单来说, QQ 的列向量在 mnm \geq n 时是正交的,否则是行向量是正交的。

如果张量维度超过二维,我们将其视为形状为 (…, m, n) 的矩阵批次。

矩阵 QQ 可以通过三种不同的 orthogonal_map 来参数化,这些参数化方式基于原始张量

  • "matrix_exp"/"cayley": matrix_exp() Q=exp(A)Q = \exp(A)Cayley 映射 Q=(In+A/2)(InA/2)1Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1} 应用于斜对称矩阵 AA 以生成一个正交矩阵。

  • "householder": 计算 Householder 反射器的乘积 (householder_product()).

"matrix_exp"/"cayley" 通常使参数化权重比 "householder" 收敛更快,但对于非常瘦或非常宽的矩阵,它们的计算速度较慢。

如果 use_trivialization=True(默认),参数化将实现“动态平凡化框架”,其中一个额外的矩阵 BKn×nB \in \mathbb{K}^{n \times n} 存储在 module.parametrizations.weight[0].base 下。这有助于参数化层的收敛,但会消耗一些额外的内存。请参见 Trivializations for Gradient-Based Optimization on Manifolds

QQ 的初始值:如果原始张量没有参数化并且 use_trivialization=True(默认),则 QQ 的初始值是原始张量的值,如果它是正交的(或在复数情况下是酉的),否则通过 QR 分解进行正交化(请参见 torch.linalg.qr())。当它没有参数化并且 orthogonal_map="householder" 即使 use_trivialization=False 时也是如此。否则,初始值是应用于原始张量的所有已注册参数化的组合的结果。

注意

此函数是使用 register_parametrization() 中的参数化功能实现的。

参数
  • module (nn.Module) – 要注册参数化的模块。

  • name (str, optional) – 要使其正交的张量的名称。默认值:"weight"

  • orthogonal_map (str, optional) – 以下之一:"matrix_exp", "cayley", "householder"。默认值:如果矩阵是方形或复数,则为 "matrix_exp",否则为 "householder"

  • use_trivialization (bool, optional) – 是否使用动态平凡化框架。默认值:True

返回值

原始模块,其中向指定权重注册了正交参数化

返回类型

模块

示例

>>> orth_linear = orthogonal(nn.Linear(20, 40))
>>> orth_linear
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
    (weight): ParametrizationList(
    (0): _Orthogonal()
    )
)
)
>>> Q = orth_linear.weight
>>> torch.dist(Q.T @ Q, torch.eye(20))
tensor(4.9332e-07)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源