使用 autograd.Function 扩展 torch.func¶
您可能希望将 torch.autograd.Function
与 torch.func
变换(如 torch.vmap()
、torch.func.grad()
等)一起使用。
主要有两种用例
您希望调用不包含 PyTorch 操作的代码,并使其与函数变换一起使用。也就是说,
torch.autograd.Function
的 forward/backward/etc 调用来自其他系统(如 C++、CUDA、numpy)的函数。您希望指定自定义梯度规则,例如 JAX 的 custom_vjp/custom_jvp
PyTorch 将这两个概念结合到 torch.autograd.Function
中。
基本用法¶
本指南假设您熟悉 扩展 torch.autograd,其中解释了如何使用 torch.autograd.Function
。
torch.autograd.Function
可以拥有一个接受 ctx 对象的 forward()
方法,或者它可以拥有单独的 forward()
方法(不接受 ctx
)和一个修改 ctx
对象的 setup_context()
静态方法。
只有后者在函数转换中受支持。
forward()
是执行操作的代码,它不应该接受ctx
对象。setup_context(ctx, inputs, output)
是你可以调用ctx
上的方法的代码。在这里,你应该保存用于反向传播的张量(通过调用ctx.save_for_backward(*tensors)
),或者保存非张量(通过将它们分配给ctx
对象)。
因为 setup_context()
只接受 inputs
和 output
,所以唯一可以保存的数量是输入或输出中的对象(例如张量)或从它们派生的数量(例如 Tensor.shape
)。如果你希望保存来自 Function.forward()
的非输入中间激活用于反向传播,那么你需要将其作为 forward()
的输出返回,以便它被传递给 setup_context()
。
根据转换的不同,
为了支持反向模式 AD(
torch.func.grad()
,torch.func.vjp()
),torch.autograd.Function
需要一个backward()
静态方法。为了支持
torch.vmap()
,torch.autograd.Function
需要一个vmap()
静态方法。为了支持
torch.func.jvp()
,torch.autograd.Function
需要一个jvp()
静态方法。为了支持转换的组合(例如
torch.func.jacrev()
,torch.func.jacfwd()
,torch.func.hessian()
) - 你可能需要多个上述方法。
为了使 torch.autograd.Function
可以与函数转换任意组合,我们建议除了 forward()
和 setup_context()
之外的所有其他静态方法都必须是可转换的:也就是说,它们必须仅包含 PyTorch 运算符或调用其他 torch.autograd.Function
(可能调用 C++/CUDA/等)。
让我们来看一些常见用例的例子。
示例 1:autograd.Function 调用另一个系统¶
一个常见情况是 torch.autograd.Function
的 forward() 和 backward() 都调用另一个系统(如 C++、CUDA、numpy、triton)。
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
# Note that forward does not take ctx
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
# Any intermediates to be saved in backward must be returned as
# outputs.
return (
# The desired output
torch.tensor(result, device=device),
# intermediate to save for backward
torch.tensor(ind, device=device),
# intermediate to save for backward
torch.tensor(ind_inv, device=device),
)
# setup_context is responsible for calling methods and/or assigning to
# the ctx object. Please do not do additional compute (e.g. add
# Tensors together) in setup_context.
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
# Note that output is whatever you returned from forward.
# If you returned multiple values, then output is a Tuple of multiple values.
# If you returned a single Tensor, then output is a Tensor.
# If you returned a Tuple with a single Tensor, then output is a
# Tuple with a single Tensor.
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
# Tensors must be saved via ctx.save_for_backward. Please do not
# assign them directly onto the ctx object.
ctx.save_for_backward(ind, ind_inv)
# Non-tensors may be saved by assigning them as attributes on the ctx object.
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
# For the autograd.Function to be arbitrarily composable with function
# transforms, all staticmethod other than forward and setup_context
# must be implemented in a "transformable" way; that is, they must
# only consist of PyTorch operations or autograd.Function.
#
# For example, this allows us to do double backwards and/or compute
# second order gradients.
#
# We've written the backward pass of NumpySort in terms of another
# autograd.Function, NumpyTake.
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
现在,为了更方便地使用 NumpySort
(隐藏我们作为输出返回的中间值,以及允许默认参数和关键字参数),我们创建一个新的函数来调用它
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
这是一个健全性检查
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd.Function 指定自定义梯度规则¶
另一个常见情况是 torch.autograd.Function
使用 PyTorch 运算符实现。PyTorch 能够自动计算 PyTorch 运算符的梯度,但也许我们希望自定义梯度的计算方式。我们可能想要一个与 PyTorch 给出的不同的自定义 backward 的一些原因是
提高数值稳定性
改变 backward 的性能特征
更改边缘情况的处理方式(例如 NaN、inf)
修改梯度(例如梯度裁剪)
以下是一个 torch.autograd.Function
的示例,用于函数 y = x ** 3
,其中我们更改了性能特征(一些通常在反向传播期间发生的计算,计算 dx,发生在正向传播期间)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# In regular PyTorch, if we had just run y = x ** 3, then the backward
# pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
# that computation here in the forward pass instead.
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`.
result = grad_output * dx + grad_dx * 6 * x
return result
现在,为了更方便地使用 NumpySort
(并隐藏我们作为输出返回的中间值),我们创建了一个调用它的新函数
def my_cube(x):
result, _ = MyCube.apply(x)
return result
这是一个计算二阶梯度的健全性检查
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和注意事项¶
警告
请仔细阅读 torch.autograd.Function
与 torch.func 变换的这些限制。我们无法捕获许多这些情况并优雅地出错,因此会导致未定义的行为。
请不要将正在转换的张量、具有 requires_grad=True 的张量或双张量捕获到 torch.autograd.Function
的方法中。完全安全的做法是确保在 torch.autograd.Function
的任何方法内部使用的唯一张量必须直接作为输入传递(或通过 ctx 对象),而不是来自 torch.autograd.Function
的外部。
torch.autograd.Function
不处理 pytrees 中的张量(可能包含或可能不包含张量的任意嵌套 Python 数据结构)。为了让这些张量被 autograd 跟踪,它们必须直接作为参数传递给 torch.autograd.Function
。这与 jax.{custom_vjp, custom_jvp} 形成对比,后者确实接受 pytrees。
请仅使用 save_for_backward()
或 save_for_forward()
来保存张量。请不要将张量或张量集合直接分配到 ctx 对象上 - 这些张量将不会被跟踪
torch.vmap()
支持¶
要将 torch.autograd.Function
与 torch.vmap()
一起使用,您必须:
提供一个
vmap()
静态方法,告诉我们torch.autograd.Function
在torch.vmap()
下的行为通过设置
generate_vmap_rule=True
来让我们自动生成它。
自动生成 vmap 规则¶
如果您的 torch.autograd.Function
满足以下附加约束,那么我们就可以为它生成一个 vmap 规则。如果它不满足约束,或者您希望在 vmap 下有自定义行为,请手动定义一个 vmap 静态方法(参见下一节)。
警告
我们不容易检查以下约束并优雅地出错。违反约束可能会导致未定义的行为。
torch.autograd.Function
的forward()
、backward()
(如果存在)和jvp()
(如果存在)静态方法必须可以通过torch.vmap()
进行转换。也就是说,它们必须只包含 PyTorch 操作(而不是例如 NumPy 或自定义 CUDA 内核)。
示例
class MyCube(torch.autograd.Function):
# Set generate_vmap_rule to True to ask PyTorch to automatically generate
# a vmap rule.
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定义 vmap 静态方法¶
如果您的 torch.autograd.Function
调用另一个系统(如 NumPy、C++、CUDA、triton),那么为了使其与 torch.vmap()
或使用它的转换一起工作,您需要手动定义一个 vmap()
静态方法。
根据您要使用的转换和您的用例,您可能不需要为所有 torch.autograd.Function
添加 vmap()
静态方法。
例如,
torch.func.jacrev()
在反向传播过程中执行vmap()
。因此,如果您只对使用torch.func.jacrev()
感兴趣,则只需要使backward()
静态方法可映射。
我们建议确保所有 torch.autograd.Function
都支持 torch.vmap()
,特别是在您编写第三方库并且希望您的 torch.autograd.Function
与所有 torch.func()
转换组合一起工作时。
从概念上讲,vmap 静态方法负责定义 forward()
在 torch.vmap()
下的行为方式。也就是说,它定义了如何将 forward()
转换为在具有额外维度(被 vmap 映射的维度)的输入上运行。这类似于 torch.vmap()
在 PyTorch 操作上的实现方式:对于每个操作,我们定义一个 vmap 规则(有时也称为“批处理规则”)。
以下是定义 vmap()
静态方法的方法
签名为
vmap(info, in_dims: Tuple[Optional[int]], *args)
,其中*args
与forward()
的参数相同。vmap 静态方法负责定义
forward()
在torch.vmap()
下的行为方式。也就是说,给定具有额外维度(由in_dims
指定)的输入,我们如何计算forward()
的批处理版本?对于
args
中的每个参数,in_dims
都有一个对应的Optional[int]
。如果参数不是张量或没有被 vmap 映射,则为None
,否则为一个整数,指定张量的哪个维度被 vmap 映射。info
是一个包含可能有用附加元数据的集合:info.batch_size
指定被 vmap 映射的维度的尺寸,而info.randomness
是传递给torch.vmap()
的randomness
选项。vmap 静态方法的返回值是一个包含
(output, out_dims)
的元组。类似于in_dims
,out_dims
应该与output
具有相同的结构,并包含每个输出的out_dim
,用于指定输出是否具有 vmap 映射的维度以及该维度的索引。
示例
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
# The signature of the vmap staticmethod is:
# vmap(info, in_dims: Tuple[Optional[int]], *args)
# where *args is the same as the arguments to `forward`.
@staticmethod
def vmap(info, in_dims, x, dim):
# For every input (x and dim), in_dims stores an Optional[int]
# that is:
# - None if the input is not being vmapped over or if the input
# is not a Tensor
# - an integer if the input is being vmapped over that represents
# the index of the dimension being vmapped over.
x_bdim, _ = in_dims
# A "vmap rule" is the logic of how to perform the operation given
# inputs with one additional dimension. In NumpySort, x has an
# additional dimension (x_bdim). The vmap rule is simply
# to call NumpySort again but pass it a different `dim`.
x = x.movedim(x_bdim, 0)
# Handle negative dims correctly
dim = dim if dim >= 0 else dim + x.dim() - 1
result = NumpySort.apply(x, dim + 1)
# The vmap rule must return a tuple of two things
# 1. the output. Should be the same amount of things
# as returned by the forward().
# 2. one Optional[int] for each output specifying if each output
# is being vmapped over, and if so, the index of the
# dimension being vmapped over.
#
# NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
# dimension being vmapped over to the front of `x`, that appears at
# dimension 0 of all outputs.
# The return is (output, out_dims) -- output is a tuple of 3 Tensors
# and out_dims is a Tuple of 3 Optional[int]
return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# The strategy is: expand {x, ind, ind_inv} to all have the dimension
# being vmapped over.
# Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
# Handle negative dims by wrapping them to be positive
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
# If the Tensor doesn't have the dimension being vmapped over,
# expand it out. Otherwise, move it to the front of the Tensor
x = maybe_expand_bdim_at_front(x, x_bdim)
ind = maybe_expand_bdim_at_front(ind, ind_bdim)
ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
# The return is a tuple (output, out_dims). Since output is a Tensor,
# then out_dims is an Optional[int] (instead of being a Tuple).
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
注意
vmap 静态方法应该旨在保留整个 Function
的语义。也就是说,(伪代码)grad(vmap(MyFunc))
应该可以替换为 grad(map(MyFunc))
。
如果你的 autograd.Function 在反向传播中具有任何自定义行为,请牢记这一点。
注意
为 Function
编写自定义 vmap 静态方法是一个合法的用例,PyTorch 可以通过 generate_vmap_rule=True
为其生成 vmap 规则。如果你生成的 vmap 规则没有你想要的语义,你可能希望这样做。
torch.func.jvp()
支持¶
为了支持前向模式 AD,torch.autograd.Function
必须具有 jvp()
静态方法。请参阅 前向模式 AD 以了解更多详细信息。