使用 autograd.Function 扩展 torch.func¶
您可能希望将 torch.autograd.Function
与 torch.func
变换(如 torch.vmap()
、torch.func.grad()
等)一起使用。
主要有两个用例
您希望调用不包含 PyTorch 算子的代码并使其与函数变换一起工作。也就是说,
torch.autograd.Function
的 forward/backward/等方法调用其他系统(如 C++、CUDA、numpy)中的函数。您希望指定自定义梯度规则,例如 JAX 的 custom_vjp/custom_jvp。
PyTorch 将这两个概念结合到 torch.autograd.Function
中。
基本用法¶
本指南假定您熟悉 扩展 torch.autograd,其中解释了如何使用 torch.autograd.Function
。
torch.autograd.Function
可以有一个接受 ctx 对象的 forward()
方法,或者有一个不接受 ctx
的独立的 forward()
方法和一个修改 ctx
对象的 setup_context()
静态方法。
函数变换只支持后者
forward()
是执行操作的代码,它不应该接受ctx
对象。setup_context(ctx, inputs, output)
是您可以在ctx
对象上调用方法的地方。您应该在这里保存用于反向传播的张量(通过调用ctx.save_for_backward(*tensors)
),或者保存非张量对象(通过将它们赋值给ctx
对象)。
因为 setup_context()
只接受 inputs
和 output
,所以可以保存的数量只能是 inputs 或 outputs 中的对象(如张量),或是从它们派生的数量(如 Tensor.shape
)。如果您希望保存 Function.forward()
中非输入的中间激活用于反向传播,则需要将其作为 forward()
的输出返回,以便传递给 setup_context()
。
根据不同的变换,
为了支持反向模式 AD(
torch.func.grad()
、torch.func.vjp()
),torch.autograd.Function
需要一个backward()
静态方法。为了支持
torch.vmap()
,torch.autograd.Function
需要一个vmap()
静态方法。为了支持
torch.func.jvp()
,torch.autograd.Function
需要一个jvp()
静态方法。为了支持变换的组合(如
torch.func.jacrev()
、torch.func.jacfwd()
、torch.func.hessian()
),您可能需要上述方法中的多个。
为了使 torch.autograd.Function
能够与函数变换任意组合,我们建议除了 forward()
和 setup_context()
之外的所有其他静态方法都必须是可变换的:也就是说,它们必须仅由 PyTorch 算子组成或调用其他 torch.autograd.Function
(这些 Function 可能调用 C++/CUDA/等)。
让我们来看一些常见用例的示例。
示例 1:autograd.Function 调用其他系统¶
一个常见情况是,torch.autograd.Function
的 forward()
和 backward()
都调用其他系统(如 C++、CUDA、numpy、triton)中的函数。
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
# Note that forward does not take ctx
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
# Any intermediates to be saved in backward must be returned as
# outputs.
return (
# The desired output
torch.tensor(result, device=device),
# intermediate to save for backward
torch.tensor(ind, device=device),
# intermediate to save for backward
torch.tensor(ind_inv, device=device),
)
# setup_context is responsible for calling methods and/or assigning to
# the ctx object. Please do not do additional compute (e.g. add
# Tensors together) in setup_context.
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
# Note that output is whatever you returned from forward.
# If you returned multiple values, then output is a Tuple of multiple values.
# If you returned a single Tensor, then output is a Tensor.
# If you returned a Tuple with a single Tensor, then output is a
# Tuple with a single Tensor.
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
# Tensors must be saved via ctx.save_for_backward. Please do not
# assign them directly onto the ctx object.
ctx.save_for_backward(ind, ind_inv)
# Non-tensors may be saved by assigning them as attributes on the ctx object.
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
# For the autograd.Function to be arbitrarily composable with function
# transforms, all staticmethod other than forward and setup_context
# must be implemented in a "transformable" way; that is, they must
# only consist of PyTorch operations or autograd.Function.
#
# For example, this allows us to do double backwards and/or compute
# second order gradients.
#
# We've written the backward pass of NumpySort in terms of another
# autograd.Function, NumpyTake.
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
现在,为了更方便地使用 NumpySort
(以隐藏我们作为输出返回的中间结果,并允许使用默认参数和关键字参数),我们创建一个新函数来调用它。
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
以下是一个健全性检查
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd.Function 指定自定义梯度规则¶
另一种常见情况是,torch.autograd.Function
是用 PyTorch 算子实现的。PyTorch 能够自动计算 PyTorch 算子的梯度,但也许我们希望自定义梯度计算方式。我们可能希望自定义与 PyTorch 提供的不同反向传播的一些原因包括:
提高数值稳定性
改变反向传播的性能特性
改变边缘情况的处理方式(例如 NaN、Inf)
修改梯度(例如梯度裁剪)
以下是一个函数 y = x ** 3
的 torch.autograd.Function
示例,其中我们改变了性能特性(一些通常发生在反向传播中计算 dx 的计算,被放在了前向传播中)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# In regular PyTorch, if we had just run y = x ** 3, then the backward
# pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
# that computation here in the forward pass instead.
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`.
result = grad_output * dx + grad_dx * 6 * x
return result
现在,为了更方便地使用 NumpySort
(并隐藏我们作为输出返回的中间结果),我们创建一个新函数来调用它。
def my_cube(x):
result, _ = MyCube.apply(x)
return result
以下是计算二阶梯度的健全性检查。
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制与陷阱¶
警告
请仔细阅读 torch.autograd.Function
与 torch.func 变换结合使用时的这些限制。我们无法捕获许多此类情况并优雅地报错,因此它们可能导致未定义行为。
请不要将正在进行变换、设置了 requires_grad=True
或属于 dual tensors 的张量捕获到 torch.autograd.Function
的方法中。完全安全的方法是确保在 torch.autograd.Function
的任何方法中使用的张量必须直接作为输入(或通过 ctx 对象)传递,而不是来自 torch.autograd.Function
外部。
torch.autograd.Function
不处理 pytrees(任意嵌套的 Python 数据结构,可能包含或不包含张量)中的张量。为了让这些张量被 autograd 跟踪,它们必须直接作为参数传递给 torch.autograd.Function
。这与 jax.{custom_vjp, custom_jvp} 不同,后者接受 pytrees。
请仅使用 save_for_backward()
或 save_for_forward()
来保存张量。请不要直接将张量或张量集合赋值给 ctx 对象 - 这些张量不会被跟踪。
torch.vmap()
支持¶
要将 torch.autograd.Function
与 torch.vmap()
一起使用,您必须选择以下方法之一:
提供一个
vmap()
静态方法,用于说明torch.autograd.Function
在torch.vmap()
下的行为。通过设置
generate_vmap_rule=True
让我们自动生成它。
自动生成 vmap 规则¶
如果您的 torch.autograd.Function
满足以下附加约束,则我们可以为其生成 vmap 规则。如果它不满足约束或者您希望在 vmap 下有自定义行为,请手动定义一个 vmap 静态方法(参见下一节)。
警告
我们无法轻松检查以下约束并优雅地报错。违反约束可能导致未定义行为。
torch.autograd.Function
的forward()
方法、backward()
方法(如果存在)和jvp()
方法(如果存在)必须可以通过torch.vmap()
进行变换。也就是说,它们必须仅由 PyTorch 算子组成(而不是 NumPy 或自定义 CUDA 内核等)。
示例
class MyCube(torch.autograd.Function):
# Set generate_vmap_rule to True to ask PyTorch to automatically generate
# a vmap rule.
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定义 vmap 静态方法¶
如果您的 torch.autograd.Function
调用其他系统(如 NumPy、C++、CUDA、triton)中的函数,那么为了使其与 torch.vmap()
或使用它的变换一起工作,您需要手动定义一个 vmap()
静态方法。
根据您想要使用的变换和您的用例,您可能不需要为所有的 torch.autograd.Function
添加 vmap()
静态方法。
例如,
torch.func.jacrev()
在反向传播过程中执行vmap()
。因此,如果您只对使用torch.func.jacrev()
感兴趣,则只需要使backward()
静态方法支持 vmap 即可。
不过,我们建议确保您所有的 torch.autograd.Function
都支持 torch.vmap()
,特别是如果您正在编写第三方库,并且希望您的 torch.autograd.Function
能与 torch.func()
的所有变换组合一起工作。
从概念上讲,vmap 静态方法负责定义 forward()
在 torch.vmap()
下的行为方式。也就是说,它定义了如何变换 forward()
,使其在具有额外维度(即进行 vmap 的维度)的输入上运行。这类似于 torch.vmap()
在 PyTorch 算子上的实现方式:对于每个算子,我们定义一个 vmap 规则(有时也称为“批处理规则”)。
以下是如何定义 vmap()
静态方法:
签名为
vmap(info, in_dims: Tuple[Optional[int]], *args)
,其中*args
与forward()
的参数相同。vmap 静态方法负责定义
forward()
在torch.vmap()
下的行为方式。也就是说,给定具有额外维度(由in_dims
指定)的输入,我们如何计算forward()
的批处理版本?对于
args
中的每个参数,in_dims
都有一个对应的Optional[int]
。如果参数不是张量或未进行 vmap 变换,则该值为None
;否则,它是一个整数,指定了张量正在进行 vmap 变换的维度。info
是一个包含可能有用附加元数据的集合:info.batch_size
指定进行 vmap 变换的维度大小,而info.randomness
是传递给torch.vmap()
的randomness
选项。vmap 静态方法的返回是一个包含
(output, out_dims)
的元组。与in_dims
类似,out_dims
应与output
具有相同的结构,并且每个输出包含一个out_dim
,用于指定输出是否包含 vmap 维度以及该维度的索引位置。
示例
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
# The signature of the vmap staticmethod is:
# vmap(info, in_dims: Tuple[Optional[int]], *args)
# where *args is the same as the arguments to `forward`.
@staticmethod
def vmap(info, in_dims, x, dim):
# For every input (x and dim), in_dims stores an Optional[int]
# that is:
# - None if the input is not being vmapped over or if the input
# is not a Tensor
# - an integer if the input is being vmapped over that represents
# the index of the dimension being vmapped over.
x_bdim, _ = in_dims
# A "vmap rule" is the logic of how to perform the operation given
# inputs with one additional dimension. In NumpySort, x has an
# additional dimension (x_bdim). The vmap rule is simply
# to call NumpySort again but pass it a different `dim`.
x = x.movedim(x_bdim, 0)
# Handle negative dims correctly
dim = dim if dim >= 0 else dim + x.dim() - 1
result = NumpySort.apply(x, dim + 1)
# The vmap rule must return a tuple of two things
# 1. the output. Should be the same amount of things
# as returned by the forward().
# 2. one Optional[int] for each output specifying if each output
# is being vmapped over, and if so, the index of the
# dimension being vmapped over.
#
# NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
# dimension being vmapped over to the front of `x`, that appears at
# dimension 0 of all outputs.
# The return is (output, out_dims) -- output is a tuple of 3 Tensors
# and out_dims is a Tuple of 3 Optional[int]
return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# The strategy is: expand {x, ind, ind_inv} to all have the dimension
# being vmapped over.
# Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
# Handle negative dims by wrapping them to be positive
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
# If the Tensor doesn't have the dimension being vmapped over,
# expand it out. Otherwise, move it to the front of the Tensor
x = maybe_expand_bdim_at_front(x, x_bdim)
ind = maybe_expand_bdim_at_front(ind, ind_bdim)
ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
# The return is a tuple (output, out_dims). Since output is a Tensor,
# then out_dims is an Optional[int] (instead of being a Tuple).
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
注意
vmap 静态方法应旨在保留整个 Function
的语义。也就是说,(伪代码) grad(vmap(MyFunc))
应该可以替换为 grad(map(MyFunc))
。
如果您的 autograd.Function 在反向传播中包含任何自定义行为,请记住这一点。
注意
为一个 Function
编写自定义 vmap 静态方法是合法的用例,即使 PyTorch 可以通过 generate_vmap_rule=True
为其生成 vmap 规则。如果生成的 vmap 规则不具备您所需的语义,您可能希望这样做。
torch.func.jvp()
支持¶
为了支持前向模式自动微分 (AD),torch.autograd.Function
必须包含一个 jvp()
静态方法。详细信息请参阅前向模式 AD。