torch.func¶
torch.func,以前称为“functorch”,是 PyTorch 中类似 JAX 的可组合函数变换。
注意
该库目前处于 测试阶段(beta)。这意味着其功能通常可以工作(除非另有说明),并且我们(PyTorch 团队)致力于推进该库。然而,API 可能会根据用户反馈进行更改,而且我们尚未完全覆盖所有 PyTorch 操作。
如果您对 API 或希望涵盖的用例有任何建议,请在 GitHub 上提交 issue 或联系我们。我们很乐意了解您如何使用该库。
什么是可组合函数变换?¶
“函数变换”是一种高阶函数,它接受一个数值函数,并返回一个计算不同量的新函数。
torch.func
包含自动微分变换(grad(f)
返回计算f
梯度的函数)、向量化/批处理变换(vmap(f)
返回在输入批量上计算f
的函数)等。这些函数变换可以任意组合。例如,组合
vmap(grad(f))
可以计算称为“每样本梯度”的量,而当前的 PyTorch 尚无法高效计算此量。