快捷方式

ParametrizationList

class torch.nn.utils.parametrize.ParametrizationList(modules, original, unsafe=False)[source][source]

一个顺序容器,用于保存和管理参数化的 torch.nn.Module 的原始参数或缓冲区。

module[tensor_name] 已使用 register_parametrization() 参数化时,它是 module.parametrizations[tensor_name] 的类型。

如果第一个注册的参数化具有返回一个张量的 right_inverse 或不具有 right_inverse(在这种情况下,我们假设 right_inverse 是恒等式),它将在名称 original 下保存张量。如果它具有返回多个张量的 right_inverse,这些张量将注册为 original0original1、……

警告

此类在内部由 register_parametrization() 使用。此处记录它是为了完整性。用户不得实例化它。

参数
  • modules (sequence) – 表示参数化的模块序列

  • original (ParameterTensor) – 参数化的参数或缓冲区

  • unsafe (bool) – 一个布尔标志,表示参数化是否可能更改张量的 dtype 和形状。默认值:False 警告:注册时不会检查参数化的一致性。启用此标志的风险由您自行承担。

right_inverse(value)[source][source]

以相反的注册顺序调用参数化的 right_inverse 方法。

然后,如果 right_inverse 输出一个张量,则将结果存储在 self.original 中;如果输出多个张量,则存储在 self.original0self.original1、……中。

参数

value (Tensor) – 用于初始化模块的值

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源