从 functorch 迁移到 torch.func¶
torch.func,以前称为“functorch”,是 PyTorch 的 JAX 式 可组合函数转换。
functorch 最初是 pytorch/functorch 存储库中的一个树外库。我们的目标始终是将 functorch 直接上游到 PyTorch 并将其作为 PyTorch 核心库提供。
作为上游的最后一步,我们决定从作为顶级包 (functorch
) 迁移到成为 PyTorch 的一部分,以反映函数转换如何直接集成到 PyTorch 核心。从 PyTorch 2.0 开始,我们弃用 import functorch
并要求用户迁移到最新的 API,我们将继续维护这些 API。import functorch
将保留一段时间以维护向后兼容性。
函数转换¶
以下 API 是对以下 functorch API 的直接替换。它们完全向后兼容。
functorch API |
PyTorch API(从 PyTorch 2.0 开始) |
---|---|
functorch.vmap |
|
functorch.grad |
|
functorch.vjp |
|
functorch.jvp |
|
functorch.jacrev |
|
functorch.jacfwd |
|
functorch.hessian |
|
functorch.functionalize |
此外,如果您使用的是 torch.autograd.functional API,请尝试使用 torch.func
等效项。在许多情况下,torch.func
函数转换更具可组合性并且性能更高。
torch.autograd.functional API |
torch.func API(从 PyTorch 2.0 开始) |
---|---|
NN 模块实用程序¶
我们已经更改了 API 以对 NN 模块应用函数转换,以使它们更符合 PyTorch 设计理念。新 API 不同,因此请仔细阅读本节。
functorch.make_functional¶
torch.func.functional_call()
是 functorch.make_functional 和 functorch.make_functional_with_buffers 的替代品。但是,它不是直接替换。
如果您时间紧迫,可以使用 此 gist 中的辅助函数 来模拟 functorch.make_functional 和 functorch.make_functional_with_buffers 的行为。我们建议直接使用 torch.func.functional_call()
,因为它是一个更明确且更灵活的 API。
具体来说,functorch.make_functional 返回一个函数化模块和参数。函数化模块接受参数和模型输入作为参数。 torch.func.functional_call()
允许使用新的参数、缓冲区和输入来调用现有模块的前向传递。
以下是如何使用 functorch 和 torch.func
计算模型参数梯度的示例
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
def compute_loss(params, inputs, targets):
prediction = fmodel(params, inputs)
return torch.nn.functional.mse_loss(prediction, targets)
grads = functorch.grad(compute_loss)(params, inputs, targets)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)
grads = torch.func.grad(compute_loss)(params, inputs, targets)
以下是如何计算模型参数雅可比矩阵的示例
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
请注意,为了节省内存,您应该只保留参数的单个副本。 model.named_parameters()
不会复制参数。如果您在模型训练中就地更新模型的参数,那么作为模型的 nn.Module
拥有参数的单个副本,一切正常。
但是,如果您想将参数保存在字典中并进行非就地更新,那么将有两个参数副本:字典中的一个和 model
中的一个。在这种情况下,您应该通过将 model
转换为元设备(使用 model.to('meta')
)来阻止它占用内存。
functorch.combine_state_for_ensemble¶
请使用 torch.func.stack_module_state()
而不是 functorch.combine_state_for_ensemble torch.func.stack_module_state()
返回两个字典,一个包含堆叠的参数,另一个包含堆叠的缓冲区,这些字典可以与 torch.vmap()
和 torch.func.functional_call()
一起用于集成。
例如,以下是一个在非常简单的模型上进行集成的示例
import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy
# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')
params, buffers = torch.func.stack_module_state(models)
# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
return torch.func.functional_call(base_model, (params, buffers), (data,))
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
functorch.compile¶
我们不再支持 functorch.compile(也称为 AOTAutograd)作为 PyTorch 中编译的前端;我们已将 AOTAutograd 集成到 PyTorch 的编译流程中。如果您是用户,请使用 torch.compile()
替代。