快捷方式

functorch

functorch 是一个 类似 JAX 的 可组合函数变换,用于 PyTorch。

警告

我们已将 functorch 集成到 PyTorch 中。作为集成的最后一步,从 PyTorch 2.0 开始,functorch API 已弃用。请改用 torch.func API,并查看 迁移指南文档 以获取更多详细信息。

什么是可组合函数变换?

  • “函数变换”是一种高阶函数,它接受一个数值函数并返回一个计算不同数量的新函数。

  • functorch 具有自动微分变换 (grad(f) 返回一个计算 f 梯度的函数)、向量化/批处理变换 (vmap(f) 返回一个计算 f 在输入批次上的函数) 等等。

  • 这些函数变换可以任意相互组合。例如,组合 vmap(grad(f)) 计算一个名为每样本梯度的量,而今天的 PyTorch 无法有效地计算。

为什么使用可组合函数变换?

今天,PyTorch 中有一些难以处理的用例

  • 计算每样本梯度(或其他每样本量)

  • 在一台机器上运行模型集合

  • 在 MAML 的内部循环中有效地将任务批处理在一起

  • 有效地计算雅可比矩阵和海森矩阵

  • 有效地计算批处理的雅可比矩阵和海森矩阵

组合 vmap()grad()vjp() 变换使我们能够表达上述内容,而无需为每个内容设计一个单独的子系统。这个可组合函数变换的想法来自 JAX 框架

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获取问题解答

查看资源