快捷方式

ParametrizationList

class torch.nn.utils.parametrize.ParametrizationList(modules, original, unsafe=False)[source][source]

一个序列容器,用于持有和管理一个参数化 torch.nn.Module 的原始参数或缓冲区。

module[tensor_name] 使用 register_parametrization() 进行参数化后,它就是 module.parametrizations[tensor_name] 的类型。

如果第一个注册的参数化具有返回一个张量的 right_inverse,或者没有 right_inverse(在这种情况下,我们假定 right_inverse 是恒等函数),它将以 original 的名称持有该张量。如果它具有返回多个张量的 right_inverse,这些张量将注册为 original0, original1, …

警告

此类由 register_parametrization() 内部使用。在此处提供文档是为了完整性。用户不应实例化此类。

参数
  • modules (sequence) – 表示参数化的模块序列

  • original (Parameter or Tensor) – 被参数化的参数或缓冲区

  • unsafe (bool) – 一个布尔标志,表示参数化是否可能改变张量的数据类型 (dtype) 和形状 (shape)。默认值: False 警告: 注册时不会检查参数化的一致性。启用此标志风险自担。

right_inverse(value)[source][source]

按相反的注册顺序调用参数化的 right_inverse 方法。

然后,如果 right_inverse 输出一个张量,它将结果存储在 self.original 中;如果输出多个张量,则存储在 self.original0, self.original1, … 中。

参数

value (Tensor) – 用于初始化模块的值

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源