torch.nn.utils.parametrizations.spectral_norm¶
- torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[source][source]¶
将谱归一化应用于给定模块中的参数。
当应用于向量时,它简化为
谱归一化通过降低模型的 Lipschitz 常数,稳定了生成对抗网络 (GAN) 中判别器(评论家)的训练。每次访问权重时,通过执行一次幂迭代法来近似计算 。如果权重张量的维度大于 2,则在幂迭代方法中将其重塑为 2D 以获得谱范数。
参见 用于生成对抗网络的谱归一化 。
注意
此函数是使用
register_parametrization()
中的 parametrization 功能实现的。它是torch.nn.utils.spectral_norm()
的重新实现。注意
注册此约束后,将估计与最大奇异值相关联的奇异向量,而不是随机采样。然后,当在 training 模式下访问模块中的张量时,通过执行
n_power_iterations
次幂迭代法来更新这些奇异向量。注意
如果 _SpectralNorm 模块(即 module.parametrization.weight[idx])在移除时处于训练模式,它将执行另一次幂迭代。如果您想避免这次迭代,请在移除前将模块设置为评估模式。
- 参数
- 返回值
注册了新 parametrization 的原始模块
- 返回类型
示例
>>> snm = spectral_norm(nn.Linear(20, 40)) >>> snm ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() ) ) ) >>> torch.linalg.matrix_norm(snm.weight, 2) tensor(1.0081, grad_fn=<AmaxBackward0>)