注意
点击 此处 下载完整示例代码
模型集成¶
创建日期:2023 年 3 月 15 日 | 最后更新:2024 年 1 月 16 日 | 最后验证:2024 年 11 月 05 日
本教程演示了如何使用 torch.vmap
对模型集成进行向量化。
什么是模型集成?¶
模型集成将多个模型的预测结果结合在一起。传统上,这是通过分别对每个模型运行输入,然后组合预测结果来完成的。然而,如果运行的是具有相同架构的模型,则可能可以使用 torch.vmap
将它们组合在一起。vmap
是一种函数变换,它将函数映射到输入张量的维度上。其用例之一是消除 for 循环并通过向量化加速。
让我们使用简单的 MLP 集成来演示如何实现这一点。
注意
本教程需要 PyTorch 2.0.0 或更高版本。
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
# Here's a simple MLP
class SimpleMLP(nn.Module):
def __init__(self):
super(SimpleMLP, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = x.flatten(1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
让我们生成一批虚拟数据,并假设我们正在处理 MNIST 数据集。因此,虚拟图像是 28x28,我们有一个大小为 64 的 minibatch。此外,假设我们要组合来自 10 个不同模型的预测。
device = 'cuda'
num_models = 10
data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)
models = [SimpleMLP().to(device) for _ in range(num_models)]
我们有几种生成预测的选项。也许我们想为每个模型提供不同的随机 minibatch 数据。或者,也许我们想让每个模型都运行相同的 minibatch 数据(例如,如果我们正在测试不同模型初始化对结果的影响)。
选项 1:每个模型使用不同的 minibatch
minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]
选项 2:使用相同的 minibatch
使用 vmap
对集成进行向量化¶
让我们使用 vmap
来加速 for 循环。我们首先需要为使用 vmap
准备模型。
首先,通过堆叠每个参数来组合模型的状态。例如,model[i].fc1.weight
的形状是 [784, 128]
;我们将堆叠 10 个模型的 .fc1.weight
以生成形状为 [10, 784, 128]
的大权重。
PyTorch 提供了便捷函数 torch.func.stack_module_state
来完成此操作。
from torch.func import stack_module_state
params, buffers = stack_module_state(models)
接下来,我们需要定义一个要通过 vmap
进行映射的函数。该函数应该在给定参数、缓冲区和输入的情况下,使用这些参数、缓冲区和输入来运行模型。我们将使用 torch.func.functional_call
来提供帮助。
from torch.func import functional_call
import copy
# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')
def fmodel(params, buffers, x):
return functional_call(base_model, (params, buffers), (x,))
选项 1:为每个模型使用不同的 minibatch 获取预测。
默认情况下,vmap
会将函数映射到传递给函数的**所有输入**的第一个维度上。在使用 stack_module_state
后,params
和缓冲区中的每一个都在最前面增加了一个大小为“num_models”的维度,minibatches 也有一个大小为“num_models”的维度。
print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension
assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'
from torch import vmap
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
# verify the ``vmap`` predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
[10, 10, 10, 10, 10, 10]
选项 2:使用相同的 minibatch 数据获取预测。
vmap
有一个 in_dims
参数,用于指定要映射哪些维度。通过使用 None
,我们告诉 vmap
我们希望相同的 minibatch 应用于所有 10 个模型。
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)
快速注意:vmap
可以转换的函数类型存在一些限制。最适合转换的函数是纯函数:即输出仅由输入决定且没有副作用(例如修改)的函数。vmap
无法处理任意 Python 数据结构的修改,但能够处理许多原地 PyTorch 操作。
性能¶
对性能数据感到好奇吗?这是数据看起来的样子。
from torch.utils.benchmark import Timer
without_vmap = Timer(
stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
globals=globals())
with_vmap = Timer(
stmt="vmap(fmodel)(params, buffers, minibatches)",
globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')
Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fbbc7af4940>
[model(minibatch) for model, minibatch in zip(models, minibatches)]
1.43 ms
1 measurement, 100 runs , 1 thread
Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fbbc7ae01f0>
vmap(fmodel)(params, buffers, minibatches)
484.76 us
1 measurement, 100 runs , 1 thread
使用 vmap
带来了巨大的加速!
总的来说,使用 vmap
进行向量化应该比在 for 循环中运行函数更快,并且与手动批处理具有竞争力。不过也有一些例外情况,例如我们尚未为特定操作实现 vmap
规则,或者底层内核尚未针对较旧的硬件(GPU)进行优化。如果你遇到任何此类情况,请通过在 GitHub 上提交问题告知我们。
脚本总运行时间: ( 0 分钟 0.951 秒)