torch.func¶
torch.func,以前称为 “functorch”,是 PyTorch 的类似 JAX 的可组合函数变换。
注意
该库目前处于 beta 阶段。这意味着这些功能通常可以工作(除非另有文档说明),并且我们(PyTorch 团队)致力于推进该库的发展。但是,API 可能会根据用户反馈进行更改,并且我们没有完全覆盖 PyTorch 操作。
如果您对 API 或您希望涵盖的用例有建议,请打开 GitHub issue 或与我们联系。我们很乐意了解您如何使用该库。
什么是可组合函数变换?¶
“函数变换” 是一个高阶函数,它接受一个数值函数并返回一个新的函数,该函数计算不同的量。
torch.func
具有自动微分变换 (grad(f)
返回一个计算f
梯度的函数)、向量化/批处理变换 (vmap(f)
返回一个计算批量输入上f
的函数) 以及其他变换。这些函数变换可以任意组合。例如,组合
vmap(grad(f))
计算一个名为“per-sample-gradients” 的量,这是目前标准的 PyTorch 无法有效计算的。