torch.autograd.Function.backward¶
- static Function.backward(ctx, *grad_outputs)¶
使用反向模式自动微分定义一个用于微分运算的公式。
所有子类都必须重写此函数。(定义此函数等效于定义
vjp
函数。)它必须接受一个上下文
ctx
作为第一个参数,然后是与forward()
返回的输出一样多的输出(对于前向函数的非张量输出将传入 None),它应该返回与forward()
输入一样多的张量。每个参数都是相对于给定输出的梯度,每个返回值应该是相对于相应输入的梯度。如果输入不是张量或不是需要梯度的张量,则可以为该输入传递 None 作为梯度。上下文可用于检索在正向传递期间保存的张量。它还具有一个属性
ctx.needs_input_grad
,它是一个布尔值元组,表示每个输入是否需要梯度。例如,backward()
将具有ctx.needs_input_grad[0] = True
,如果forward()
的第一个输入需要计算相对于输出的梯度。- 返回类型