torch.autograd.Function.forward¶
- static Function.forward(*args, **kwargs)¶
定义自定义 autograd Function 的前向传播。
此函数需要由所有子类覆盖。有两种方法可以定义前向传播
用法 1(组合前向传播和 ctx)
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必须接受上下文 ctx 作为第一个参数,后面可以跟任意数量的参数(张量或其他类型)。
有关更多详细信息,请参阅组合或分开的 forward() 和 setup_context()
用法 2(分开的 forward 和 ctx)
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
forward 不再接受 ctx 参数。
相反,您还必须覆盖
torch.autograd.Function.setup_context()
静态方法来处理设置ctx
对象。output
是前向传播的输出,inputs
是前向传播输入的元组。有关更多详细信息,请参阅扩展 torch.autograd
上下文可用于存储任意数据,这些数据随后可以在反向传播过程中检索。不应将张量直接存储在ctx 上(尽管出于向后兼容性的原因,目前没有强制执行此操作)。相反,应使用
ctx.save_for_backward()
保存张量,如果打算在backward
(等效地,vjp
)中使用,或者使用ctx.save_for_forward()
保存张量,如果打算在jvp
中使用。- 返回类型