torch.func¶
torch.func,以前称为“functorch”,是用于 PyTorch 的类似 JAX 的可组合函数转换。
注意
此库当前处于 测试阶段。这意味着这些功能通常有效(除非有其他文档说明),并且我们(PyTorch 团队)致力于推广此库。但是,API 可能会根据用户反馈而更改,并且我们并未完全涵盖 PyTorch 操作。
如果您对 API 或希望涵盖的用例有任何建议,请打开 GitHub 问题或联系我们。我们很乐意了解您如何使用此库。
什么是可组合函数转换?¶
“函数转换”是一个高阶函数,它接受一个数值函数并返回一个计算不同值的函数。
torch.func
具有自动微分转换(grad(f)
返回一个计算f
梯度的函数)、向量化/批处理转换(vmap(f)
返回一个在批次输入上计算f
的函数)以及其他转换。这些函数转换可以任意组合在一起。例如,组合
vmap(grad(f))
计算一个称为每样本梯度的量,而现有的 PyTorch 无法高效地计算该量。