快捷方式

参数化列表

class torch.nn.utils.parametrize.ParametrizationList(modules, original, unsafe=False)[source]

一个顺序容器,它持有并管理已参数化的 torch.nn.Module 的原始参数或缓冲区。

它是 module.parametrizations[tensor_name] 的类型,其中 module[tensor_name] 已使用 register_parametrization() 参数化。

如果第一个注册的参数化具有一个返回一个张量的 right_inverse 或者没有 right_inverse(在这种情况下,我们假设 right_inverse 是恒等式),它将在名称为 original 下持有该张量。如果它具有一个返回多个张量的 right_inverse,这些张量将分别注册为 original0original1 等。

警告

此类在 register_parametrization() 内部使用。为了完整性,此处对其进行了文档记录。用户不应实例化此类。

参数
  • modules (sequence) – 表示参数化的模块序列

  • original (ParameterTensor) – 已参数化的参数或缓冲区

  • unsafe (bool) – 一个布尔标志,表示参数化是否可能会更改张量的 dtype 和形状。默认值:False 警告:注册时不会检查参数化的一致性。请自行承担使用此标志的风险。

right_inverse(value)[source]

以逆注册顺序调用参数化的 right_inverse 方法。

然后,如果 right_inverse 输出一个张量,它将结果存储在 self.original 中;如果它输出多个张量,则将结果存储在 self.original0self.original1 等中。

参数

value (Tensor) – 用于初始化模块的值

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源