注意
点击此处下载完整的示例代码
torch.vmap¶
创建于:2020 年 10 月 26 日 | 最后更新:2021 年 9 月 1 日 | 最后验证:未验证
本教程介绍 torch.vmap,PyTorch 操作的自动向量化器。torch.vmap 是一个原型功能,无法处理许多用例;但是,我们希望收集其用例以告知设计。如果您正在考虑使用 torch.vmap 或认为它对某些事情非常有用,请通过 https://github.com/pytorch/pytorch/issues/42368 联系我们。
那么,什么是 vmap?¶
vmap 是一个高阶函数。它接受一个函数 func 并返回一个新函数,该函数将 func 映射到输入的某些维度上。它深受 JAX 的 vmap 的启发。
从语义上讲,vmap 将“map”推送到 func 调用的 PyTorch 操作中,从而有效地向量化这些操作。
import torch
# NB: vmap is only available on nightly builds of PyTorch.
# You can download one at pytorch.org if you're interested in testing it out.
from torch import vmap
vmap 的第一个用例是简化代码中批量维度的处理。可以编写一个在示例上运行的函数 func,然后使用 vmap(func) 将其提升为可以处理批量示例的函数。func 然而受到许多限制
它必须是函数式的(不能在其中改变 Python 数据结构),但 PyTorch 的原地操作除外。
批量示例必须以张量形式提供。这意味着 vmap 无法开箱即用地处理可变长度序列。
vmap 的一个用例是计算批量点积。PyTorch 不提供批量的 torch.dot API;与其徒劳地翻阅文档,不如使用 vmap 构建一个新函数
torch.dot # [D], [D] -> []
batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
x, y = torch.randn(2, 5), torch.randn(2, 5)
batched_dot(x, y)
vmap 可以帮助隐藏批量维度,从而简化模型编写体验。
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)
# Note that model doesn't work with a batch of feature vectors because
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
# to use `torch.matmul` instead, but if we didn't want to do that or if
# the code is more complicated (e.g., does some advanced indexing
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
# inputs, unless otherwise specified (with the in_dims argument,
# please see the documentation for more details).
def model(feature_vec):
# Very simple linear model with activation
return feature_vec.dot(weights).relu()
examples = torch.randn(batch_size, feature_size)
result = torch.vmap(model)(examples)
expected = torch.stack([model(example) for example in examples.unbind()])
assert torch.allclose(result, expected)
vmap 还可以帮助向量化以前难以或不可能批处理的计算。这使我们进入第二个用例:批量梯度计算。
PyTorch autograd 引擎计算 vjps(向量-雅可比积)。使用 vmap,我们可以计算(批量向量)- 雅可比积。
这方面的一个例子是计算完整的雅可比矩阵(这也适用于计算完整的海森矩阵)。对于某个函数 f: R^N -> R^N 计算完整的雅可比矩阵通常需要 N 次调用 autograd.grad,每个雅可比矩阵行一次。
# Setup
N = 5
def f(x):
return x ** 2
x = torch.randn(N, requires_grad=True)
y = f(x)
basis_vectors = torch.eye(N)
# Sequential approach
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
for v in basis_vectors.unbind()]
jacobian = torch.stack(jacobian_rows)
# Using `vmap`, we can vectorize the whole computation, computing the
# Jacobian in a single call to `autograd.grad`.
def get_vjp(v):
return torch.autograd.grad(y, x, v)[0]
jacobian_vmap = vmap(get_vjp)(basis_vectors)
assert torch.allclose(jacobian_vmap, jacobian)
vmap 的第三个主要用例是计算每样本梯度。这是 vmap 原型目前无法高效处理的事情。我们不确定计算每样本梯度的 API 应该是什么,但如果您有任何想法,请在 https://github.com/pytorch/pytorch/issues/7786 中评论。
def model(sample, weight):
# do something...
return torch.dot(sample, weight)
def grad_sample(sample):
return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]
# The following doesn't actually work in the vmap prototype. But it
# could be an API for computing per-sample-gradients.
# batch_of_samples = torch.randn(64, 5)
# vmap(grad_sample)(batch_of_samples)
脚本的总运行时间:(0 分钟 0.000 秒)