快捷方式

UX 限制

functorch 与 JAX 一样,对可转换的内容有一些限制。一般来说,JAX 的限制是变换只能用于纯函数:即输出完全由输入决定且不涉及副作用(如修改)的函数。

我们也有类似的保证:我们的变换适用于纯函数。但是,我们确实支持某些就地操作。一方面,编写与 functorch 变换兼容的代码可能需要更改您编写 PyTorch 代码的方式,另一方面,您可能会发现我们的变换可以让您表达以前在 PyTorch 中难以表达的内容。

一般限制

所有 functorch 变换都存在一个限制,即函数不应向全局变量赋值。相反,函数的所有输出都必须从函数中返回。此限制来自 functorch 的实现方式:每个变换都会将 Tensor 输入包装在特殊的 functorch Tensor 子类中,以促进变换。

因此,不要使用以下方式

import torch
from functorch import grad

# Don't do this
intermediate = None

def f(x):
  global intermediate
  intermediate = x.sin()
  z = intermediate.sin()
  return z

x = torch.randn([])
grad_x = grad(f)(x)

请重写 f 以返回 intermediate

def f(x):
  intermediate = x.sin()
  z = intermediate.sin()
  return z, intermediate

grad_x, intermediate = grad(f, has_aux=True)(x)

torch.autograd API

如果您尝试在被 vmap() 或 functorch 的一个 AD 变换(vjp()jvp()jacrev()jacfwd())转换的函数内部使用 torch.autograd API(如 torch.autograd.gradtorch.autograd.backward),则变换可能无法对其进行转换。如果无法转换,您将收到错误消息。

这是 PyTorch 的 AD 支持实现方式中的一个基本设计限制,也是我们设计 functorch 库的原因。请改用 torch.autograd API 的 functorch 等效项: - torch.autograd.gradTensor.backward -> functorch.vjpfunctorch.grad - torch.autograd.functional.jvp -> functorch.jvp - torch.autograd.functional.jacobian -> functorch.jacrevfunctorch.jacfwd - torch.autograd.functional.hessian -> functorch.hessian

vmap 限制

注意

vmap() 是我们限制最严格的变换。与梯度相关的变换(grad()vjp()jvp())没有这些限制。jacfwd()(以及 hessian(),它是使用 jacfwd() 实现的)是 vmap()jvp() 的组合,因此它也存在这些限制。

vmap(func) 是一种变换,它返回一个函数,该函数将 func 映射到每个输入 Tensor 的某个新维度上。vmap 的思维模型是它就像运行一个 for 循环:对于纯函数(即在没有副作用的情况下),vmap(f)(x) 等效于

torch.stack([f(x_i) for x_i in x.unbind(0)])

修改:Python 数据结构的任意修改

在存在副作用的情况下,vmap() 不再像运行 for 循环那样工作。例如,以下函数

def f(x, list):
  list.pop()
  print("hello!")
  return x.sum(0)

x = torch.randn(3, 1)
lst = [0, 1, 2, 3]

result = vmap(f, in_dims=(0, None))(x, lst)

将打印一次“hello!”,并且只从 lst 中弹出单个元素。

vmap() 执行 f 一次,因此所有副作用只发生一次。

这是 vmap 实现方式的结果。functorch 有一个特殊的内部 BatchedTensor 类。vmap(f)(*inputs) 获取所有 Tensor 输入,将其转换为 BatchedTensor,并调用 f(*batched_tensor_inputs)。BatchedTensor 重写了 PyTorch API,以针对每个 PyTorch 运算符产生批量(即向量化)行为。

修改:就地 PyTorch 运算

vmap() 如果遇到不受支持的 PyTorch 就地运算,将引发错误,否则将成功。不受支持的操作是指会导致将具有更多元素的 Tensor 写入具有较少元素的 Tensor 的操作。以下是如何发生这种情况的示例

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(1)
y = torch.randn(3)

# Raises an error because `y` has fewer elements than `x`.
vmap(f, in_dims=(None, 0))(x, y)

x 是一个具有单个元素的 Tensor,y 是一个具有三个元素的 Tensor。x + y 有三个元素(由于广播),但尝试将三个元素写回 x(它只有一个元素),由于尝试将三个元素写入只有一个元素的 Tensor 而引发错误。

如果要写入的 Tensor 具有相同数量(或更多)的元素,则不会出现问题

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(3)
y = torch.randn(3)
expected = x + y

# Does not raise an error because x and y have the same number of elements.
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)

修改:out= PyTorch 运算

vmap() 不支持 PyTorch 运算中的 out= 关键字参数。如果在您的代码中遇到它,它将优雅地出错。

这不是一个根本性的限制;我们理论上可以在将来支持它,但我们目前选择不这样做。

数据相关的 Python 控制流

我们尚不支持对数据相关的控制流进行 vmap。数据相关的控制流是指 if 语句、while 循环或 for 循环的条件是一个正在进行 vmap 的 Tensor。例如,以下操作将引发错误消息

def relu(x):
  if x > 0:
    return x
  return 0

x = torch.randn(3)
vmap(relu)(x)

但是,任何不依赖于vmap张量中值的控制流都将正常工作。

def custom_dot(x):
  if x.dim() == 1:
    return torch.dot(x, x)
  return (x * x).sum()

x = torch.randn(3)
vmap(custom_dot)(x)

JAX 支持使用特殊的控制流运算符(例如 jax.lax.condjax.lax.while_loop)对数据相关控制流进行转换。我们正在研究向 functorch 添加这些运算符的等价物(在GitHub 上提交 issue 以表达您的支持!)。

数据相关的操作 (.item())

我们不支持(也永远不会支持)在调用张量上 .item() 的用户定义函数上使用 vmap。例如,以下操作将引发错误消息

def f(x):
  return x.item()

x = torch.randn(3)
vmap(f)(x)

请尝试重写您的代码,避免使用 .item() 调用。

您也可能遇到有关使用 .item() 的错误消息,但您可能并没有使用它。在这些情况下,PyTorch 内部可能会调用 .item() - 请在 GitHub 上提交 issue,我们将修复 PyTorch 内部问题。

动态形状操作(nonzero 及其相关操作)

vmap(f) 要求应用于输入中每个“示例”的 f 返回具有相同形状的张量。诸如 torch.nonzerotorch.is_nonzero 等操作不受支持,因此将导致错误。

要了解原因,请考虑以下示例

xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)

torch.nonzero(xs[0]) 返回形状为 2 的张量;但 torch.nonzero(xs[1]) 返回形状为 1 的张量。我们无法构建单个张量作为输出;输出需要是不规则张量(而 PyTorch 尚未具有不规则张量的概念)。

随机性

用户调用随机操作时的意图可能不明确。具体来说,一些用户可能希望随机行为在批次之间保持一致,而另一些用户可能希望它在批次之间有所不同。为了解决这个问题,vmap 提供了一个随机性标志。

该标志只能传递给 vmap,并且可以取三个值:“error”、“different” 或“same”,默认为 error。在“error”模式下,对随机函数的任何调用都会产生错误,要求用户根据其用例使用其他两个标志之一。

在“different”随机性下,批次中的元素生成不同的随机值。例如,

def add_noise(x):
  y = torch.randn(())  # y will be different across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x)  # we get 3 different values

在“same”随机性下,批次中的元素生成相同的随机值。例如,

def add_noise(x):
  y = torch.randn(())  # y will be the same across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x)  # we get the same value, repeated 3 times

警告

我们的系统仅确定 PyTorch 运算符的随机性行为,无法控制其他库(如 numpy)的行为。这与 JAX 在其解决方案中的局限性类似。

注意

使用任何一种支持的随机性类型的多次 vmap 调用不会产生相同的结果。与标准 PyTorch 一样,用户可以通过在 vmap 外部使用 torch.manual_seed() 或使用生成器来获得随机性可重复性。

注意

最后,我们的随机性与 JAX 不同,因为我们没有使用无状态 PRNG,部分原因是 PyTorch 并不完全支持无状态 PRNG。相反,我们引入了一个标志系统,以允许我们看到的随机性的最常见形式。如果您的用例不适合这些随机性形式,请提交 issue。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源