注意
单击 此处 下载完整的示例代码
torch.vmap¶
本教程介绍 torch.vmap,它是 PyTorch 运算的自动矢量化器。torch.vmap 是一个原型功能,无法处理许多用例;但是,我们希望收集有关它的用例以告知设计。如果您正在考虑使用 torch.vmap 或认为它对于某些事情非常酷,请与我们联系 https://github.com/pytorch/pytorch/issues/42368.
那么,什么是 vmap?¶
vmap 是一个高阶函数。它接受一个函数 func 并返回一个新函数,该函数将 func 映射到输入的某个维度。它深受 JAX 的 vmap 启发。
从语义上讲,vmap 将“map”推入 func 调用的 PyTorch 运算,有效地矢量化这些运算。
import torch
# NB: vmap is only available on nightly builds of PyTorch.
# You can download one at pytorch.org if you're interested in testing it out.
from torch import vmap
vmap 的第一个用例是简化代码中的批处理维度处理。可以编写一个在示例上运行的函数 func,然后使用 vmap(func) 将其提升为一个可以接受示例批次的函数。然而,func 受到许多限制
它必须是函数式的(不能在其中修改 Python 数据结构),除了就地 PyTorch 运算。
示例批次必须以张量形式提供。这意味着 vmap 无法开箱即用地处理可变长度序列。
使用 vmap 的一个例子是计算批次点积。PyTorch 没有提供批次 torch.dot API;与其在文档中徒劳地搜索,不如使用 vmap 来构建一个新函数
torch.dot # [D], [D] -> []
batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
x, y = torch.randn(2, 5), torch.randn(2, 5)
batched_dot(x, y)
vmap 有助于隐藏批次维度,从而简化模型编写体验。
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)
# Note that model doesn't work with a batch of feature vectors because
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
# to use `torch.matmul` instead, but if we didn't want to do that or if
# the code is more complicated (e.g., does some advanced indexing
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
# inputs, unless otherwise specified (with the in_dims argument,
# please see the documentation for more details).
def model(feature_vec):
# Very simple linear model with activation
return feature_vec.dot(weights).relu()
examples = torch.randn(batch_size, feature_size)
result = torch.vmap(model)(examples)
expected = torch.stack([model(example) for example in examples.unbind()])
assert torch.allclose(result, expected)
vmap 还可以帮助矢量化以前难以或无法批处理的计算。这引出了我们的第二个用例:批次梯度计算。
PyTorch 自动梯度引擎计算 vjps(向量-雅可比积)。使用 vmap,我们可以计算(批次向量) - 雅可比积。
一个例子是计算完整的雅可比矩阵(这也可以应用于计算完整的 Hessian 矩阵)。对于某些函数 f: R^N -> R^N,计算完整的雅可比矩阵通常需要 N 次调用 autograd.grad,每个雅可比行一次。
# Setup
N = 5
def f(x):
return x ** 2
x = torch.randn(N, requires_grad=True)
y = f(x)
basis_vectors = torch.eye(N)
# Sequential approach
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
for v in basis_vectors.unbind()]
jacobian = torch.stack(jacobian_rows)
# Using `vmap`, we can vectorize the whole computation, computing the
# Jacobian in a single call to `autograd.grad`.
def get_vjp(v):
return torch.autograd.grad(y, x, v)[0]
jacobian_vmap = vmap(get_vjp)(basis_vectors)
assert torch.allclose(jacobian_vmap, jacobian)
vmap 的第三个主要用例是计算每个样本的梯度。这是 vmap 原型目前无法高效处理的事情。我们不确定计算每个样本梯度的 API 应该是什么,但如果您有任何想法,请在 https://github.com/pytorch/pytorch/issues/7786 中发表评论。
def model(sample, weight):
# do something...
return torch.dot(sample, weight)
def grad_sample(sample):
return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]
# The following doesn't actually work in the vmap prototype. But it
# could be an API for computing per-sample-gradients.
# batch_of_samples = torch.randn(64, 5)
# vmap(grad_sample)(batch_of_samples)
脚本总运行时间:(0 分钟 0.000 秒)