• 文档 >
  • 使用 autograd.Function 扩展 torch.func
快捷方式

使用 autograd.Function 扩展 torch.func

因此,您希望将 torch.autograd.Functiontorch.func 变换(如 torch.vmap()torch.func.grad() 等)一起使用。

主要有两种使用场景

  • 您希望调用不包含 PyTorch 操作的代码,并使其与函数变换一起工作。也就是说,torch.autograd.Function 的 forward/backward/etc 调用来自其他系统(如 C++、CUDA、numpy)的函数。

  • 您希望指定自定义梯度规则,如 JAX 的 custom_vjp/custom_jvp

PyTorch 将这两个概念结合到 torch.autograd.Function 中。

基本用法

本指南假定您熟悉 扩展 torch.autograd,其中介绍了如何使用 torch.autograd.Function

torch.autograd.Function 可以有一个接受 ctx 对象的 forward(),也可以有单独的 forward()(不接受 ctx)和一个修改 ctx 对象的 setup_context() 静态方法。

函数变换仅支持后者

  • forward() 是执行操作的代码,它不应接受 ctx 对象。

  • setup_context(ctx, inputs, output) 是您可以在其中调用 ctx 方法的代码。您应该在此处保存用于反向传播的张量(通过调用 ctx.save_for_backward(*tensors)),或保存非张量(通过将它们分配给 ctx 对象)。

由于 setup_context() 仅接受 inputsoutput,因此可以保存的唯一量是输入或输出中的对象(例如张量)或从中导出的量(例如 Tensor.shape)。如果您希望从 Function.forward() 保存非输入中间激活值以进行反向传播,则需要将其作为 forward() 的输出返回,以便将其传递给 setup_context()

根据变换的不同,

为了使 torch.autograd.Function 可以与函数变换任意组合,我们建议除 forward()setup_context() 之外的所有其他静态方法都必须是可变换的:也就是说,它们必须仅由 PyTorch 运算符组成,或者调用其他 torch.autograd.Function(可能调用 C++/CUDA/etc)。

让我们回顾一些常见用例的示例。

示例 1:autograd.Function 调用另一个系统

一个常见的情况是 torch.autograd.Function 的 forward() 和 backward() 都调用另一个系统(如 C++、CUDA、numpy、triton)。

import torch
import numpy as np

def to_numpy(tensor):
    return tensor.cpu().numpy()

class NumpySort(torch.autograd.Function):
    # Note that forward does not take ctx
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        # Any intermediates to be saved in backward must be returned as
        # outputs.
        return (
            # The desired output
            torch.tensor(result, device=device),
            # intermediate to save for backward
            torch.tensor(ind, device=device),
            # intermediate to save for backward
            torch.tensor(ind_inv, device=device),
        )

    # setup_context is responsible for calling methods and/or assigning to
    # the ctx object. Please do not do additional compute (e.g. add
    # Tensors together) in setup_context.
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        # Note that output is whatever you returned from forward.
        # If you returned multiple values, then output is a Tuple of multiple values.
        # If you returned a single Tensor, then output is a Tensor.
        # If you returned a Tuple with a single Tensor, then output is a
        # Tuple with a single Tensor.
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        # Tensors must be saved via ctx.save_for_backward. Please do not
        # assign them directly onto the ctx object.
        ctx.save_for_backward(ind, ind_inv)
        # Non-tensors may be saved by assigning them as attributes on the ctx object.
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        # For the autograd.Function to be arbitrarily composable with function
        # transforms, all staticmethod other than forward and setup_context
        # must be implemented in a "transformable" way; that is, they must
        # only consist of PyTorch operations or autograd.Function.
        #
        # For example, this allows us to do double backwards and/or compute
        # second order gradients.
        #
        # We've written the backward pass of NumpySort in terms of another
        # autograd.Function, NumpyTake.
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None

class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None

现在,为了更轻松地使用 NumpySort(隐藏我们作为输出返回的中间值,并允许默认参数和关键字参数),我们创建一个新函数来调用它

def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result

这是一个健全性检查

x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))

示例 2:autograd.Function 指定自定义梯度规则

另一个常见情况是使用 PyTorch 操作实现的 torch.autograd.Function。PyTorch 能够自动计算 PyTorch 操作的梯度,但我们可能希望自定义梯度的计算方式。我们可能需要与 PyTorch 提供的反向传播不同的自定义反向传播的一些原因如下:

  • 提高数值稳定性

  • 更改反向传播的性能特征

  • 更改边缘情况的处理方式(例如,nans,inf)

  • 修改梯度(例如,梯度裁剪)

这是一个函数 y = x ** 3torch.autograd.Function 示例,我们在此处更改了性能特征(通常在反向传播期间发生的某些计算,即计算 dx,在正向传播中发生)。

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        result = x ** 3
        # In regular PyTorch, if we had just run y = x ** 3, then the backward
        # pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
        # that computation here in the forward pass instead.
        dx = 3 * x ** 2
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`.
        result = grad_output * dx + grad_dx * 6 * x
        return result

现在,为了更轻松地使用 NumpySort(并隐藏我们作为输出返回的中间值),我们创建一个新函数来调用它

def my_cube(x):
    result, _ = MyCube.apply(x)
    return result

这是一个计算二阶梯度的健全性检查

x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)

局限性和注意事项

警告

请仔细阅读 torch.autograd.Function 与 torch.func 变换一起使用的这些局限性。我们无法捕获许多这些情况并优雅地报错,因此它们将导致未定义的行为。

请勿将正在变换的、具有 requires_grad=True 或 dual tensors 的张量捕获到 torch.autograd.Function 的方法中。完全安全的方式是确保 torch.autograd.Function 的任何方法内部使用的唯一张量必须直接作为输入传递(或通过 ctx 对象传递),而不是来自 torch.autograd.Function 外部。

torch.autograd.Function 不处理 pytrees(可能包含或不包含张量的任意嵌套 Python 数据结构)中的张量。为了使这些张量被 autograd 跟踪,它们必须直接作为参数传递给 torch.autograd.Function。这与 jax.{custom_vjp, custom_jvp} 不同,后者接受 pytrees。

请仅使用 save_for_backward()save_for_forward() 来保存张量。请勿将张量或张量集合直接分配到 ctx 对象上 - 这些张量将不会被跟踪

torch.vmap() 支持

要将 torch.autograd.Functiontorch.vmap() 一起使用,您必须执行以下操作之一:

自动生成 vmap 规则

如果您的 torch.autograd.Function 满足以下附加约束,那么我们能够为其生成 vmap 规则。如果它不满足约束,或者如果您想要 vmap 下的自定义行为,请手动定义 vmap 静态方法(请参阅下一节)。

警告

我们无法轻易检查以下约束并优雅地报错。违反约束可能会导致未定义的行为。

示例

class MyCube(torch.autograd.Function):
    # Set generate_vmap_rule to True to ask PyTorch to automatically generate
    # a vmap rule.
    generate_vmap_rule = True

    @staticmethod
    def forward(x):
        result = x ** 3
        dx = 3 * x ** 2
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        result = grad_output * dx + grad_dx * 6 * x
        return result

def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)

定义 vmap 静态方法

如果您的 torch.autograd.Function 调用另一个系统(如 NumPy、C++、CUDA、triton),那么为了使其与 torch.vmap() 或使用它的变换一起工作,您需要手动定义 vmap() 静态方法。

根据您要使用的变换和您的用例,您可能不需要为所有 torch.autograd.Function 添加 vmap() 静态方法

我们建议确保您的所有 torch.autograd.Function 都支持 torch.vmap(),特别是当您编写第三方库并且希望您的 torch.autograd.Functiontorch.func() 变换的所有组合一起工作时。

从概念上讲,vmap 静态方法负责定义 forward()torch.vmap() 下应如何表现。也就是说,它定义了如何转换 forward() 以在具有附加维度(要进行 vmap 的维度)的输入上运行。这类似于 torch.vmap() 在 PyTorch 操作上的实现方式:对于每个操作,我们定义一个 vmap 规则(有时也称为“批处理规则”)。

以下是如何定义 vmap() 静态方法

  • 签名是 vmap(info, in_dims: Tuple[Optional[int]], *args),其中 *argsforward() 的参数相同。

  • vmap 静态方法负责定义 forward()torch.vmap() 下应如何表现。也就是说,给定具有附加维度(由 in_dims 指定)的输入,我们如何计算 forward() 的批处理版本?

  • 对于 args 中的每个 arg,in_dims 都有一个对应的 Optional[int]。如果 arg 不是张量或 arg 未进行 vmap,则为 None,否则,它是一个整数,指定张量的哪个维度正在进行 vmap。

  • info 是可能有用的其他元数据的集合:info.batch_size 指定要进行 vmap 的维度的大小,而 info.randomness 是传递给 torch.vmap()randomness 选项。

  • vmap 静态方法的返回值是一个 (output, out_dims) 元组。与 in_dims 类似,out_dims 应与 output 具有相同的结构,并且每个输出都包含一个 out_dim,用于指定输出是否具有 vmap 维度以及其在哪个索引中。

示例

def to_numpy(tensor):
    return tensor.cpu().numpy()

class NumpySort(torch.autograd.Function):
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        return (
            torch.tensor(result, device=device),
            torch.tensor(ind, device=device),
            torch.tensor(ind_inv, device=device),
        )

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None

    # The signature of the vmap staticmethod is:
    # vmap(info, in_dims: Tuple[Optional[int]], *args)
    # where *args is the same as the arguments to `forward`.
    @staticmethod
    def vmap(info, in_dims, x, dim):
        # For every input (x and dim), in_dims stores an Optional[int]
        # that is:
        # - None if the input is not being vmapped over or if the input
        #   is not a Tensor
        # - an integer if the input is being vmapped over that represents
        #   the index of the dimension being vmapped over.
        x_bdim, _ = in_dims

        # A "vmap rule" is the logic of how to perform the operation given
        # inputs with one additional dimension. In NumpySort, x has an
        # additional dimension (x_bdim). The vmap rule is simply
        # to call NumpySort again but pass it a different `dim`.
        x = x.movedim(x_bdim, 0)
        # Handle negative dims correctly
        dim = dim if dim >= 0 else dim + x.dim() - 1
        result = NumpySort.apply(x, dim + 1)

        # The vmap rule must return a tuple of two things
        # 1. the output. Should be the same amount of things
        #    as returned by the forward().
        # 2. one Optional[int] for each output specifying if each output
        # is being vmapped over, and if so, the index of the
        # dimension being vmapped over.
        #
        # NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
        # dimension being vmapped over to the front of `x`, that appears at
        # dimension 0 of all outputs.
        # The return is (output, out_dims) -- output is a tuple of 3 Tensors
        # and out_dims is a Tuple of 3 Optional[int]
        return NumpySort.apply(x, dim + 1), (0, 0, 0)

class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None

    @staticmethod
    def vmap(info, in_dims, x, ind, ind_inv, dim):
        x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims

        # The strategy is: expand {x, ind, ind_inv} to all have the dimension
        # being vmapped over.
        # Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).

        # Handle negative dims by wrapping them to be positive
        logical_dim = x.dim() if x_bdim is None else x_bdim - 1
        dim = dim if dim >= 0 else dim + logical_dim

        def maybe_expand_bdim_at_front(x, x_bdim):
            if x_bdim is None:
                return x.expand(info.batch_size, *x.shape)
            return x.movedim(x_bdim, 0)

        # If the Tensor doesn't have the dimension being vmapped over,
        # expand it out. Otherwise, move it to the front of the Tensor
        x = maybe_expand_bdim_at_front(x, x_bdim)
        ind = maybe_expand_bdim_at_front(ind, ind_bdim)
        ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)

        # The return is a tuple (output, out_dims). Since output is a Tensor,
        # then out_dims is an Optional[int] (instead of being a Tuple).
        return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0

def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result

x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))

注意

vmap 静态方法应旨在保留整个 Function 的语义。也就是说,(伪代码)grad(vmap(MyFunc)) 应可替换为 grad(map(MyFunc))

如果您的 autograd.Function 在反向传播中具有任何自定义行为,请牢记这一点。

注意

为 PyTorch 能够通过 generate_vmap_rule=True 生成 vmap 规则的 Function 编写自定义 vmap 静态方法是合法的用例。如果您希望生成的 vmap 规则不具有您要查找的语义,则可能希望这样做。

torch.func.jvp() 支持

为了支持前向模式 AD,torch.autograd.Function 必须具有 jvp() 静态方法。请参阅 前向模式 AD 以获取详细信息。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源