快捷方式

参数化教程

创建于:2021 年 4 月 19 日 | 最后更新:2024 年 2 月 05 日 | 最后验证:2024 年 11 月 05 日

作者Mario Lezcano

正则化深度学习模型是一项出乎意料的具有挑战性的任务。当应用于深度模型时,由于被优化函数的复杂性,经典的惩罚方法通常会失效。当处理病态模型时,这个问题尤其突出。这些模型的示例包括在长序列和 GAN 上训练的 RNN。近年来,已经提出了许多技术来正则化这些模型并提高其收敛性。在循环模型上,有人提出控制 RNN 循环核的奇异值以使其良好条件。例如,可以通过使循环核正交来实现这一点。正则化循环模型的另一种方法是通过“权重归一化”。这种方法建议将参数的学习与其范数的学习解耦。为此,参数除以其Frobenius 范数,并学习一个单独的参数来编码其范数。对于 GAN,在“谱归一化”的名称下提出了类似的正则化。该方法通过将其参数除以其谱范数而不是 Frobenius 范数来控制网络的 Lipschitz 常数。

所有这些方法都有一个共同的模式:它们都在使用参数之前以适当的方式转换参数。在第一种情况下,它们通过使用将矩阵映射到正交矩阵的函数使其正交。在权重和谱归一化的情况下,它们将原始参数除以其范数。

更一般地,所有这些示例都使用函数来在参数上添加额外的结构。换句话说,它们使用函数来约束参数。

在本教程中,您将学习如何实现和使用此模式来约束您的模型。这样做就像编写您自己的 nn.Module 一样简单。

要求:torch>=1.9.0

手动实现参数化

假设我们想要一个具有对称权重的方形线性层,即权重 X 使得 X = Xᵀ。一种方法是将矩阵的上三角部分复制到其下三角部分

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

def symmetric(X):
    return X.triu() + X.triu(1).transpose(-1, -2)

X = torch.rand(3, 3)
A = symmetric(X)
assert torch.allclose(A, A.T)  # A is symmetric
print(A)                       # Quick visual check
tensor([[0.8823, 0.9150, 0.3829],
        [0.9150, 0.3904, 0.6009],
        [0.3829, 0.6009, 0.9408]])

然后我们可以使用这个想法来实现具有对称权重的线性层

class LinearSymmetric(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(n_features, n_features))

    def forward(self, x):
        A = symmetric(self.weight)
        return x @ A

然后可以将该层用作常规线性层

这种实现虽然正确且自包含,但存在许多问题

  1. 它重新实现了该层。我们必须将线性层实现为 x @ A。对于线性层来说,这不是很成问题,但想象一下必须重新实现 CNN 或 Transformer……

  2. 它没有分离层和参数化。如果参数化更困难,我们将不得不为我们想要在其中使用它的每个层重写其代码。

  3. 它在我们每次使用该层时都会重新计算参数化。如果我们正向传递期间多次使用该层(想象一下 RNN 的循环核),它会在每次调用该层时计算相同的 A

参数化简介

参数化可以解决所有这些问题以及其他问题。

让我们首先使用 torch.nn.utils.parametrize 重新实现上面的代码。我们唯一需要做的是将参数化编写为常规 nn.Module

class Symmetric(nn.Module):
    def forward(self, X):
        return X.triu() + X.triu(1).transpose(-1, -2)

这就是我们需要做的全部。一旦我们有了这个,我们可以通过执行以下操作将任何常规层转换为对称层

ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
    )
  )
)

现在,线性层的矩阵是对称的

A = layer.weight
assert torch.allclose(A, A.T)  # A is symmetric
print(A)                       # Quick visual check
tensor([[ 0.2430,  0.5155,  0.3337],
        [ 0.5155,  0.3333,  0.1033],
        [ 0.3337,  0.1033, -0.5715]], grad_fn=<AddBackward0>)

我们可以对任何其他层做同样的事情。例如,我们可以创建一个具有斜对称内核的 CNN。我们使用类似的参数化,将带有符号反转的上三角部分复制到下三角部分

class Skew(nn.Module):
    def forward(self, X):
        A = X.triu(1)
        return A - A.transpose(-1, -2)


cnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3)
parametrize.register_parametrization(cnn, "weight", Skew())
# Print a few kernels
print(cnn.weight[0, 1])
print(cnn.weight[2, 2])
tensor([[ 0.0000,  0.0457, -0.0311],
        [-0.0457,  0.0000, -0.0889],
        [ 0.0311,  0.0889,  0.0000]], grad_fn=<SelectBackward0>)
tensor([[ 0.0000, -0.1314,  0.0626],
        [ 0.1314,  0.0000,  0.1280],
        [-0.0626, -0.1280,  0.0000]], grad_fn=<SelectBackward0>)

检查参数化模块

当模块被参数化时,我们发现模块发生了三个方面的变化

  1. model.weight 现在是一个属性

  2. 它有一个新的 module.parametrizations 属性

  3. 未参数化的权重已移动到 module.parametrizations.weight.original


在参数化 weight 之后,layer.weight 变成了一个 Python 属性。此属性每次我们请求 layer.weight 时都会计算 parametrization(weight),就像我们在上面实现的 LinearSymmetric 中所做的那样。

注册的参数化存储在模块内的 parametrizations 属性下。

layer = nn.Linear(3, 3)
print(f"Unparametrized:\n{layer}")
parametrize.register_parametrization(layer, "weight", Symmetric())
print(f"\nParametrized:\n{layer}")
Unparametrized:
Linear(in_features=3, out_features=3, bias=True)

Parametrized:
ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
    )
  )
)

parametrizations 属性是一个 nn.ModuleDict,可以像这样访问它

ModuleDict(
  (weight): ParametrizationList(
    (0): Symmetric()
  )
)
ParametrizationList(
  (0): Symmetric()
)

nn.ModuleDict 的每个元素都是一个 ParametrizationList,其行为类似于 nn.Sequential。此列表将允许我们在一个权重上连接参数化。由于这是一个列表,我们可以通过索引访问参数化。这是我们的 Symmetric 参数化所在的位置

Symmetric()

我们注意到的另一件事是,如果我们打印参数,我们会看到参数 weight 已被移动

print(dict(layer.named_parameters()))
{'bias': Parameter containing:
tensor([-0.0730, -0.2283,  0.3217], requires_grad=True), 'parametrizations.weight.original': Parameter containing:
tensor([[-0.4328,  0.3425,  0.4643],
        [ 0.0937, -0.1005, -0.5348],
        [-0.2103,  0.1470,  0.2722]], requires_grad=True)}

它现在位于 layer.parametrizations.weight.original

Parameter containing:
tensor([[-0.4328,  0.3425,  0.4643],
        [ 0.0937, -0.1005, -0.5348],
        [-0.2103,  0.1470,  0.2722]], requires_grad=True)

除了这三个小的差异之外,参数化所做的事情与我们的手动实现完全相同

tensor(0., grad_fn=<DistBackward0>)

参数化是一等公民

由于 layer.parametrizations 是一个 nn.ModuleList,这意味着参数化被正确注册为原始模块的子模块。因此,模块中注册参数的相同规则适用于注册参数化。例如,如果参数化具有参数,则在调用 model = model.cuda() 时,这些参数将从 CPU 移动到 CUDA。

缓存参数化的值

参数化通过上下文管理器 parametrize.cached() 附带内置缓存系统

class NoisyParametrization(nn.Module):
    def forward(self, X):
        print("Computing the Parametrization")
        return X

layer = nn.Linear(4, 4)
parametrize.register_parametrization(layer, "weight", NoisyParametrization())
print("Here, layer.weight is recomputed every time we call it")
foo = layer.weight + layer.weight.T
bar = layer.weight.sum()
with parametrize.cached():
    print("Here, it is computed just the first time layer.weight is called")
    foo = layer.weight + layer.weight.T
    bar = layer.weight.sum()
Computing the Parametrization
Here, layer.weight is recomputed every time we call it
Computing the Parametrization
Computing the Parametrization
Computing the Parametrization
Here, it is computed just the first time layer.weight is called
Computing the Parametrization

连接参数化

连接两个参数化就像在同一张量上注册它们一样简单。我们可以使用它从更简单的参数化创建更复杂的参数化。例如,Cayley 映射将斜对称矩阵映射到正行列式的正交矩阵。我们可以连接 Skew 和实现 Cayley 映射的参数化,以获得具有正交权重的层

class CayleyMap(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.register_buffer("Id", torch.eye(n))

    def forward(self, X):
        # (I + X)(I - X)^{-1}
        return torch.linalg.solve(self.Id - X, self.Id + X)

layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Skew())
parametrize.register_parametrization(layer, "weight", CayleyMap(3))
X = layer.weight
print(torch.dist(X.T @ X, torch.eye(3)))  # X is orthogonal
tensor(2.8527e-07, grad_fn=<DistBackward0>)

这也可以用于修剪参数化模块或重用参数化。例如,矩阵指数将对称矩阵映射到对称正定 (SPD) 矩阵,但矩阵指数也将斜对称矩阵映射到正交矩阵。利用这两个事实,我们可以重用之前的参数化来发挥我们的优势

class MatrixExponential(nn.Module):
    def forward(self, X):
        return torch.matrix_exp(X)

layer_orthogonal = nn.Linear(3, 3)
parametrize.register_parametrization(layer_orthogonal, "weight", Skew())
parametrize.register_parametrization(layer_orthogonal, "weight", MatrixExponential())
X = layer_orthogonal.weight
print(torch.dist(X.T @ X, torch.eye(3)))         # X is orthogonal

layer_spd = nn.Linear(3, 3)
parametrize.register_parametrization(layer_spd, "weight", Symmetric())
parametrize.register_parametrization(layer_spd, "weight", MatrixExponential())
X = layer_spd.weight
print(torch.dist(X, X.T))                        # X is symmetric
print((torch.linalg.eigvalsh(X) > 0.).all())  # X is positive definite
tensor(1.9066e-07, grad_fn=<DistBackward0>)
tensor(4.2147e-08, grad_fn=<DistBackward0>)
tensor(True)

初始化参数化

参数化附带一种初始化它们的机制。如果我们实现一个签名为 right_inverse 的方法

def right_inverse(self, X: Tensor) -> Tensor

它将在分配给参数化张量时使用。

让我们升级我们的 Skew 类的实现以支持这一点

class Skew(nn.Module):
    def forward(self, X):
        A = X.triu(1)
        return A - A.transpose(-1, -2)

    def right_inverse(self, A):
        # We assume that A is skew-symmetric
        # We take the upper-triangular elements, as these are those used in the forward
        return A.triu(1)

我们现在可以初始化一个用 Skew 参数化的层

layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Skew())
X = torch.rand(3, 3)
X = X - X.T                             # X is now skew-symmetric
layer.weight = X                        # Initialize layer.weight to be X
print(torch.dist(layer.weight, X))      # layer.weight == X
tensor(0., grad_fn=<DistBackward0>)

当我们连接参数化时,此 right_inverse 按预期工作。为了看到这一点,让我们升级 Cayley 参数化以也支持初始化

class CayleyMap(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.register_buffer("Id", torch.eye(n))

    def forward(self, X):
        # Assume X skew-symmetric
        # (I + X)(I - X)^{-1}
        return torch.linalg.solve(self.Id - X, self.Id + X)

    def right_inverse(self, A):
        # Assume A orthogonal
        # See https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
        # (A - I)(A + I)^{-1}
        return torch.linalg.solve(A + self.Id, self.Id - A)

layer_orthogonal = nn.Linear(3, 3)
parametrize.register_parametrization(layer_orthogonal, "weight", Skew())
parametrize.register_parametrization(layer_orthogonal, "weight", CayleyMap(3))
# Sample an orthogonal matrix with positive determinant
X = torch.empty(3, 3)
nn.init.orthogonal_(X)
if X.det() < 0.:
    X[0].neg_()
layer_orthogonal.weight = X
print(torch.dist(layer_orthogonal.weight, X))  # layer_orthogonal.weight == X
tensor(2.2141, grad_fn=<DistBackward0>)

此初始化步骤可以更简洁地写成

layer_orthogonal.weight = nn.init.orthogonal_(layer_orthogonal.weight)

此方法的名称来自我们通常期望 forward(right_inverse(X)) == X 的事实。这是重写初始值为 X 的初始化后的正向传递应返回值 X 的直接方法。实际上,此约束并未得到强烈执行。事实上,有时,放松这种关系可能是有意义的。例如,考虑以下随机修剪方法的实现

class PruningParametrization(nn.Module):
    def __init__(self, X, p_drop=0.2):
        super().__init__()
        # sample zeros with probability p_drop
        mask = torch.full_like(X, 1.0 - p_drop)
        self.mask = torch.bernoulli(mask)

    def forward(self, X):
        return X * self.mask

    def right_inverse(self, A):
        return A

在这种情况下,对于每个矩阵 A,forward(right_inverse(A)) == A 并不成立。这仅在矩阵 A 的零位置与掩码相同的情况下才成立。即使那样,如果我们将张量分配给修剪后的参数,张量实际上会被修剪也就不足为奇了

layer = nn.Linear(3, 4)
X = torch.rand_like(layer.weight)
print(f"Initialization matrix:\n{X}")
parametrize.register_parametrization(layer, "weight", PruningParametrization(layer.weight))
layer.weight = X
print(f"\nInitialized weight:\n{layer.weight}")
Initialization matrix:
tensor([[0.3513, 0.3546, 0.7670],
        [0.2533, 0.2636, 0.8081],
        [0.0643, 0.5611, 0.9417],
        [0.5857, 0.6360, 0.2088]])

Initialized weight:
tensor([[0.3513, 0.3546, 0.7670],
        [0.2533, 0.0000, 0.8081],
        [0.0643, 0.5611, 0.9417],
        [0.5857, 0.6360, 0.0000]], grad_fn=<MulBackward0>)

移除参数化

我们可以通过使用 parametrize.remove_parametrizations() 从模块中的参数或缓冲区中删除所有参数化

layer = nn.Linear(3, 3)
print("Before:")
print(layer)
print(layer.weight)
parametrize.register_parametrization(layer, "weight", Skew())
print("\nParametrized:")
print(layer)
print(layer.weight)
parametrize.remove_parametrizations(layer, "weight")
print("\nAfter. Weight has skew-symmetric values but it is unconstrained:")
print(layer)
print(layer.weight)
Before:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[ 0.0669, -0.3112,  0.3017],
        [-0.5464, -0.2233, -0.1125],
        [-0.4906, -0.3671, -0.0942]], requires_grad=True)

Parametrized:
ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Skew()
    )
  )
)
tensor([[ 0.0000, -0.3112,  0.3017],
        [ 0.3112,  0.0000, -0.1125],
        [-0.3017,  0.1125,  0.0000]], grad_fn=<SubBackward0>)

After. Weight has skew-symmetric values but it is unconstrained:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[ 0.0000, -0.3112,  0.3017],
        [ 0.3112,  0.0000, -0.1125],
        [-0.3017,  0.1125,  0.0000]], requires_grad=True)

移除参数化时,我们可以选择保留原始参数(即 layer.parametrizations.weight.original 中的参数),而不是通过设置标志 leave_parametrized=False 来保留其参数化版本

layer = nn.Linear(3, 3)
print("Before:")
print(layer)
print(layer.weight)
parametrize.register_parametrization(layer, "weight", Skew())
print("\nParametrized:")
print(layer)
print(layer.weight)
parametrize.remove_parametrizations(layer, "weight", leave_parametrized=False)
print("\nAfter. Same as Before:")
print(layer)
print(layer.weight)
Before:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[-0.3447, -0.3777,  0.5038],
        [ 0.2042,  0.0153,  0.0781],
        [-0.4640, -0.1928,  0.5558]], requires_grad=True)

Parametrized:
ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Skew()
    )
  )
)
tensor([[ 0.0000, -0.3777,  0.5038],
        [ 0.3777,  0.0000,  0.0781],
        [-0.5038, -0.0781,  0.0000]], grad_fn=<SubBackward0>)

After. Same as Before:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[ 0.0000, -0.3777,  0.5038],
        [ 0.0000,  0.0000,  0.0781],
        [ 0.0000,  0.0000,  0.0000]], requires_grad=True)

脚本的总运行时间: ( 0 分钟 0.055 秒)

图库由 Sphinx-Gallery 生成

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源