快捷方式

torch.func.grad_and_value

torch.func.grad_and_value(func, argnums=0, has_aux=False)[source]

返回一个函数,用于计算梯度和原始(或前向)计算的元组。

参数
  • func (Callable) – 一个接受一个或多个参数的 Python 函数。必须返回一个单元素 Tensor。如果指定了 has_aux 等于 True,函数可以返回一个包含单元素 Tensor 和其他辅助对象的元组: (output, aux)

  • argnums (intTuple[int]) – 指定要计算梯度的参数。argnums 可以是一个整数或整数元组。默认值:0。

  • has_aux (bool) – 一个标志,指示 func 返回一个 tensor 和其他辅助对象: (output, aux)。默认值:False。

返回

返回一个函数,用于计算相对于输入的梯度和前向计算的元组。默认情况下,该函数的输出是相对于第一个参数的梯度 tensor(s) 和原始计算的元组。如果指定了 has_aux 等于 True,则返回梯度元组和包含输出辅助对象的前向计算元组。如果 argnums 是一个整数元组,则返回一个元组,该元组包含相对于每个 argnums 值的输出梯度元组以及前向计算。

返回类型

Callable

参见 grad() 获取示例

文档

查阅 PyTorch 全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源