注意
点击此处下载完整示例代码
nn.Module 中 load_state_dict
和 tensor 子类别的扩展点¶
创建日期:2024 年 4 月 19 日 | 最后更新:2024 年 4 月 19 日 | 最后验证:未验证
本文介绍了一个新的实用函数 torch.utils.swap_tensors
,以及它已集成到 nn.Module
中的两个新的扩展点
nn.Module.to()
及相关方法nn.Module.load_state_dict()
注意
本文需要 PyTorch 2.3.0 或更高版本。
torch.utils.swap_tensors
¶
torch.utils.swap_tensors
(下文简称 swap_tensors
)是一个实用函数,它接受两个 Python 张量并交换它们。
import torch
import torch.nn as nn
t1 = torch.arange(2)
t2 = torch.arange(3)
print(f"Before swapping, t1: {t1}, t2: {t2}")
torch.utils.swap_tensors(t1, t2)
print(f"After swapping, t1: {t1}, t2: {t2}")
Before swapping, t1: tensor([0, 1]), t2: tensor([0, 1, 2])
After swapping, t1: tensor([0, 1, 2]), t2: tensor([0, 1])
更具体地说,swap_tensors
交换两个张量的 Python __class__
、__dict__
和 __slots__
,以及它们关联的 at::Tensor
。
在 nn.Module
中的应用¶
当模块外部的 Python 对象持有模块参数的引用时,此实用程序与 nn.Module
相关。如果 nn.Module
修改其任何参数时不是在原地操作,则持有参数引用的对象将看不到更改。一个经典的例子是优化器,它持有 nn.Module
参数的引用。这会导致一个静默的正确性问题,即 optimizer.step()
会无错运行,但 nn.Module
的权重不会更新。
mod = torch.nn.Linear(1, 2, bias=False)
optimizer = torch.optim.SGD(mod.parameters())
print(f"weight in mod: {mod.weight}")
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
mod.weight = torch.nn.Parameter(2 * mod.weight)
print(f"weight in mod: {mod.weight}")
print(f"weight in optimizer: {optimizer.param_groups[0]['params']}")
weight in mod: Parameter containing:
tensor([[0.7645],
[0.8300]], requires_grad=True)
weight in optimizer: [Parameter containing:
tensor([[0.7645],
[0.8300]], requires_grad=True)]
weight in mod: Parameter containing:
tensor([[1.5291],
[1.6600]], requires_grad=True)
weight in optimizer: [Parameter containing:
tensor([[0.7645],
[0.8300]], requires_grad=True)]
nn.Module.load_state_dict()
¶
根据传递给 load_state_dict()
的关键字参数 assign
的值,有两种方式加载 state_dict
:
assign=False
:保留module.param
的属性,只获取state_dict['param_name']
的值assign=True
:保留state_dict['param_name']
的属性和值。
以前,这些是分别通过原地 copy_
和 __setattr__
实现的。现有实现中,每种方法都有其局限性——assign=False
强制要求 state_dict
中参数的类型必须与模块中参数的类型相同,而 assign=True
强制要求任何持有模块参数引用的对象必须在 nn.Module.load_state_dict()
之后初始化。
现在,我们通过向 load_state_dict()
添加 swap_tensors
路径并引入一个新的扩展点 torch.Tensor.module_load(self, other, assign=False)
来解决这两个约束。当通过上述 __future__
启用 swap_tensors
路径时,我们可以使用 module_load
的 __torch_function__
处理程序,对 state_dict
中的值应用自定义变换。此变换的结果将与模块中的参数交换。
在以下示例中,我们将使用上面定义的 MyQuantizedLinearWeight
子类别,以说明如何在加载 state_dict
时使用这些特性对线性层的权重应用自定义量化方案。
回想一下,如果 self
或 other
(在这种情况下是 param
或 state_dict[param_key]
)是 MyQuantizedLinearWeight
子类别,则会调用 module_load
的 __torch_function__
处理程序。
假设我们期望 state_dict
包含普通张量,并且模块包含 MyQuantizedLinearWeight
参数,我们希望将 state_dict
中的张量转换为子类别。那么我们可以如下定义 torch.Tensor.module_load
的 __torch_function__
处理程序:
@classmethod
def custom_torch_function(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if func is torch.Tensor.module_load:
dest, src = args[0], args[1]
assert type(dest) == cls and type(src) == torch.Tensor
return MyQuantizedLinearWeight(src, dest.scale)
else:
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
MyQuantizedLinearWeight.__torch_function__ = custom_torch_function
首先,让我们在 meta 设备上创建一个模型骨架,以避免具体化存储。我们将模块中的所有权重转换为 MyQuantizedLinearWeight
子类别,同时保留偏差不变。
def fn(m):
if isinstance(m, nn.Linear):
requires_grad = m.weight.requires_grad
m.weight = torch.nn.Parameter(
MyQuantizedLinearWeight(m.weight, 0.5), requires_grad=requires_grad
)
with torch.device("meta"):
m = nn.Linear(3, 5)
m.apply(fn)
然后我们可以加载 state_dict
。请注意,我们使用 assign=True
,因为对于偏差,我们希望保留 state_dict
中张量的属性(例如,我们不希望加载后偏差位于 meta
设备上)。
torch.__future__.set_swap_module_params_on_conversion(True)
print(f"Before: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
print(f"m.state_dict() before load_state_dict():\n {m.state_dict()}")
state_dict = nn.Linear(3, 5).state_dict()
print(f"state_dict:\n {state_dict}")
m.load_state_dict(state_dict, assign=True)
print(f"After: id(weight)={id(m.weight)}, id(bias)={id(m.bias)}")
print(f"m.state_dict() after load_state_dict():\n {m.state_dict()}")
Before: id(weight)=139963536330320, id(bias)=139963524974128
m.state_dict() before load_state_dict():
OrderedDict([('weight', MyQuantizedLinearWeight(tensor(..., device='meta', size=(5, 3)), scale=0.5)), ('bias', tensor(..., device='meta', size=(5,)))])
state_dict:
OrderedDict([('weight', tensor([[ 0.2430, 0.5155, 0.3337],
[-0.2524, 0.3333, 0.1033],
[ 0.2932, -0.3519, -0.5715],
[-0.2231, -0.4428, 0.4737],
[ 0.1663, 0.2391, 0.1826]])), ('bias', tensor([-0.0100, 0.4518, -0.4102, 0.0364, -0.3941]))])
After: id(weight)=139963536330320, id(bias)=139963524974128
m.state_dict() after load_state_dict():
OrderedDict([('weight', MyQuantizedLinearWeight(tensor([[ 0.2430, 0.5155, 0.3337],
[-0.2524, 0.3333, 0.1033],
[ 0.2932, -0.3519, -0.5715],
[-0.2231, -0.4428, 0.4737],
[ 0.1663, 0.2391, 0.1826]]), scale=0.5)), ('bias', tensor([-0.0100, 0.4518, -0.4102, 0.0364, -0.3941]))])
以上是如何使用 nn.Module.load_state_dict()
中新扩展点的一个玩具示例。我们还可以设想其他场景,例如 state_dict
中包含张量子类别,而模块中是普通 nn.Parameters
/张量,或者两者都是张量子类别。根据用例,我们可以为 module_load
定义 __torch_function__
处理程序,以根据需要应用变换。
结论¶
在本文中,我们学习了 swap_tensors
,了解了在 nn.Module
中保留参数引用的重要性,以及如何使用由 torch.__future__.set_swap_module_params_on_conversion
控制的两个新扩展点。
脚本总运行时间: ( 0 分 0.012 秒)