快捷方式

torch.func

torch.func,以前称为 “functorch”,是 PyTorch 的类似 JAX 的可组合函数变换。

注意

该库目前处于 beta 阶段。这意味着这些功能通常可以工作(除非另有文档说明),并且我们(PyTorch 团队)致力于推进该库的发展。但是,API 可能会根据用户反馈进行更改,并且我们没有完全覆盖 PyTorch 操作。

如果您对 API 或您希望涵盖的用例有建议,请打开 GitHub issue 或与我们联系。我们很乐意了解您如何使用该库。

什么是可组合函数变换?

  • “函数变换” 是一个高阶函数,它接受一个数值函数并返回一个新的函数,该函数计算不同的量。

  • torch.func 具有自动微分变换 (grad(f) 返回一个计算 f 梯度的函数)、向量化/批处理变换 (vmap(f) 返回一个计算批量输入上 f 的函数) 以及其他变换。

  • 这些函数变换可以任意组合。例如,组合 vmap(grad(f)) 计算一个名为“per-sample-gradients” 的量,这是目前标准的 PyTorch 无法有效计算的。

为什么使用可组合函数变换?

在 PyTorch 中,有很多用例目前很难实现

  • 计算 per-sample-gradients(或其他 per-sample 量)

  • 在单台机器上运行模型集成

  • 有效地批量处理 MAML 内循环中的任务

  • 高效地计算雅可比矩阵和 Hessian 矩阵

  • 高效地计算批量雅可比矩阵和 Hessian 矩阵

组合 vmap(), grad(), 和 vjp() 变换使我们能够表达上述内容,而无需为每个用例设计单独的子系统。这种可组合函数变换的思想来自 JAX 框架

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源