functorch.hessian¶
-
functorch.
hessian
(func, argnums=0)[源代码]¶ 使用前向-反向策略计算
func
相对于索引argnum
处的参数的海森矩阵。前向-反向策略(组合
jacfwd(jacrev(func))
)是良好性能的良好默认选择。可以通过其他组合jacfwd()
和jacrev()
来计算海森矩阵,例如jacfwd(jacfwd(func))
或jacrev(jacrev(func))
。- 参数
- 返回值
返回一个函数,该函数接受与
func
相同的输入,并返回func
相对于argnums
处的参数的海森矩阵。
注意
您可能会看到此 API 出错并显示“未为运算符 X 实现前向模式自动微分”。如果是这样,请提交错误报告,我们将优先处理。另一种方法是使用
jacrev(jacrev(func))
,它具有更好的运算符覆盖率。一个使用 R^N -> R^1 函数的基本用法会得到一个 N x N 海森矩阵
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))
警告
我们已将 functorch 集成到 PyTorch 中。作为集成的最后一步,functorch.hessian 从 PyTorch 2.0 开始已弃用,并将从未来版本的 PyTorch >= 2.3 中删除。请改用 torch.func.hessian;有关更多详细信息,请参阅 PyTorch 2.0 版本说明或 torch.func 迁移指南 https://pytorch.ac.cn/docs/master/func.migrating.html