torch.func¶
torch.func,以前称为“functorch”,是 JAX-like 用于 PyTorch 的可组合函数转换。
注意
此库目前处于 beta 阶段。这意味着这些功能通常有效(除非另有说明),我们(PyTorch 团队)致力于将此库向前发展。但是,API 可能会根据用户反馈而更改,并且我们尚未完全覆盖 PyTorch 操作。
如果您对 API 或您希望涵盖的用例有任何建议,请打开 GitHub 问题或与我们联系。我们很乐意了解您如何使用此库。
什么是可组合函数转换?¶
“函数变换”是一种高阶函数,它接受一个数值函数并返回一个计算不同数量的新函数。
torch.func
具有自动微分变换(grad(f)
返回一个计算f
梯度的函数),向量化/批处理变换(vmap(f)
返回一个在输入批次上计算f
的函数),以及其他。这些函数变换可以任意组合。例如,组合
vmap(grad(f))
计算一个称为每样本梯度的量,而现有的 PyTorch 无法有效地计算它。