快捷方式

torch.autograd.Function.forward

static Function.forward(*args, **kwargs)[源]

定义自定义 autograd Function 的前向传播。

所有子类都必须重写此函数。有两种定义前向传播的方法

用法 1(合并前向传播和 ctx)

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

用法 2(分离前向传播和 ctx)

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • forward 不再接受 ctx 参数。

  • 相反,您还必须重写 torch.autograd.Function.setup_context() 静态方法来处理 ctx 对象的设置。 output 是前向传播的输出,inputs 是前向传播输入的元组(Tuple)。

  • 更多详细信息请参阅 扩展 torch.autograd

上下文可用于存储可在反向传播过程中检索的任意数据。张量不应直接存储在 ctx 上(尽管出于向后兼容性目前并未强制执行)。相反,如果张量打算用于 backward(等同于 vjp),则应使用 ctx.save_for_backward() 保存;如果张量打算用于 jvp,则应使用 ctx.save_for_forward() 保存。

返回类型

Any

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源