torch.autograd.Function.forward¶
- static Function.forward(*args, **kwargs)[源]¶
定义自定义 autograd Function 的前向传播。
所有子类都必须重写此函数。有两种定义前向传播的方法
用法 1(合并前向传播和 ctx)
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必须接受一个上下文 ctx 作为第一个参数,后面可以跟任意数量的参数(张量或其他类型)。
更多详细信息请参阅 合并或分离 forward() 和 setup_context()
用法 2(分离前向传播和 ctx)
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
forward 不再接受 ctx 参数。
相反,您还必须重写
torch.autograd.Function.setup_context()
静态方法来处理ctx
对象的设置。output
是前向传播的输出,inputs
是前向传播输入的元组(Tuple)。更多详细信息请参阅 扩展 torch.autograd
上下文可用于存储可在反向传播过程中检索的任意数据。张量不应直接存储在 ctx 上(尽管出于向后兼容性目前并未强制执行)。相反,如果张量打算用于
backward
(等同于vjp
),则应使用ctx.save_for_backward()
保存;如果张量打算用于jvp
,则应使用ctx.save_for_forward()
保存。- 返回类型