快捷方式

torch.func

torch.func,以前称为“functorch”,是 JAX-like 用于 PyTorch 的可组合函数转换。

注意

此库目前处于 beta 阶段。这意味着这些功能通常有效(除非另有说明),我们(PyTorch 团队)致力于将此库向前发展。但是,API 可能会根据用户反馈而更改,并且我们尚未完全覆盖 PyTorch 操作。

如果您对 API 或您希望涵盖的用例有任何建议,请打开 GitHub 问题或与我们联系。我们很乐意了解您如何使用此库。

什么是可组合函数转换?

  • “函数变换”是一种高阶函数,它接受一个数值函数并返回一个计算不同数量的新函数。

  • torch.func 具有自动微分变换(grad(f) 返回一个计算 f 梯度的函数),向量化/批处理变换(vmap(f) 返回一个在输入批次上计算 f 的函数),以及其他。

  • 这些函数变换可以任意组合。例如,组合 vmap(grad(f)) 计算一个称为每样本梯度的量,而现有的 PyTorch 无法有效地计算它。

为什么可组合函数变换?

有一些用例在今天的 PyTorch 中很难实现

  • 计算每样本梯度(或其他每样本量)

  • 在一台机器上运行模型集合

  • 在 MAML 的内循环中有效地将任务批处理在一起

  • 有效地计算雅可比矩阵和海森矩阵

  • 有效地计算批处理雅可比矩阵和海森矩阵

组合 vmap()grad()vjp() 变换使我们能够表达上述内容,而无需为每个内容设计一个单独的子系统。这种可组合函数变换的想法来自 JAX 框架

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源