用户体验限制¶
torch.func,类似于 JAX,对可转换的内容有限制。通常,JAX 的限制在于转换仅适用于纯函数:即输出完全由输入决定且不涉及副作用(如修改)的函数。
我们也有类似的保证:我们的转换适用于纯函数。但是,我们确实支持某些就地操作。一方面,编写与函数转换兼容的代码可能需要更改您编写 PyTorch 代码的方式,另一方面,您可能会发现我们的转换让您可以表达以前在 PyTorch 中难以表达的内容。
一般限制¶
所有 torch.func 转换都存在一个限制,即函数不应向全局变量赋值。相反,函数的所有输出都必须从函数返回。此限制源于 torch.func 的实现方式:每个转换都会将张量输入包装在特殊的 torch.func 张量子类中,以促进转换。
因此,不要使用以下方式
import torch
from torch.func import grad
# Don't do this
intermediate = None
def f(x):
global intermediate
intermediate = x.sin()
z = intermediate.sin()
return z
x = torch.randn([])
grad_x = grad(f)(x)
请重写 f
以返回 intermediate
def f(x):
intermediate = x.sin()
z = intermediate.sin()
return z, intermediate
grad_x, intermediate = grad(f, has_aux=True)(x)
torch.autograd API¶
如果您尝试在由 vmap()
或 torch.func 的其中一个 AD 转换 (vjp()
、jvp()
、jacrev()
、jacfwd()
) 转换的函数内部使用 torch.autograd
API(如 torch.autograd.grad
或 torch.autograd.backward
),则转换可能无法对其进行转换。如果无法转换,您将收到错误消息。
这是 PyTorch 的 AD 支持实现方式中的一个基本设计限制,也是我们设计 torch.func 库的原因。请改用 torch.autograd
API 的 torch.func 等效项: - torch.autograd.grad
、Tensor.backward
-> torch.func.vjp
或 torch.func.grad
- torch.autograd.functional.jvp
-> torch.func.jvp
- torch.autograd.functional.jacobian
-> torch.func.jacrev
或 torch.func.jacfwd
- torch.autograd.functional.hessian
-> torch.func.hessian
vmap 限制¶
注意
vmap()
是我们限制最严格的转换。与梯度相关的转换 (grad()
、vjp()
、jvp()
) 没有这些限制。jacfwd()
(以及 hessian()
,它是使用 jacfwd()
实现的)是 vmap()
和 jvp()
的组合,因此它也存在这些限制。
vmap(func)
是一种转换,它返回一个函数,该函数将 func
映射到每个输入张量的一些新维度上。vmap 的心理模型是它类似于运行 for 循环:对于纯函数(即在没有副作用的情况下),vmap(f)(x)
等效于
torch.stack([f(x_i) for x_i in x.unbind(0)])
变异:Python 数据结构的任意变异¶
在存在副作用的情况下,vmap()
不会再像运行 for 循环一样。例如,以下函数
def f(x, list):
list.pop()
print("hello!")
return x.sum(0)
x = torch.randn(3, 1)
lst = [0, 1, 2, 3]
result = vmap(f, in_dims=(0, None))(x, lst)
将只打印一次“hello!”,并且只从 lst
中弹出单个元素。
vmap()
只执行一次 f
,因此所有副作用也只发生一次。
这是 vmap 实现方式的结果。torch.func 有一个特殊的内部 BatchedTensor 类。vmap(f)(*inputs)
获取所有 Tensor 输入,将其转换为 BatchedTensor,然后调用 f(*batched_tensor_inputs)
。BatchedTensor 覆盖了 PyTorch API,以便为每个 PyTorch 运算符产生批量(即矢量化)行为。
变异:就地 PyTorch 运算¶
您可能因为收到有关 vmap 不兼容的就地运算的错误而来到这里。vmap()
如果遇到不受支持的 PyTorch 就地运算,则会引发错误,否则会成功。不受支持的运算指的是会导致将元素更多的张量写入元素更少的张量的运算。以下是一个示例,说明这种情况是如何发生的
def f(x, y):
x.add_(y)
return x
x = torch.randn(1)
y = torch.randn(3, 1) # When vmapped over, looks like it has shape [1]
# Raises an error because `x` has fewer elements than `y`.
vmap(f, in_dims=(None, 0))(x, y)
x
是一个具有一个元素的张量,y
是一个具有三个元素的张量。x + y
有三个元素(由于广播),但尝试将三个元素写回 x
(它只有一个元素),由于尝试将三个元素写入只有一个元素的张量,因此会引发错误。
如果要写入的张量在 vmap()
下被批处理(即对其进行 vmap),则不会出现问题。
def f(x, y):
x.add_(y)
return x
x = torch.randn(3, 1)
y = torch.randn(3, 1)
expected = x + y
# Does not raise an error because x is being vmapped over.
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)
对此的一个常见解决方法是用其“new_*”等效项替换对工厂函数的调用。例如
用
Tensor.new_zeros()
替换torch.zeros()
用
Tensor.new_empty()
替换torch.empty()
要了解为什么这样做会有帮助,请考虑以下情况。
def diag_embed(vec):
assert vec.dim() == 1
result = torch.zeros(vec.shape[0], vec.shape[0])
result.diagonal().copy_(vec)
return result
vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
# RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...
vmap(diag_embed)(vecs)
在 vmap()
内部,result
是一个形状为 [3, 3] 的张量。但是,虽然 vec
看起来形状为 [3],但 vec
实际上具有底层形状 [2, 3]。无法将 vec
复制到 result.diagonal()
中,后者形状为 [3],因为它有太多元素。
def diag_embed(vec):
assert vec.dim() == 1
result = vec.new_zeros(vec.shape[0], vec.shape[0])
result.diagonal().copy_(vec)
return result
vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
vmap(diag_embed)(vecs)
用 Tensor.new_zeros()
替换 torch.zeros()
会使 result
具有形状为 [2, 3, 3] 的底层张量,因此现在可以将 vec
(具有底层形状 [2, 3])复制到 result.diagonal()
中。
变异:out= PyTorch 运算¶
vmap()
不支持 PyTorch 运算中的 out=
关键字参数。如果在代码中遇到该参数,它将优雅地出错。
这不是一个根本性的限制;从理论上讲,我们将来可以支持它,但我们目前选择不这样做。
数据相关的 Python 控制流¶
我们尚不支持在数据相关的控制流上使用 vmap
。数据相关的控制流是指 if 语句、while 循环或 for 循环的条件是一个正在进行 vmap
的张量。例如,以下操作将引发错误消息
def relu(x):
if x > 0:
return x
return 0
x = torch.randn(3)
vmap(relu)(x)
但是,任何不依赖于 vmap
张量中值的控制流都将起作用
def custom_dot(x):
if x.dim() == 1:
return torch.dot(x, x)
return (x * x).sum()
x = torch.randn(3)
vmap(custom_dot)(x)
JAX 支持使用特殊的控制流运算符(例如 jax.lax.cond
、jax.lax.while_loop
)在 数据相关的控制流 上进行转换。我们正在研究将这些运算符的等效项添加到 PyTorch 中。
数据相关的运算(.item())¶
我们不支持(将来也不会支持)在调用 .item()
对张量进行操作的用户定义函数上使用 vmap。例如,以下操作将引发错误消息
def f(x):
return x.item()
x = torch.randn(3)
vmap(f)(x)
请尝试重写您的代码,不要使用 .item()
调用。
您也可能会遇到有关使用 .item()
的错误消息,但您可能没有使用它。在这些情况下,PyTorch 可能在内部调用了 .item()
- 请在 GitHub 上提交问题,我们将修复 PyTorch 内部。
动态形状运算(nonzero 及其相关运算)¶
vmap(f)
要求应用于输入中每个“示例”的 f
返回具有相同形状的张量。不支持诸如 torch.nonzero
、torch.is_nonzero
之类的运算,因此会引发错误。
要了解原因,请考虑以下示例
xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)
torch.nonzero(xs[0])
返回形状为 2 的张量;但 torch.nonzero(xs[1])
返回形状为 1 的张量。我们无法构造单个张量作为输出;输出需要是不规则张量(而 PyTorch 还没有不规则张量的概念)。
随机性¶
用户在调用随机运算时的意图可能不清楚。具体来说,一些用户可能希望随机行为在批次之间保持相同,而另一些用户可能希望它在批次之间有所不同。为了解决这个问题,vmap
带有一个随机性标志。
该标志只能传递给 vmap,并且可以取三个值:“error”、“different”或“same”,默认为 error。在“error”模式下,对随机函数的任何调用都会产生一个错误,要求用户根据其用例使用其他两个标志之一。
在“different”随机性下,批次中的元素会产生不同的随机值。例如,
def add_noise(x):
y = torch.randn(()) # y will be different across the batch
return x + y
x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x) # we get 3 different values
在“same”随机性下,批次中的元素会产生相同的随机值。例如,
def add_noise(x):
y = torch.randn(()) # y will be the same across the batch
return x + y
x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x) # we get the same value, repeated 3 times
警告
我们的系统仅确定 PyTorch 运算符的随机性行为,无法控制其他库(如 numpy)的行为。这类似于 JAX 在其解决方案中存在的限制
注意
使用任一类型的支持随机性进行多次 vmap 调用不会产生相同的结果。与标准 PyTorch 一样,用户可以通过在 vmap 外部使用 torch.manual_seed()
或使用生成器来获得随机性可重复性。
注意
最后,我们的随机性与 JAX 不同,因为我们没有使用无状态 PRNG,部分原因是 PyTorch 不完全支持无状态 PRNG。相反,我们引入了一个标志系统,以允许我们看到的最常见的随机性形式。如果您的用例不适合这些随机性形式,请提交问题。