• 教程 >
  • nn.Module 中 load_state_dict 和 tensor 子类别的扩展点
快捷方式

nn.Module 中 load_state_dict 和 tensor 子类别的扩展点

创建日期:2024 年 4 月 19 日 | 最后更新:2024 年 4 月 19 日 | 最后验证:未验证

作者: Mikayla Gawarecki

本文介绍了一个新的实用函数 torch.utils.swap_tensors,以及它已集成到 nn.Module 中的两个新的扩展点

  • nn.Module.to() 及相关方法

  • nn.Module.load_state_dict()

注意

本文需要 PyTorch 2.3.0 或更高版本。

torch.utils.swap_tensors

torch.utils.swap_tensors(下文简称 swap_tensors)是一个实用函数,它接受两个 Python 张量并交换它们。

import torch
import torch.nn as nn
t1 = torch.arange(2)
t2 = torch.arange(3)
print(f"Before swapping, t1: {t1}, t2: {t2}")
torch.utils.swap_tensors(t1, t2)
print(f"After swapping, t1: {t1}, t2: {t2}")
Before swapping, t1: tensor([0, 1]), t2: tensor([0, 1, 2])
After swapping, t1: tensor([0, 1, 2]), t2: tensor([0, 1])

更具体地说,swap_tensors 交换两个张量的 Python __class____dict____slots__,以及它们关联的 at::Tensor

nn.Module 中的应用

当模块外部的 Python 对象持有模块参数的引用时,此实用程序与 nn.Module 相关。如果 nn.Module 修改其任何参数时不是在原地操作,则持有参数引用的对象将看不到更改。一个经典的例子是优化器,它持有 nn.Module 参数的引用。这会导致一个静默的正确性问题,即 optimizer.step() 会无错运行,但 nn.Module 的权重不会更新。

mod = torch.nn.Linear(1, 2, bias=False)
optimizer = torch.optim.SGD(mod.parameters())
print(f"weight in mod: {mod.weight}")
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
mod.weight = torch.nn.Parameter(2 * mod.weight)
print(f"weight in mod: {mod.weight}")
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
weight in mod: Parameter containing:
tensor([[0.7645],
        [0.8300]], requires_grad=True)
weight in optimizer: [Parameter containing:
tensor([[0.7645],
        [0.8300]], requires_grad=True)]
weight in mod: Parameter containing:
tensor([[1.5291],
        [1.6600]], requires_grad=True)
weight in optimizer: [Parameter containing:
tensor([[0.7645],
        [0.8300]], requires_grad=True)]

nn.Module.load_state_dict()

根据传递给 load_state_dict() 的关键字参数 assign 的值,有两种方式加载 state_dict

  • assign=False:保留 module.param 的属性,只获取 state_dict['param_name'] 的值

  • assign=True:保留 state_dict['param_name'] 的属性和值。

以前,这些是分别通过原地 copy___setattr__ 实现的。现有实现中,每种方法都有其局限性——assign=False 强制要求 state_dict 中参数的类型必须与模块中参数的类型相同,而 assign=True 强制要求任何持有模块参数引用的对象必须在 nn.Module.load_state_dict() 之后初始化。

现在,我们通过向 load_state_dict() 添加 swap_tensors 路径并引入一个新的扩展点 torch.Tensor.module_load(self, other, assign=False) 来解决这两个约束。当通过上述 __future__ 启用 swap_tensors 路径时,我们可以使用 module_load__torch_function__ 处理程序,对 state_dict 中的值应用自定义变换。此变换的结果将与模块中的参数交换。

在以下示例中,我们将使用上面定义的 MyQuantizedLinearWeight 子类别,以说明如何在加载 state_dict 时使用这些特性对线性层的权重应用自定义量化方案。

回想一下,如果 selfother(在这种情况下是 paramstate_dict[param_key])是 MyQuantizedLinearWeight 子类别,则会调用 module_load__torch_function__ 处理程序。

假设我们期望 state_dict 包含普通张量,并且模块包含 MyQuantizedLinearWeight 参数,我们希望将 state_dict 中的张量转换为子类别。那么我们可以如下定义 torch.Tensor.module_load__torch_function__ 处理程序:

@classmethod
def custom_torch_function(cls, func, types, args=(), kwargs=None):
    kwargs = {} if kwargs is None else kwargs

    if func is torch.Tensor.module_load:
        dest, src = args[0], args[1]
        assert type(dest) == cls and type(src) == torch.Tensor
        return MyQuantizedLinearWeight(src, dest.scale)
    else:
        with torch._C.DisableTorchFunctionSubclass():
                return func(*args, **kwargs)

MyQuantizedLinearWeight.__torch_function__ = custom_torch_function

首先,让我们在 meta 设备上创建一个模型骨架,以避免具体化存储。我们将模块中的所有权重转换为 MyQuantizedLinearWeight 子类别,同时保留偏差不变。

def fn(m):
    if isinstance(m, nn.Linear):
        requires_grad = m.weight.requires_grad
        m.weight = torch.nn.Parameter(
                    MyQuantizedLinearWeight(m.weight, 0.5), requires_grad=requires_grad
                   )

with torch.device("meta"):
    m = nn.Linear(3, 5)
    m.apply(fn)

然后我们可以加载 state_dict。请注意,我们使用 assign=True,因为对于偏差,我们希望保留 state_dict 中张量的属性(例如,我们不希望加载后偏差位于 meta 设备上)。

torch.__future__.set_swap_module_params_on_conversion(True)
print(f"Before: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
print(f"m.state_dict() before load_state_dict():\n {m.state_dict()}")
state_dict = nn.Linear(3, 5).state_dict()
print(f"state_dict:\n {state_dict}")
m.load_state_dict(state_dict, assign=True)
print(f"After: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
print(f"m.state_dict() after load_state_dict():\n {m.state_dict()}")
Before: id(weight)=139963536330320, id(bias)=139963524974128
m.state_dict() before load_state_dict():
 OrderedDict([('weight', MyQuantizedLinearWeight(tensor(..., device='meta', size=(5, 3)), scale=0.5)), ('bias', tensor(..., device='meta', size=(5,)))])
state_dict:
 OrderedDict([('weight', tensor([[ 0.2430,  0.5155,  0.3337],
        [-0.2524,  0.3333,  0.1033],
        [ 0.2932, -0.3519, -0.5715],
        [-0.2231, -0.4428,  0.4737],
        [ 0.1663,  0.2391,  0.1826]])), ('bias', tensor([-0.0100,  0.4518, -0.4102,  0.0364, -0.3941]))])
After: id(weight)=139963536330320, id(bias)=139963524974128
m.state_dict() after load_state_dict():
 OrderedDict([('weight', MyQuantizedLinearWeight(tensor([[ 0.2430,  0.5155,  0.3337],
        [-0.2524,  0.3333,  0.1033],
        [ 0.2932, -0.3519, -0.5715],
        [-0.2231, -0.4428,  0.4737],
        [ 0.1663,  0.2391,  0.1826]]), scale=0.5)), ('bias', tensor([-0.0100,  0.4518, -0.4102,  0.0364, -0.3941]))])

以上是如何使用 nn.Module.load_state_dict() 中新扩展点的一个玩具示例。我们还可以设想其他场景,例如 state_dict 中包含张量子类别,而模块中是普通 nn.Parameters/张量,或者两者都是张量子类别。根据用例,我们可以为 module_load 定义 __torch_function__ 处理程序,以根据需要应用变换。

结论

在本文中,我们学习了 swap_tensors,了解了在 nn.Module 中保留参数引用的重要性,以及如何使用由 torch.__future__.set_swap_module_params_on_conversion 控制的两个新扩展点。

脚本总运行时间: ( 0 分 0.012 秒)

由 Sphinx-Gallery 生成

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获取问题解答

查看资源