使用 autograd.Function 扩展 torch.func¶
因此,您希望将torch.autograd.Function
与torch.func
转换(如torch.vmap()
、torch.func.grad()
等)一起使用。
主要有两个用例
您希望调用不包含 PyTorch 操作的代码并使其与函数转换一起工作。也就是说,
torch.autograd.Function
的 forward/backward/等调用来自其他系统(如 C++、CUDA、numpy)的函数。您希望指定自定义梯度规则,例如 JAX 的custom_vjp/custom_jvp
PyTorch 将这两个概念都组合到torch.autograd.Function
中。
基本用法¶
本指南假设您熟悉扩展 torch.autograd,其中解释了如何使用torch.autograd.Function
。
torch.autograd.Function
可以具有接受 ctx 对象的forward()
,也可以具有单独的forward()
(不接受ctx
)和修改ctx
对象的setup_context()
静态方法。
只有后者受函数转换支持
forward()
是执行操作的代码,它不应接受ctx
对象。setup_context(ctx, inputs, output)
是可以在其中调用ctx
上方法的代码。在这里,您应该保存用于反向传播的张量(通过调用ctx.save_for_backward(*tensors)
),或保存非张量(通过将它们分配给ctx
对象)。
因为setup_context()
仅接受inputs
和output
,所以唯一可以保存的数量是输入或输出中的对象(例如张量)或从中派生的数量(例如Tensor.shape
)。如果您希望保存Function.forward()
的反向传播的非输入中间激活,那么您需要将其作为forward()
的输出返回,以便将其传递给setup_context()
。
根据转换,
要支持反向模式 AD(
torch.func.grad()
、torch.func.vjp()
),torch.autograd.Function
需要backward()
静态方法。要支持
torch.vmap()
,torch.autograd.Function
需要vmap()
静态方法。为了支持
torch.func.jvp()
,torch.autograd.Function
需要一个jvp()
静态方法。为了支持变换的组合(例如
torch.func.jacrev()
、torch.func.jacfwd()
、torch.func.hessian()
)——你可能需要上述方法中的多个。
为了使 torch.autograd.Function
可以与函数变换任意组合,我们建议除了 forward()
和 setup_context()
之外的所有其他静态方法都必须是可变换的:也就是说,它们必须仅由 PyTorch 运算符组成或调用其他 torch.autograd.Function
(这可能调用 C++/CUDA/等)。
让我们来看一些常见用例的示例。
示例 1:autograd.Function 调用另一个系统¶
一个常见的情况是 torch.autograd.Function
的 forward()
和 backward()
方法都调用另一个系统(如 C++、CUDA、NumPy、Triton)。
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
# Note that forward does not take ctx
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
# Any intermediates to be saved in backward must be returned as
# outputs.
return (
# The desired output
torch.tensor(result, device=device),
# intermediate to save for backward
torch.tensor(ind, device=device),
# intermediate to save for backward
torch.tensor(ind_inv, device=device),
)
# setup_context is responsible for calling methods and/or assigning to
# the ctx object. Please do not do additional compute (e.g. add
# Tensors together) in setup_context.
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
# Note that output is whatever you returned from forward.
# If you returned multiple values, then output is a Tuple of multiple values.
# If you returned a single Tensor, then output is a Tensor.
# If you returned a Tuple with a single Tensor, then output is a
# Tuple with a single Tensor.
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
# Tensors must be saved via ctx.save_for_backward. Please do not
# assign them directly onto the ctx object.
ctx.save_for_backward(ind, ind_inv)
# Non-tensors may be saved by assigning them as attributes on the ctx object.
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
# For the autograd.Function to be arbitrarily composable with function
# transforms, all staticmethod other than forward and setup_context
# must be implemented in a "transformable" way; that is, they must
# only consist of PyTorch operations or autograd.Function.
#
# For example, this allows us to do double backwards and/or compute
# second order gradients.
#
# We've written the backward pass of NumpySort in terms of another
# autograd.Function, NumpyTake.
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
现在,为了更容易使用 NumpySort
(隐藏我们作为输出返回的中间结果,并允许默认参数和关键字参数),我们创建一个新的函数来调用它。
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
这是一个健全性检查。
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd.Function 指定自定义梯度规则¶
另一个常见的情况是使用 PyTorch 操作实现的 torch.autograd.Function
。PyTorch 能够自动计算 PyTorch 操作的梯度,但也许我们希望自定义梯度的计算方式。我们可能想要自定义反向传播的原因包括:
提高数值稳定性
改变反向传播的性能特征
改变边缘情况的处理方式(例如 NaN、无穷大)
修改梯度(例如梯度裁剪)
这是一个 torch.autograd.Function
的示例,用于函数 y = x ** 3
,其中我们改变了性能特征(一些通常在反向传播过程中发生的计算,即计算 dx,发生在正向传播过程中)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# In regular PyTorch, if we had just run y = x ** 3, then the backward
# pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
# that computation here in the forward pass instead.
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`.
result = grad_output * dx + grad_dx * 6 * x
return result
现在,为了更容易使用 NumpySort
(并隐藏我们作为输出返回的中间结果),我们创建一个新的函数来调用它。
def my_cube(x):
result, _ = MyCube.apply(x)
return result
这是一个计算二阶梯度的健全性检查。
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和注意事项¶
警告
请仔细阅读 torch.autograd.Function
与 torch.func 变换的这些限制。我们无法捕获许多这些情况并优雅地报错,因此它们会导致未定义的行为。
请不要在 torch.autograd.Function
的方法中捕获正在被变换、具有 requires_grad=True 或为双张量的张量。完全安全的做法是确保在 torch.autograd.Function
的任何方法内部使用的唯一张量必须直接作为输入传递(或通过 ctx 对象),而不是来自 torch.autograd.Function
的外部。
torch.autograd.Function
不处理 pytree 中的张量(可能包含也可能不包含张量的任意嵌套 Python 数据结构)。为了让这些张量被 autograd 跟踪,它们必须直接作为参数传递给 torch.autograd.Function
。这与 jax.{custom_vjp, custom_jvp} 形成对比,后者确实接受 pytree。
请仅使用 save_for_backward()
或 save_for_forward()
来保存张量。请不要将张量或张量集合直接分配到 ctx 对象上——这些张量不会被跟踪。
torch.vmap()
支持¶
要将 torch.autograd.Function
与 torch.vmap()
一起使用,你必须:
提供一个
vmap()
静态方法,告诉我们torch.autograd.Function
在torch.vmap()
下的行为。通过设置
generate_vmap_rule=True
让系统自动生成 vmap 规则。
自动生成 vmap 规则¶
如果你的 torch.autograd.Function
满足以下额外约束,那么我们就可以为它生成一个 vmap 规则。如果它不满足约束或你希望在 vmap 下有自定义行为,请手动定义一个 vmap 静态方法(见下一节)。
警告
我们不容易检查以下约束并优雅地报错。违反约束可能会导致未定义的行为。
torch.autograd.Function
的forward()
、backward()
(如果存在)和jvp()
(如果存在)静态方法必须可以通过torch.vmap()
进行变换。也就是说,它们必须仅由 PyTorch 操作组成(而不是例如 NumPy 或自定义 CUDA 内核)。
示例
class MyCube(torch.autograd.Function):
# Set generate_vmap_rule to True to ask PyTorch to automatically generate
# a vmap rule.
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定义 vmap 静态方法¶
如果你的 torch.autograd.Function
调用另一个系统(如 NumPy、C++、CUDA、Triton),那么为了使其能够与 torch.vmap()
或使用它的变换一起工作,你需要手动定义一个 vmap()
静态方法。
根据你想要使用的变换和你的用例,你可能不需要为所有 torch.autograd.Function
添加 vmap()
静态方法。
例如,
torch.func.jacrev()
对反向传播执行vmap()
。因此,如果你只对使用torch.func.jacrev()
感兴趣,那么只有backward()
静态方法需要是可 vmap 的。
我们建议确保所有 torch.autograd.Function
都支持 torch.vmap()
,尤其是在编写第三方库并且希望你的 torch.autograd.Function
能够与所有 torch.func()
变换组合兼容的情况下。
从概念上讲,vmap 静态方法负责定义 forward()
在 torch.vmap()
下的行为方式。也就是说,它定义了如何转换 forward()
以使其能够处理具有额外维度(即被 vmap 遍历的维度)的输入。这类似于 torch.vmap()
在 PyTorch 操作中实现的方式:对于每个操作,我们定义一个 vmap 规则(有时也称为“批处理规则”)。
以下是定义 vmap()
静态方法的方法。
签名为
vmap(info, in_dims: Tuple[Optional[int]], *args)
,其中*args
与forward()
的参数相同。vmap 静态方法负责定义
forward()
在torch.vmap()
下的行为方式。也就是说,给定具有额外维度(由in_dims
指定)的输入,我们如何计算forward()
的批处理版本?对于
args
中的每个参数,in_dims
都有一个对应的Optional[int]
。如果参数不是张量或未被 vmap 遍历,则为None
;否则,它是一个整数,指定张量的哪个维度正在被 vmap 遍历。info
是一个包含其他元数据的集合,这些元数据可能会有所帮助:info.batch_size
指定被 vmap 遍历的维度的尺寸,而info.randomness
是传递给torch.vmap()
的randomness
选项。vmap 静态方法的返回值是一个包含
(output, out_dims)
的元组。与in_dims
类似,out_dims
应该与output
具有相同的结构,并为每个输出包含一个out_dim
,用于指定输出是否具有 vmap 维度以及它在其中的索引。
示例
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
# The signature of the vmap staticmethod is:
# vmap(info, in_dims: Tuple[Optional[int]], *args)
# where *args is the same as the arguments to `forward`.
@staticmethod
def vmap(info, in_dims, x, dim):
# For every input (x and dim), in_dims stores an Optional[int]
# that is:
# - None if the input is not being vmapped over or if the input
# is not a Tensor
# - an integer if the input is being vmapped over that represents
# the index of the dimension being vmapped over.
x_bdim, _ = in_dims
# A "vmap rule" is the logic of how to perform the operation given
# inputs with one additional dimension. In NumpySort, x has an
# additional dimension (x_bdim). The vmap rule is simply
# to call NumpySort again but pass it a different `dim`.
x = x.movedim(x_bdim, 0)
# Handle negative dims correctly
dim = dim if dim >= 0 else dim + x.dim() - 1
result = NumpySort.apply(x, dim + 1)
# The vmap rule must return a tuple of two things
# 1. the output. Should be the same amount of things
# as returned by the forward().
# 2. one Optional[int] for each output specifying if each output
# is being vmapped over, and if so, the index of the
# dimension being vmapped over.
#
# NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
# dimension being vmapped over to the front of `x`, that appears at
# dimension 0 of all outputs.
# The return is (output, out_dims) -- output is a tuple of 3 Tensors
# and out_dims is a Tuple of 3 Optional[int]
return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# The strategy is: expand {x, ind, ind_inv} to all have the dimension
# being vmapped over.
# Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
# Handle negative dims by wrapping them to be positive
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
# If the Tensor doesn't have the dimension being vmapped over,
# expand it out. Otherwise, move it to the front of the Tensor
x = maybe_expand_bdim_at_front(x, x_bdim)
ind = maybe_expand_bdim_at_front(ind, ind_bdim)
ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
# The return is a tuple (output, out_dims). Since output is a Tensor,
# then out_dims is an Optional[int] (instead of being a Tuple).
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
注意
vmap 静态方法应该旨在保持整个 Function
的语义。也就是说,(伪代码)grad(vmap(MyFunc))
应该可以用 grad(map(MyFunc))
替换。
如果你的 autograd.Function 在反向传播过程中有任何自定义行为,请记住这一点。
注意
为 PyTorch 能够通过 generate_vmap_rule=True
生成 vmap 规则的 Function
编写自定义 vmap 静态方法是一个合理的用例。如果你希望生成的 vmap 规则不具有你想要的语义,则可能需要这样做。
torch.func.jvp()
支持¶
为了支持前向模式自动微分,torch.autograd.Function
必须具有一个 jvp()
静态方法。有关详细信息,请参阅 前向模式自动微分。