快捷方式

functorch.combine_state_for_ensemble

functorch.combine_state_for_ensemble(models)func, params, buffers[源代码]

为使用 vmap() 集成准备一个 torch.nn.Modules 列表。

给定一个包含 M 个相同类的 nn.Modules 的列表,将所有参数和缓冲区堆叠在一起以创建 paramsbuffers。结果中的每个参数和缓冲区将具有一个大小为 M 的额外维度。

combine_state_for_ensemble() 还返回 func,它是 models 中的一个模型的功能版本。不能直接运行 func(params, buffers, *args, **kwargs),可能需要使用 vmap(func, ...)(params, buffers, *args, **kwargs)

以下是如何在一个非常简单的模型上进行集成的一个示例

num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

fmodel, params, buffers = combine_state_for_ensemble(models)
output = vmap(fmodel, (0, 0, None))(params, buffers, data)

assert output.shape == (num_models, batch_size, out_features)

警告

所有一起堆叠的模块必须相同(除了其参数/缓冲区的值)。例如,它们应该处于相同的模式(训练与评估)。

此 API 可能会发生变化——我们正在研究创建集成的更好方法,并希望获得您关于如何改进它的反馈。

警告

我们已将 functorch 集成到 PyTorch 中。作为集成的最后一步,从 PyTorch 2.0 开始,functorch.combine_state_for_ensemble 已弃用,并将在 PyTorch >= 2.3 的未来版本中删除。请改用 torch.func.stack_module_state;有关更多详细信息,请参阅 PyTorch 2.0 发行说明和/或 torch.func 迁移指南 https://pytorch.ac.cn/docs/master/func.migrating.html

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源