functorch¶
functorch 是一个 类似 JAX 的 可组合函数变换,用于 PyTorch。
警告
我们已将 functorch 集成到 PyTorch 中。作为集成的最后一步,从 PyTorch 2.0 开始,functorch API 已弃用。请改用 torch.func API,并查看 迁移指南 和 文档 以获取更多详细信息。
什么是可组合函数变换?¶
“函数变换”是一种高阶函数,它接受一个数值函数并返回一个计算不同数量的新函数。
functorch 具有自动微分变换 (
grad(f)
返回一个计算f
梯度的函数)、向量化/批处理变换 (vmap(f)
返回一个计算f
在输入批次上的函数) 等等。这些函数变换可以任意相互组合。例如,组合
vmap(grad(f))
计算一个名为每样本梯度的量,而今天的 PyTorch 无法有效地计算。
为什么使用可组合函数变换?¶
今天,PyTorch 中有一些难以处理的用例
计算每样本梯度(或其他每样本量)
在一台机器上运行模型集合
在 MAML 的内部循环中有效地将任务批处理在一起
有效地计算雅可比矩阵和海森矩阵
有效地计算批处理的雅可比矩阵和海森矩阵
组合 vmap()
、grad()
和 vjp()
变换使我们能够表达上述内容,而无需为每个内容设计一个单独的子系统。这个可组合函数变换的想法来自 JAX 框架。