快捷方式

functorch.hessian

functorch.hessian(func, argnums=0)[源代码]

使用前向-反向策略计算 func 相对于索引 argnum 处的参数的海森矩阵。

前向-反向策略(组合 jacfwd(jacrev(func)))是良好性能的良好默认选择。可以通过其他组合 jacfwd()jacrev() 来计算海森矩阵,例如 jacfwd(jacfwd(func))jacrev(jacrev(func))

参数
  • func (函数) – 一个 Python 函数,它接受一个或多个参数,其中一个必须是张量,并返回一个或多个张量

  • argnums (整数元组[整数]) – 可选的,整数或整数元组,表示要获取相对于哪些参数的海森矩阵。默认值:0。

返回值

返回一个函数,该函数接受与 func 相同的输入,并返回 func 相对于 argnums 处的参数的海森矩阵。

注意

您可能会看到此 API 出错并显示“未为运算符 X 实现前向模式自动微分”。如果是这样,请提交错误报告,我们将优先处理。另一种方法是使用 jacrev(jacrev(func)),它具有更好的运算符覆盖率。

一个使用 R^N -> R^1 函数的基本用法会得到一个 N x N 海森矩阵

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))

警告

我们已将 functorch 集成到 PyTorch 中。作为集成的最后一步,functorch.hessian 从 PyTorch 2.0 开始已弃用,并将从未来版本的 PyTorch >= 2.3 中删除。请改用 torch.func.hessian;有关更多详细信息,请参阅 PyTorch 2.0 版本说明或 torch.func 迁移指南 https://pytorch.ac.cn/docs/master/func.migrating.html

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源