快捷方式

functorch.grad_and_value

functorch.grad_and_value(func, argnums=0, has_aux=False)[source]

返回一个函数,用于计算梯度和原始(或前向)计算的元组。

参数
  • func (Callable) – 一个 Python 函数,它接受一个或多个参数。必须返回一个单元素张量。如果指定 has_aux 等于 True,则函数可以返回一个单元素张量和其他辅助对象的元组:(output, aux)

  • argnums (intTuple[int]) – 指定要计算其梯度的参数。argnums 可以是单个整数或整数元组。默认值:0。

  • has_aux (bool) – 一个标志,指示 func 返回一个张量和其他辅助对象:(output, aux)。默认值:False。

返回值

用于计算其输入相对于其输入的梯度和前向计算的元组的函数。默认情况下,函数的输出是相对于第一个参数的梯度张量和原始计算的元组。如果指定 has_aux 等于 True,则返回梯度元组和带有输出辅助对象的正向计算元组。如果 argnums 是一个整数元组,则返回一个元组,该元组包含相对于每个 argnums 值的输出梯度的元组和前向计算。

请参阅 grad() 以获取示例

警告

我们已将 functorch 集成到 PyTorch 中。作为集成的最后一步,从 PyTorch 2.0 开始,functorch.grad_and_value 已弃用,并在 PyTorch >= 2.3 的未来版本中删除。请改用 torch.func.grad_and_value;有关更多详细信息,请参阅 PyTorch 2.0 发行说明和/或 torch.func 迁移指南 https://pytorch.ac.cn/docs/master/func.migrating.html

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源