快捷方式

torch.func.stack_module_state

torch.func.stack_module_state(models) params, buffers[source]

准备用于通过 vmap() 进行集成的 torch.nn.Module 列表。

给定一个包含 M 个相同类的 nn.Modules 的列表,返回两个字典,它们将所有模块的参数和缓冲区堆叠在一起,并按名称索引。堆叠的参数是可优化的(即,它们是 autograd 历史记录中的新叶节点,与原始参数无关,可以直接传递给优化器)。

以下是如何对一个非常简单的模型进行集成的示例

num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

def wrapper(params, buffers, data):
    return torch.func.functional_call(models[0], (params, buffers), data)

params, buffers = stack_module_state(models)
output = vmap(wrapper, (0, 0, None))(params, buffers, data)

assert output.shape == (num_models, batch_size, out_features)

当存在子模块时,这将遵循状态字典命名约定

import torch.nn as nn
class Foo(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        hidden = 4
        self.l1 = nn.Linear(in_features, hidden)
        self.l2 = nn.Linear(hidden, out_features)

    def forward(self, x):
        return self.l2(self.l1(x))

num_models = 5
in_features, out_features = 3, 3
models = [Foo(in_features, out_features) for i in range(num_models)]
params, buffers = stack_module_state(models)
print(list(params.keys()))  # "l1.weight", "l1.bias", "l2.weight", "l2.bias"

警告

所有堆叠在一起的模块必须相同(除了它们的参数/缓冲区的值)。例如,它们应该处于相同的模式(训练与评估)。

返回类型

元组[字典[字符串, 任意类型], 字典[字符串, 任意类型]]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源