快捷方式

从 functorch 迁移到 torch.func

torch.func,以前称为“functorch”,是 JAX 风格的 PyTorch 可组合函数转换。

functorch 最初是一个在 pytorch/functorch 存储库中的树外库。我们的目标始终是将 functorch 直接上游到 PyTorch,并将其作为核心 PyTorch 库提供。

作为上游的最后一步,我们决定从成为顶级包 (functorch) 迁移到成为 PyTorch 的一部分,以反映函数转换是如何直接集成到 PyTorch 核心中的。从 PyTorch 2.0 开始,我们不建议使用 import functorch,并要求用户迁移到最新的 API,我们将继续维护这些 API。import functorch 将保留一段时间以保持向后兼容性,持续几个版本。

函数转换

以下 API 可以替代以下 functorch API。它们完全向后兼容。

functorch API

PyTorch API (从 PyTorch 2.0 开始)

functorch.vmap

torch.vmap()torch.func.vmap()

functorch.grad

torch.func.grad()

functorch.vjp

torch.func.vjp()

functorch.jvp

torch.func.jvp()

functorch.jacrev

torch.func.jacrev()

functorch.jacfwd

torch.func.jacfwd()

functorch.hessian

torch.func.hessian()

functorch.functionalize

torch.func.functionalize()

此外,如果您正在使用 torch.autograd.functional API,请尝试使用 torch.func 等效项。在许多情况下,torch.func 函数转换更具可组合性和更高性能。

torch.autograd.functional API

torch.func API (从 PyTorch 2.0 开始)

torch.autograd.functional.vjp()

torch.func.grad()torch.func.vjp()

torch.autograd.functional.jvp()

torch.func.jvp()

torch.autograd.functional.jacobian()

torch.func.jacrev()torch.func.jacfwd()

torch.autograd.functional.hessian()

torch.func.hessian()

NN 模块实用程序

我们更改了 API 以将函数转换应用于 NN 模块,使它们更符合 PyTorch 设计理念。新的 API 与以前不同,因此请仔细阅读本节。

functorch.make_functional

torch.func.functional_call()functorch.make_functionalfunctorch.make_functional_with_buffers 的替代方案。但是,它不是直接替代品。

如果您时间紧迫,可以使用 此 gist 中的辅助函数 来模拟 functorch.make_functional 和 functorch.make_functional_with_buffers 的行为。我们建议直接使用 torch.func.functional_call(),因为它是一个更明确且更灵活的 API。

具体来说,functorch.make_functional 返回一个函数化模块和参数。函数化模块接受参数和模型输入作为参数。 torch.func.functional_call() 允许使用新的参数和缓冲区以及输入来调用现有模块的前向传播。

以下是如何使用 functorch 与 torch.func 计算模型参数梯度的示例

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)

def compute_loss(params, inputs, targets):
    prediction = fmodel(params, inputs)
    return torch.nn.functional.mse_loss(prediction, targets)

grads = functorch.grad(compute_loss)(params, inputs, targets)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())

def compute_loss(params, inputs, targets):
    prediction = torch.func.functional_call(model, params, (inputs,))
    return torch.nn.functional.mse_loss(prediction, targets)

grads = torch.func.grad(compute_loss)(params, inputs, targets)

以下是如何计算模型参数雅可比矩阵的示例

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))

请注意,为了节省内存,您应该只保留参数的单个副本。 model.named_parameters() 不会复制参数。如果在模型训练中,您对模型的参数进行了就地更新,那么您的模型中的 nn.Module 将拥有参数的单个副本,一切都正常。

但是,如果您想将参数保存在字典中并进行非就地更新,那么将存在两个参数副本:字典中的一个和 model 中的一个。在这种情况下,您应该通过 model.to('meta')model 转换为元设备,使其不再保留内存。

functorch.combine_state_for_ensemble

请使用 torch.func.stack_module_state() 而不是 functorch.combine_state_for_ensemble torch.func.stack_module_state() 返回两个字典,一个包含堆叠的参数,另一个包含堆叠的缓冲区,然后可以与 torch.vmap()torch.func.functional_call() 一起用于集成。

例如,以下是如何在一个非常简单的模型上进行集成的示例

import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy

# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')

params, buffers = torch.func.stack_module_state(models)

# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
    return torch.func.functional_call(base_model, (params, buffers), (data,))

output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

functorch.compile

我们不再支持 functorch.compile(也称为 AOTAutograd)作为 PyTorch 中编译的前端;我们将 AOTAutograd 集成到 PyTorch 的编译故事中。如果您是用户,请使用 torch.compile() 替代。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获得针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源