torch.func.grad_and_value¶ torch.func.grad_and_value(func, argnums=0, has_aux=False)¶ 返回一个函数,用于计算梯度和原始(或前向)计算的元组。 参数 func (Callable) – 一个 Python 函数,它接受一个或多个参数。必须返回一个单元素张量。如果指定的 has_aux 等于 True,则函数可以返回一个单元素张量和其他辅助对象的元组: (output, aux)。 argnums (int 或 Tuple[int]) – 指定要计算梯度的参数。 argnums 可以是单个整数或整数元组。默认值:0。 has_aux (bool) – 标志,指示 func 返回一个张量和其他辅助对象: (output, aux)。默认值:False。 返回值 计算其输入的梯度和前向计算的元组的函数。默认情况下,函数的输出是关于第一个参数的梯度张量和原始计算的元组。如果指定的 has_aux 等于 True,则返回梯度元组和输出辅助对象的正向计算元组。如果 argnums 是一个整数元组,则返回关于每个 argnums 值的输出梯度的元组和前向计算的元组。 返回类型 Callable 请参阅 grad() 以获取示例