快捷方式

torch.autograd.Function.forward

static Function.forward(*args, **kwargs)

定义自定义 autograd Function 的前向传播。

此函数需要由所有子类覆盖。有两种方法可以定义前向传播

用法 1(组合前向传播和 ctx)

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

用法 2(分开的 forward 和 ctx)

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • forward 不再接受 ctx 参数。

  • 相反,您还必须覆盖torch.autograd.Function.setup_context() 静态方法来处理设置ctx 对象。output 是前向传播的输出,inputs 是前向传播输入的元组。

  • 有关更多详细信息,请参阅扩展 torch.autograd

上下文可用于存储任意数据,这些数据随后可以在反向传播过程中检索。不应将张量直接存储在ctx 上(尽管出于向后兼容性的原因,目前没有强制执行此操作)。相反,应使用ctx.save_for_backward() 保存张量,如果打算在backward(等效地,vjp)中使用,或者使用ctx.save_for_forward() 保存张量,如果打算在jvp 中使用。

返回类型

任意

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源