torch.func.grad_and_value¶
- torch.func.grad_and_value(func, argnums=0, has_aux=False)[source]¶
返回一个函数,用于计算梯度和原始(或前向)计算的元组。
- 参数
- 返回
返回一个函数,用于计算相对于输入的梯度和前向计算的元组。默认情况下,该函数的输出是相对于第一个参数的梯度 tensor(s) 和原始计算的元组。如果指定了
has_aux
等于True
,则返回梯度元组和包含输出辅助对象的前向计算元组。如果argnums
是一个整数元组,则返回一个元组,该元组包含相对于每个argnums
值的输出梯度元组以及前向计算。- 返回类型
参见
grad()
获取示例