从 functorch 迁移到 torch.func¶
torch.func,之前称为 “functorch”,是 PyTorch 中类似 JAX 的可组合函数变换。
functorch 最初是 pytorch/functorch 仓库中的一个独立库。我们的目标始终是将 functorch 直接合并到 PyTorch 主干中,并将其作为核心 PyTorch 库提供。
作为合并主干的最后一步,我们决定从一个顶级包(functorch)迁移到成为 PyTorch 的一部分,以反映函数变换是如何直接集成到 PyTorch 核心中的。从 PyTorch 2.0 开始,我们弃用了 import functorch,并要求用户迁移到我们将继续维护的最新 API。import functorch 将保留几个版本,以保持向后兼容性。
函数变换¶
以下 API 可以直接替换以下 functorch API。它们完全向后兼容。
functorch API |
PyTorch API (截至 PyTorch 2.0) |
|---|---|
functorch.vmap |
|
functorch.grad |
|
functorch.vjp |
|
functorch.jvp |
|
functorch.jacrev |
|
functorch.jacfwd |
|
functorch.hessian |
|
functorch.functionalize |
此外,如果您正在使用 torch.autograd.functional API,请尝试使用相应的 torch.func 替代项。torch.func 函数变换在许多情况下更具可组合性和更高性能。
torch.autograd.functional API |
torch.func API (截至 PyTorch 2.0) |
|---|---|
NN 模块工具¶
我们更改了应用于 NN 模块的函数变换 API,使其更符合 PyTorch 的设计理念。新的 API 有所不同,请仔细阅读本节。
functorch.make_functional¶
torch.func.functional_call() 替代了 functorch.make_functional 和 functorch.make_functional_with_buffers。然而,它不能完全直接替换。
如果时间紧急,您可以使用此 Gist 中的辅助函数来模拟 functorch.make_functional 和 functorch.make_functional_with_buffers 的行为。我们建议直接使用 torch.func.functional_call(),因为它是一个更明确、更灵活的 API。
具体来说,functorch.make_functional 返回一个函数式模块和参数。该函数式模块接受参数和模型输入作为参数。torch.func.functional_call() 允许使用新的参数、缓冲区和输入调用现有模块的前向传递。
这里有一个使用 functorch 和 torch.func 计算模型参数梯度的示例
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
def compute_loss(params, inputs, targets):
prediction = fmodel(params, inputs)
return torch.nn.functional.mse_loss(prediction, targets)
grads = functorch.grad(compute_loss)(params, inputs, targets)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)
grads = torch.func.grad(compute_loss)(params, inputs, targets)
这里有一个计算模型参数 Jacobian 矩阵的示例
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
请注意,为了控制内存消耗,只应保留参数的单个副本,这一点很重要。model.named_parameters() 不会复制参数。如果在模型训练中原地更新模型参数,则模型中的 nn.Module 拥有参数的单个副本,一切正常。
但是,如果想将参数保存在字典中并非原地更新,则会有两份参数副本:一份在字典中,另一份在 model 中。在这种情况下,应该通过 model.to('meta') 将 model 转换为 meta 设备,使其不持有内存。
functorch.combine_state_for_ensemble¶
请使用 torch.func.stack_module_state() 替代 functorch.combine_state_for_ensemble。torch.func.stack_module_state() 返回两个字典,一个包含堆叠的参数,另一个包含堆叠的缓冲区,然后可以与 torch.vmap() 和 torch.func.functional_call() 一起用于集成(ensembling)。
例如,这里有一个关于如何对一个非常简单的模型进行集成(ensemble)的示例
import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy
# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')
params, buffers = torch.func.stack_module_state(models)
# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
return torch.func.functional_call(base_model, (params, buffers), (data,))
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
functorch.compile¶
我们不再支持将 functorch.compile(也称为 AOTAutograd)作为 PyTorch 编译的前端;我们已将 AOTAutograd 集成到 PyTorch 的编译体系中。如果您是用户,请转而使用 torch.compile()。