torch.func.hessian¶
- torch.func.hessian(func, argnums=0)[源代码]¶
通过前向-反向策略计算
func
相对于索引argnum
处的参数的 Hessian 矩阵。前向-反向策略(组合
jacfwd(jacrev(func))
)是获得良好性能的良好默认选择。可以通过jacfwd()
和jacrev()
的其他组合来计算 Hessian 矩阵,例如jacfwd(jacfwd(func))
或jacrev(jacrev(func))
。- 参数
func (函数) – 一个 Python 函数,接受一个或多个参数,其中一个参数必须是张量,并返回一个或多个张量
argnums (int 或 Tuple[int]) – 可选,整数或整数元组,说明要相对于哪个参数获取 Hessian 矩阵。默认值:0。
- 返回
返回一个函数,该函数接受与
func
相同的输入,并返回func
相对于argnums
处参数的 Hessian 矩阵。
注意
您可能会看到此 API 报错“算子 X 未实现前向模式 AD”。如果遇到这种情况,请提交错误报告,我们将优先处理。另一种选择是使用
jacrev(jacrev(func))
,它具有更好的算子覆盖率。R^N -> R^1 函数的基本用法给出了 N x N Hessian 矩阵
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))