使用 autograd.Function 扩展 torch.func¶
因此,您想将 torch.autograd.Function
与 torch.func
变换(如 torch.vmap()
、torch.func.grad()
等)一起使用。
主要有两个用例
您希望调用不包含 PyTorch 运算的代码,并使其与函数变换一起使用。也就是说,
torch.autograd.Function
的 forward/backward/etc 调用来自其他系统(如 C++、CUDA、numpy)的函数。您希望指定自定义梯度规则,如 JAX 的 custom_vjp/custom_jvp
PyTorch 将这两个概念结合到 torch.autograd.Function
中。
基本用法¶
本指南假设您熟悉 扩展 torch.autograd,其中解释了如何使用 torch.autograd.Function
。
torch.autograd.Function
可以具有接受 ctx 对象的 forward()
,也可以具有单独的 forward()
(不接受 ctx
)和修改 ctx
对象的 setup_context()
静态方法。
函数变换仅支持后者
forward()
是执行运算的代码,它不应该接受ctx
对象。setup_context(ctx, inputs, output)
是可以在其中调用ctx
方法的代码。在这里,您应该保存用于反向传播的张量(通过调用ctx.save_for_backward(*tensors)
),或保存非张量(通过将它们赋值给ctx
对象)。
因为 setup_context()
仅接受 inputs
和 output
,所以唯一可以保存的数量是输入或输出中的对象(如张量)或从它们派生的数量(如 Tensor.shape
)。如果您希望从 Function.forward()
中保存非输入中间激活以用于反向传播,则需要将其作为 forward()
的输出返回,以便将其传递给 setup_context()
。
根据变换的不同,
为了支持反向模式自动微分(
torch.func.grad()
、torch.func.vjp()
),torch.autograd.Function
需要一个backward()
静态方法。为了支持
torch.vmap()
,torch.autograd.Function
需要一个vmap()
静态方法。为了支持
torch.func.jvp()
,torch.autograd.Function
需要一个jvp()
静态方法。为了支持变换的组合(例如
torch.func.jacrev()
、torch.func.jacfwd()
、torch.func.hessian()
),您可能需要以上多个方法。
为了使 torch.autograd.Function
能够与函数变换任意组合,我们建议除了 forward()
和 setup_context()
之外的所有其他静态方法都必须是可变换的:也就是说,它们必须只包含 PyTorch 运算符或调用其他 torch.autograd.Function
(可以调用 C++/CUDA/etc)。
让我们来看一些常见用例的例子。
示例 1:autograd.Function 调用另一个系统¶
一个常见的情况是 torch.autograd.Function
的 forward() 和 backward() 都调用另一个系统(例如 C++、CUDA、numpy、triton)。
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
# Note that forward does not take ctx
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
# Any intermediates to be saved in backward must be returned as
# outputs.
return (
# The desired output
torch.tensor(result, device=device),
# intermediate to save for backward
torch.tensor(ind, device=device),
# intermediate to save for backward
torch.tensor(ind_inv, device=device),
)
# setup_context is responsible for calling methods and/or assigning to
# the ctx object. Please do not do additional compute (e.g. add
# Tensors together) in setup_context.
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
# Note that output is whatever you returned from forward.
# If you returned multiple values, then output is a Tuple of multiple values.
# If you returned a single Tensor, then output is a Tensor.
# If you returned a Tuple with a single Tensor, then output is a
# Tuple with a single Tensor.
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
# Tensors must be saved via ctx.save_for_backward. Please do not
# assign them directly onto the ctx object.
ctx.save_for_backward(ind, ind_inv)
# Non-tensors may be saved by assigning them as attributes on the ctx object.
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
# For the autograd.Function to be arbitrarily composable with function
# transforms, all staticmethod other than forward and setup_context
# must be implemented in a "transformable" way; that is, they must
# only consist of PyTorch operations or autograd.Function.
#
# For example, this allows us to do double backwards and/or compute
# second order gradients.
#
# We've written the backward pass of NumpySort in terms of another
# autograd.Function, NumpyTake.
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
现在,为了更容易使用 NumpySort
(隐藏我们作为输出返回的中间结果,并允许使用默认的 args 和 kwargs),我们创建一个新的函数来调用它
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
这是一个完整性检查
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd.Function 指定自定义梯度规则¶
另一种常见的情况是使用 PyTorch 运算实现的 torch.autograd.Function
。PyTorch 能够自动计算 PyTorch 运算的梯度,但我们可能希望自定义梯度的计算方式。我们可能需要一个不同于 PyTorch 提供的自定义 backward 的一些原因是
提高数值稳定性
改变 backward 的性能特征
改变边缘情况的处理方式(例如 nan、inf)
修改梯度(例如梯度裁剪)
下面是一个函数 y = x ** 3
的 torch.autograd.Function
示例,我们更改了性能特征(一些通常在 backward 传递期间发生的计算,计算 dx,发生在 forward 传递中)。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# In regular PyTorch, if we had just run y = x ** 3, then the backward
# pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
# that computation here in the forward pass instead.
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`.
result = grad_output * dx + grad_dx * 6 * x
return result
现在,为了更容易使用 NumpySort
(并隐藏我们作为输出返回的中间结果),我们创建一个新的函数来调用它
def my_cube(x):
result, _ = MyCube.apply(x)
return result
这是一个计算二阶梯度的完整性检查
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和注意事项¶
警告
请仔细阅读 torch.autograd.Function
与 torch.func 变换的这些限制。我们无法捕获许多此类情况并优雅地报错,因此它们会导致未定义的行为。
请不要将正在转换、具有 requires_grad=True 或为对偶张量的张量捕获到 torch.autograd.Function
的方法中。完全安全的做法是确保在 torch.autograd.Function
的任何方法中使用的唯一张量必须是作为输入直接传递的(或通过 ctx 对象),而不是来自 torch.autograd.Function
之外。
torch.autograd.Function
不处理 pytrees(可能包含也可能不包含张量的任意嵌套 Python 数据结构)中的张量。为了让 autograd 跟踪这些张量,必须将它们作为参数直接传递给 torch.autograd.Function
。这与 jax.{custom_vjp、custom_jvp} 相反,后者接受 pytrees。
请仅使用 save_for_backward()
或 save_for_forward()
来保存张量。请不要将张量或张量集合直接分配给 ctx 对象——这些张量将不会被跟踪
torch.vmap()
支持¶
要将 torch.autograd.Function
与 torch.vmap()
一起使用,您必须
提供一个
vmap()
静态方法,该方法告诉我们torch.autograd.Function
在torch.vmap()
下的行为通过设置
generate_vmap_rule=True
要求我们自动生成它。
自动生成 vmap 规则¶
如果您的 torch.autograd.Function
满足以下附加约束,那么我们就能够为其生成 vmap 规则。如果它不满足约束,或者您希望在 vmap 下有自定义行为,请手动定义一个 vmap 静态方法(请参阅下一节)。
警告
我们无法轻松检查以下约束并优雅地报错。违反约束可能会导致未定义的行为。
torch.autograd.Function
的forward()
、backward()
(如果存在)和jvp()
(如果存在)静态方法必须可以通过torch.vmap()
进行变换。也就是说,它们必须只包含 PyTorch 运算(而不是例如 NumPy 或自定义 CUDA 内核)。
示例
class MyCube(torch.autograd.Function):
# Set generate_vmap_rule to True to ask PyTorch to automatically generate
# a vmap rule.
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定义 vmap 静态方法¶
如果您的 torch.autograd.Function
调用另一个系统(如 NumPy、C++、CUDA、triton),那么要使其与 torch.vmap()
或使用它的变换一起使用,您需要手动定义一个 vmap()
静态方法。
根据您要使用的变换和您的用例,您可能不需要向所有 torch.autograd.Function
添加 vmap()
静态方法
例如,
torch.func.jacrev()
在反向传递中执行vmap()
。因此,如果您只对使用torch.func.jacrev()
感兴趣,则只需要backward()
静态方法是可 vmap 的。
不过,我们确实建议确保所有 torch.autograd.Function
都支持 torch.vmap()
,尤其是在您编写第三方库并且希望您的 torch.autograd.Function
可以与 torch.func()
转换的所有组合一起使用时。
从概念上讲,vmap 静态方法负责定义 forward()
在 torch.vmap()
下的行为方式。也就是说,它定义了如何转换 forward()
以运行具有额外维度(被 vmap 映射的维度)的输入。这类似于 torch.vmap()
在 PyTorch 操作上的实现方式:对于每个操作,我们定义一个 vmap 规则(有时也称为“批处理规则”)。
以下是定义 vmap()
静态方法的方法
签名是
vmap(info, in_dims: Tuple[Optional[int]], *args)
,其中*args
与forward()
的参数相同。vmap 静态方法负责定义
forward()
在torch.vmap()
下的行为方式。也就是说,给定具有额外维度(由in_dims
指定)的输入,我们如何计算forward()
的批处理版本?对于
args
中的每个参数,in_dims
都有一个对应的Optional[int]
。如果参数不是张量或参数没有被 vmap 映射,则为None
,否则,它是一个整数,指定正在被 vmap 映射的张量的维度。info
是可能会有所帮助的附加元数据集合:info.batch_size
指定被 vmap 映射的维度的长度,而info.randomness
是传递给torch.vmap()
的randomness
选项。vmap 静态方法的返回值是一个
(output, out_dims)
元组。与in_dims
类似,out_dims
的结构应与output
相同,并且每个输出包含一个out_dim
,用于指定输出是否具有 vmap 映射的维度及其索引。
示例
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
# The signature of the vmap staticmethod is:
# vmap(info, in_dims: Tuple[Optional[int]], *args)
# where *args is the same as the arguments to `forward`.
@staticmethod
def vmap(info, in_dims, x, dim):
# For every input (x and dim), in_dims stores an Optional[int]
# that is:
# - None if the input is not being vmapped over or if the input
# is not a Tensor
# - an integer if the input is being vmapped over that represents
# the index of the dimension being vmapped over.
x_bdim, _ = in_dims
# A "vmap rule" is the logic of how to perform the operation given
# inputs with one additional dimension. In NumpySort, x has an
# additional dimension (x_bdim). The vmap rule is simply
# to call NumpySort again but pass it a different `dim`.
x = x.movedim(x_bdim, 0)
# Handle negative dims correctly
dim = dim if dim >= 0 else dim + x.dim() - 1
result = NumpySort.apply(x, dim + 1)
# The vmap rule must return a tuple of two things
# 1. the output. Should be the same amount of things
# as returned by the forward().
# 2. one Optional[int] for each output specifying if each output
# is being vmapped over, and if so, the index of the
# dimension being vmapped over.
#
# NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
# dimension being vmapped over to the front of `x`, that appears at
# dimension 0 of all outputs.
# The return is (output, out_dims) -- output is a tuple of 3 Tensors
# and out_dims is a Tuple of 3 Optional[int]
return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# The strategy is: expand {x, ind, ind_inv} to all have the dimension
# being vmapped over.
# Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
# Handle negative dims by wrapping them to be positive
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
# If the Tensor doesn't have the dimension being vmapped over,
# expand it out. Otherwise, move it to the front of the Tensor
x = maybe_expand_bdim_at_front(x, x_bdim)
ind = maybe_expand_bdim_at_front(ind, ind_bdim)
ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
# The return is a tuple (output, out_dims). Since output is a Tensor,
# then out_dims is an Optional[int] (instead of being a Tuple).
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
注意
vmap 静态方法应旨在保留整个 Function
的语义。也就是说,(伪代码)grad(vmap(MyFunc))
应该可以用 grad(map(MyFunc))
替换。
如果您的 autograd.Function 在反向传递中具有任何自定义行为,请记住这一点。
注意
为 PyTorch 能够通过 generate_vmap_rule=True
为其生成 vmap 规则的 Function
编写自定义 vmap 静态方法是一种合法的用例。如果您希望生成的 vmap 规则不具备您想要的语义,则可以这样做。
torch.func.jvp()
支持¶
为了支持前向模式 AD,torch.autograd.Function
必须有一个 jvp()
静态方法。有关详细信息,请参阅 前向模式 AD。