模型集成¶
此示例说明如何使用 vmap 对模型集成进行矢量化。
什么是模型集成?¶
模型集成将多个模型的预测结果结合在一起。传统上,这是通过分别对每个模型运行一些输入,然后组合预测结果来完成的。但是,如果运行的是具有相同架构的模型,则可以使用 vmap
将它们组合在一起。vmap
是一种函数转换,可以跨输入张量的维度映射函数。它的一个用例是消除 for 循环并通过矢量化加速它们。
让我们演示如何使用多个简单的 MLP 集成来完成此操作。
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
torch.manual_seed(0);
# Here's a simple MLP
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = x.flatten(1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
让我们生成一批虚拟数据,并假设我们正在使用 MNIST 数据集。因此,虚拟图像为 28 x 28,我们有一个大小为 64 的小批次。此外,假设我们想要组合 10 个不同模型的预测结果。
device = 'cuda'
num_models = 10
data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)
models = [SimpleMLP().to(device) for _ in range(num_models)]
我们有几个生成预测结果的选项。也许我们想为每个模型提供不同的随机数据小批次。或者,也许我们想通过每个模型运行相同的数据小批次(例如,如果我们正在测试不同模型初始化的影响)。
选项 1:每个模型使用不同的 minibatch
minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]
选项 2:使用相同的 minibatch
minibatch = data[0]
predictions2 = [model(minibatch) for model in models]
使用 vmap 对集成进行矢量化¶
让我们使用 vmap 加速 for 循环。我们必须首先准备好在 vmap 中使用的模型。
首先,让我们通过堆叠每个参数将模型的状态组合在一起。例如,model[i].fc1.weight
的形状为 [784, 128]
;我们将堆叠 10 个模型中的每个模型的 .fc1.weight 以生成一个大小为 [10, 784, 128]
的大权重。
functorch 提供了“combine_state_for_ensemble”便捷函数来执行此操作。它返回模型的无状态版本 (fmodel) 以及堆叠的参数和缓冲区。
from functorch import combine_state_for_ensemble
fmodel, params, buffers = combine_state_for_ensemble(models)
[p.requires_grad_() for p in params];
选项 1:使用每个模型的不同小批次获取预测结果。
默认情况下,vmap 会跨传递给函数的所有输入的第一维映射函数。使用 combine_state_for_ensemble 后,每个参数和缓冲区在前面都会有一个大小为“num_models”的额外维度,并且 minibatches 会有一个大小为“num_models”的维度。
print([p.size(0) for p in params]) # show the leading 'num_models' dimension
assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'
[10, 10, 10, 10, 10, 10]
from functorch import vmap
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
# verify the vmap predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
选项 2:使用相同的数据小批次获取预测结果。
vmap 有一个 in_dims 参数,用于指定要映射的维度。通过使用 None
,我们告诉 vmap 我们希望 10 个模型都应用相同的小批次。
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)
快速说明:vmap 可以转换的函数类型存在一些限制。最佳转换函数是纯函数:输出仅由输入确定的函数,并且没有任何副作用(例如,突变)。vmap 无法处理任意 Python 数据结构的突变,但它可以处理许多就地 PyTorch 操作。
性能¶
想知道性能数据?以下是 Google Colab 上的数据。
from torch.utils.benchmark import Timer
without_vmap = Timer(
stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
globals=globals())
with_vmap = Timer(
stmt="vmap(fmodel)(params, buffers, minibatches)",
globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')
Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fe22c58b3d0>
[model(minibatch) for model, minibatch in zip(models, minibatches)]
3.25 ms
1 measurement, 100 runs , 1 thread
Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fe22c50c450>
vmap(fmodel)(params, buffers, minibatches)
879.28 us
1 measurement, 100 runs , 1 thread
使用 vmap 可以大幅提升速度!
通常,使用 vmap 进行矢量化应该比在 for 循环中运行函数更快,并且与手动批处理具有竞争力。但也有一些例外情况,例如,如果我们尚未为特定操作实现 vmap 规则,或者底层内核没有针对旧硬件(GPU)进行优化。如果您遇到任何这些情况,请通过在我们的 GitHub 上提交问题告知我们!