functorch.combine_state_for_ensemble¶
-
functorch.
combine_state_for_ensemble
(models) → func, params, buffers[源代码]¶ 为使用
vmap()
集成准备一个 torch.nn.Modules 列表。给定一个包含
M
个相同类的nn.Modules
的列表,将所有参数和缓冲区堆叠在一起以创建params
和buffers
。结果中的每个参数和缓冲区将具有一个大小为M
的额外维度。combine_state_for_ensemble()
还返回func
,它是models
中的一个模型的功能版本。不能直接运行func(params, buffers, *args, **kwargs)
,可能需要使用vmap(func, ...)(params, buffers, *args, **kwargs)
以下是如何在一个非常简单的模型上进行集成的一个示例
num_models = 5 batch_size = 64 in_features, out_features = 3, 3 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] data = torch.randn(batch_size, 3) fmodel, params, buffers = combine_state_for_ensemble(models) output = vmap(fmodel, (0, 0, None))(params, buffers, data) assert output.shape == (num_models, batch_size, out_features)
警告
所有一起堆叠的模块必须相同(除了其参数/缓冲区的值)。例如,它们应该处于相同的模式(训练与评估)。
此 API 可能会发生变化——我们正在研究创建集成的更好方法,并希望获得您关于如何改进它的反馈。
警告
我们已将 functorch 集成到 PyTorch 中。作为集成的最后一步,从 PyTorch 2.0 开始,functorch.combine_state_for_ensemble 已弃用,并将在 PyTorch >= 2.3 的未来版本中删除。请改用 torch.func.stack_module_state;有关更多详细信息,请参阅 PyTorch 2.0 发行说明和/或 torch.func 迁移指南 https://pytorch.ac.cn/docs/master/func.migrating.html