快捷方式

torch.func

torch.func,以前称为“functorch”,是 PyTorch 中类似 JAX 的可组合函数变换。

注意

该库目前处于 测试阶段(beta)。这意味着其功能通常可以工作(除非另有说明),并且我们(PyTorch 团队)致力于推进该库。然而,API 可能会根据用户反馈进行更改,而且我们尚未完全覆盖所有 PyTorch 操作。

如果您对 API 或希望涵盖的用例有任何建议,请在 GitHub 上提交 issue 或联系我们。我们很乐意了解您如何使用该库。

什么是可组合函数变换?

  • “函数变换”是一种高阶函数,它接受一个数值函数,并返回一个计算不同量的新函数。

  • torch.func 包含自动微分变换(grad(f) 返回计算 f 梯度的函数)、向量化/批处理变换(vmap(f) 返回在输入批量上计算 f 的函数)等。

  • 这些函数变换可以任意组合。例如,组合 vmap(grad(f)) 可以计算称为“每样本梯度”的量,而当前的 PyTorch 尚无法高效计算此量。

为什么选择可组合函数变换?

目前在 PyTorch 中实现许多用例比较棘手

  • 计算每样本梯度(或其他每样本量)

  • 在单台机器上运行模型集成

  • 在 MAML 内循环中高效地将任务批量处理

  • 高效计算 Jacobian 和 Hessian

  • 高效计算批量 Jacobian 和 Hessian

组合 vmap()grad()vjp() 变换,使我们无需为每个用例设计单独的子系统即可实现上述功能。这种可组合函数变换的思想源自 JAX 框架

文档

查阅全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源