快捷方式

torch.nn.utils.skip_init

torch.nn.utils.skip_init(module_cls, *args, **kwargs)[源代码][源代码]

给定一个模块类对象以及 args / kwargs,在不初始化参数 / 缓冲区的情况下实例化该模块。

如果初始化速度慢,或者如果将执行自定义初始化,从而使得默认初始化不必要,这可能很有用。由于此函数的实现方式,存在一些注意事项:

1. 模块的构造函数必须接受一个 device 参数,该参数会传递给在构建过程中创建的任何参数或缓冲区。

2. 除了初始化(即来自 torch.nn.init 的函数)外,模块的构造函数不得对参数执行任何计算。

如果满足这些条件,则可以实例化参数 / 缓冲区值未初始化的模块,就像使用 torch.empty() 创建一样。

参数
  • module_cls – 类对象;应该是 torch.nn.Module 的子类。

  • args – 传递给模块构造函数的 args。

  • kwargs – 传递给模块构造函数的 kwargs。

返回

实例化后参数 / 缓冲区未初始化的模块。

示例

>>> import torch
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
>>> m.weight
Parameter containing:
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
       requires_grad=True)
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
>>> m2.weight
Parameter containing:
tensor([[-1.4677e+24,  4.5915e-41,  1.4013e-45,  0.0000e+00, -1.4677e+24,
          4.5915e-41]], requires_grad=True)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源