快捷方式

torch.autograd.Function.forward

static Function.forward(*args, **kwargs)[来源]

定义自定义 autograd Function 的前向传播。

此函数需要由所有子类重写。定义前向传播有两种方式

用法 1(组合前向传播和 ctx)

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

用法 2(分离前向传播和 ctx)

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • 前向传播不再接受 ctx 参数。

  • 相反,您还必须重写 torch.autograd.Function.setup_context() 静态方法来处理设置 ctx 对象。output 是前向传播的输出,inputs 是前向传播的输入的元组。

  • 有关更多详细信息,请参阅 扩展 torch.autograd

上下文可用于存储任意数据,这些数据随后可在反向传播期间检索。张量不应直接存储在 ctx 上(尽管目前为了向后兼容性,这并非强制执行)。相反,张量应使用 ctx.save_for_backward() 保存(如果它们打算在 backward 中使用(等效地,vjp)),或者使用 ctx.save_for_forward() 保存(如果它们打算在 jvp 中使用)。

返回类型

Any

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源