• 教程 >
  • 跳过模块参数初始化
快捷方式

跳过模块参数初始化

简介

创建模块时,其可学习参数会根据与模块类型关联的默认初始化方案进行初始化。例如,torch.nn.Linear 模块的 weight 参数会从 uniform(-1/sqrt(in_features), 1/sqrt(in_features)) 分布初始化。如果需要其他初始化方案,则传统上需要在模块实例化后重新初始化参数。

from torch import nn

# Initializes weight from the default distribution: uniform(-1/sqrt(10), 1/sqrt(10)).
m = nn.Linear(10, 5)

# Re-initialize weight from a different distribution.
nn.init.orthogonal_(m.weight)

在这种情况下,构造期间完成的初始化是浪费的计算,如果 weight 参数很大,则可能很麻烦。

跳过初始化

现在可以跳过模块构造期间的参数初始化,从而避免浪费计算。可以使用 torch.nn.utils.skip_init() 函数轻松实现这一点。

from torch import nn
from torch.nn.utils import skip_init

m = skip_init(nn.Linear, 10, 5)

# Example: Do custom, non-default parameter initialization.
nn.init.orthogonal_(m.weight)

这可以应用于满足下面更新模块以支持跳过初始化部分中描述的条件的任何模块。请注意,torch.nn 提供的所有模块都满足这些条件,因此支持跳过初始化。

更新模块以支持跳过初始化

由于 torch.nn.utils.skip_init() 的实现方式(请参阅下面的实现细节部分),模块必须满足两个要求才能与该函数兼容。只需遵守这些要求,您就可以选择加入对自定义模块的参数初始化跳过功能。

1. 模块必须在其构造函数中接受一个名为 device 的关键字参数,该参数会传递给构造过程中创建的任何参数或缓冲区。

2. 模块在其构造函数中不得对参数或缓冲区执行任何计算,除了初始化(即来自 torch.nn.init 的函数)。

以下示例演示了一个模块,它通过将 device 关键字参数传递给任何创建的参数、缓冲区或子模块来更新以支持该参数。

import torch
from torch import nn

class MyModule(torch.nn.Module):
  def __init__(self, foo, bar, device=None):
    super().__init__()

    # ==== Case 1: Module creates parameters directly. ====
    # Pass device along to any created parameters.
    self.param1 = nn.Parameter(torch.empty((foo, bar), device=device))
    self.register_parameter('param2', nn.Parameter(torch.empty(bar, device=device)))

    # To ensure support for the meta device, avoid using ops except those in
    # torch.nn.init on parameters in your module's constructor.
    with torch.no_grad():
        nn.init.kaiming_uniform_(self.param1)
        nn.init.uniform_(self.param2)


    # ==== Case 2: Module creates submodules. ====
    # Pass device along recursively. All submodules will need to support
    # them as well; this is the case for all torch.nn provided modules.
    self.fc = nn.Linear(bar, 5, device=device)

    # This also works with containers.
    self.linears = nn.Sequential(
        nn.Linear(5, 5, device=device),
        nn.Linear(5, 1, device=device)
    )


    # ==== Case 3: Module creates buffers. ====
    # Pass device along during buffer tensor creation.
    self.register_buffer('some_buffer', torch.ones(7, device=device))

...

实现细节

在幕后,torch.nn.utils.skip_init() 函数是根据两步模式实现的。

# 1. Initialize module on the meta device; all torch.nn.init ops have
# no-op behavior on the meta device.
m = nn.Linear(10, 5, device='meta')

# 2. Materialize an uninitialized (empty) form of the module on the CPU device.
# The result of this is a module instance with uninitialized parameters.
m.to_empty(device='cpu')

它通过将模块实例化到一个“元”设备上,该设备具有张量形状信息,但不分配任何存储空间。torch.nn.init 操作专门为这个元设备实现,因此它们具有无操作行为。这会导致参数初始化逻辑基本上被跳过。

请注意,此模式仅适用于在构造过程中正确支持 device 关键字参数的模块,如 更新模块以支持跳过初始化 中所述。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源