torch.autograd.Function.forward¶
- static Function.forward(*args, **kwargs)[来源]¶
定义自定义 autograd Function 的前向传播。
此函数需要由所有子类重写。定义前向传播有两种方式
用法 1(组合前向传播和 ctx)
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必须接受一个上下文 ctx 作为第一个参数,后跟任意数量的参数(张量或其他类型)。
有关更多详细信息,请参阅 组合或分离 forward() 和 setup_context()
用法 2(分离前向传播和 ctx)
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
前向传播不再接受 ctx 参数。
相反,您还必须重写
torch.autograd.Function.setup_context()
静态方法来处理设置ctx
对象。output
是前向传播的输出,inputs
是前向传播的输入的元组。有关更多详细信息,请参阅 扩展 torch.autograd
上下文可用于存储任意数据,这些数据随后可在反向传播期间检索。张量不应直接存储在 ctx 上(尽管目前为了向后兼容性,这并非强制执行)。相反,张量应使用
ctx.save_for_backward()
保存(如果它们打算在backward
中使用(等效地,vjp
)),或者使用ctx.save_for_forward()
保存(如果它们打算在jvp
中使用)。- 返回类型