快捷方式

从 functorch 迁移到 torch.func

torch.func,之前称为 “functorch”,是 PyTorch 中类似 JAX 的可组合函数变换。

functorch 最初是 pytorch/functorch 仓库中的一个独立库。我们的目标始终是将 functorch 直接合并到 PyTorch 主干中,并将其作为核心 PyTorch 库提供。

作为合并主干的最后一步,我们决定从一个顶级包(functorch)迁移到成为 PyTorch 的一部分,以反映函数变换是如何直接集成到 PyTorch 核心中的。从 PyTorch 2.0 开始,我们弃用了 import functorch,并要求用户迁移到我们将继续维护的最新 API。import functorch 将保留几个版本,以保持向后兼容性。

函数变换

以下 API 可以直接替换以下 functorch API。它们完全向后兼容。

functorch API

PyTorch API (截至 PyTorch 2.0)

functorch.vmap

torch.vmap()torch.func.vmap()

functorch.grad

torch.func.grad()

functorch.vjp

torch.func.vjp()

functorch.jvp

torch.func.jvp()

functorch.jacrev

torch.func.jacrev()

functorch.jacfwd

torch.func.jacfwd()

functorch.hessian

torch.func.hessian()

functorch.functionalize

torch.func.functionalize()

此外,如果您正在使用 torch.autograd.functional API,请尝试使用相应的 torch.func 替代项。torch.func 函数变换在许多情况下更具可组合性和更高性能。

torch.autograd.functional API

torch.func API (截至 PyTorch 2.0)

torch.autograd.functional.vjp()

torch.func.grad()torch.func.vjp()

torch.autograd.functional.jvp()

torch.func.jvp()

torch.autograd.functional.jacobian()

torch.func.jacrev()torch.func.jacfwd()

torch.autograd.functional.hessian()

torch.func.hessian()

NN 模块工具

我们更改了应用于 NN 模块的函数变换 API,使其更符合 PyTorch 的设计理念。新的 API 有所不同,请仔细阅读本节。

functorch.make_functional

torch.func.functional_call() 替代了 functorch.make_functionalfunctorch.make_functional_with_buffers。然而,它不能完全直接替换。

如果时间紧急,您可以使用此 Gist 中的辅助函数来模拟 functorch.make_functional 和 functorch.make_functional_with_buffers 的行为。我们建议直接使用 torch.func.functional_call(),因为它是一个更明确、更灵活的 API。

具体来说,functorch.make_functional 返回一个函数式模块和参数。该函数式模块接受参数和模型输入作为参数。torch.func.functional_call() 允许使用新的参数、缓冲区和输入调用现有模块的前向传递。

这里有一个使用 functorch 和 torch.func 计算模型参数梯度的示例

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)

def compute_loss(params, inputs, targets):
    prediction = fmodel(params, inputs)
    return torch.nn.functional.mse_loss(prediction, targets)

grads = functorch.grad(compute_loss)(params, inputs, targets)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())

def compute_loss(params, inputs, targets):
    prediction = torch.func.functional_call(model, params, (inputs,))
    return torch.nn.functional.mse_loss(prediction, targets)

grads = torch.func.grad(compute_loss)(params, inputs, targets)

这里有一个计算模型参数 Jacobian 矩阵的示例

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))

请注意,为了控制内存消耗,只应保留参数的单个副本,这一点很重要。model.named_parameters() 不会复制参数。如果在模型训练中原地更新模型参数,则模型中的 nn.Module 拥有参数的单个副本,一切正常。

但是,如果想将参数保存在字典中并非原地更新,则会有两份参数副本:一份在字典中,另一份在 model 中。在这种情况下,应该通过 model.to('meta')model 转换为 meta 设备,使其不持有内存。

functorch.combine_state_for_ensemble

请使用 torch.func.stack_module_state() 替代 functorch.combine_state_for_ensembletorch.func.stack_module_state() 返回两个字典,一个包含堆叠的参数,另一个包含堆叠的缓冲区,然后可以与 torch.vmap()torch.func.functional_call() 一起用于集成(ensembling)。

例如,这里有一个关于如何对一个非常简单的模型进行集成(ensemble)的示例

import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy

# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')

params, buffers = torch.func.stack_module_state(models)

# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
    return torch.func.functional_call(base_model, (params, buffers), (data,))

output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

functorch.compile

我们不再支持将 functorch.compile(也称为 AOTAutograd)作为 PyTorch 编译的前端;我们已将 AOTAutograd 集成到 PyTorch 的编译体系中。如果您是用户,请转而使用 torch.compile()

文档

查阅 PyTorch 全面的开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源