• 文档 >
  • torch.utils.checkpoint
快捷方式

torch.utils.checkpoint

注意

检查点通过在反向传播期间为每个检查点段重新运行一次前向传递段来实现。这会导致像 RNG 状态这样的持久状态比没有检查点时的状态更先进。默认情况下,检查点包括逻辑来处理 RNG 状态,以便使用 RNG(例如通过 dropout)的检查点传递与非检查点传递相比具有确定性输出。存储和恢复 RNG 状态的逻辑可能会带来适度的性能损失,具体取决于检查点操作的运行时间。如果不需要与非检查点传递相比的确定性输出,请向 checkpointcheckpoint_sequential 提供 preserve_rng_state=False 以省略在每个检查点期间存储和恢复 RNG 状态。

存储逻辑将 CPU 和另一个设备类型的 RNG 状态(通过 _infer_device_type 从 Tensor 参数(不包括 CPU 张量)推断设备类型)保存并恢复到 run_fn。如果有多个设备,则仅保存单个设备类型设备的状态,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,这可能会导致不正确的梯度。(请注意,如果 CUDA 设备位于检测到的设备中,它将被优先考虑;否则,将选择遇到的第一个设备。)如果没有 CPU 张量,则将保存并恢复默认设备类型状态(默认值为 cuda,可以通过 DefaultDeviceType 设置为其他设备)。但是,该逻辑无法预测用户是否会将 Tensor 移动到 run_fn 本身内的新的设备(“新”表示不属于 [当前设备 + Tensor 参数的设备] 集合)。因此,如果您在 run_fn 内将 Tensor 移动到新的设备,则无法保证与非检查点传递相比的确定性输出。

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source]

检查点模型或模型的一部分。

激活检查点是一种将计算换取内存的技术。在反向传递期间,激活检查点区域的前向计算不会保存用于反向传递的张量,而是通过重新计算这些张量来进行反向传递,从而节省了内存。激活检查点可以应用于模型的任何部分。

目前,有两种检查点实现可用,由 use_reentrant 参数决定。建议您使用 use_reentrant=False。请参阅下面的说明以讨论它们的差异。

警告

如果反向传递期间的 function 调用与正向传递不同,例如由于全局变量,则检查点版本可能不完全相同,这可能会导致引发错误或导致静默地出现错误梯度。

警告

应显式传递 use_reentrant 参数。在 2.4 版本中,如果未传递 use_reentrant,我们将引发异常。如果您正在使用 use_reentrant=True 变体,请参阅下面的说明以了解重要注意事项和潜在限制。

注意

检查点的可重入变体 (use_reentrant=True) 和检查点的不可重入变体 (use_reentrant=False) 在以下方面有所不同

  • 不可重入检查点在所有需要的中间激活都已重新计算后,立即停止重新计算。此功能默认启用,但可以使用 set_checkpoint_early_stop() 禁用。可重入检查点始终在反向传递期间完全重新计算 function

  • 可重入变体在正向传递期间不会记录自动微分图,因为它在 torch.no_grad() 下运行正向传递。不可重入版本会记录自动微分图,允许在检查点区域内对图进行反向传播。

  • 可重入检查点仅支持 torch.autograd.backward() API 进行反向传播,但不支持其 inputs 参数,而不可重入版本支持所有反向传播方法。

  • 对于可重入变体,至少一个输入和输出必须具有 requires_grad=True。如果此条件不满足,则模型的检查点部分将没有梯度。不可重入版本没有此要求。

  • 可重入版本不会将嵌套结构中的张量(例如,自定义对象、列表、字典等)视为参与自动微分,而不可重入版本会。

  • 可重入检查点不支持包含从计算图中分离的张量的检查点区域,而不可重入版本支持。对于可重入变体,如果检查点段包含使用 detach()torch.no_grad() 分离的张量,则反向传播将引发错误。这是因为 checkpoint 使所有输出都需要梯度,当张量被定义为在模型中没有梯度时会导致问题。为了避免这种情况,请在 checkpoint 函数之外分离张量。

参数
  • function – 描述在模型或模型一部分的正向传递中运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在 LSTM 中,如果用户传递 (activation, hidden)function 应该正确地使用第一个输入作为 activation,第二个输入作为 hidden

  • preserve_rng_state (bool, optional) – 在每个检查点期间省略保存和恢复 RNG 状态。请注意,在 torch.compile 下,此标志不起作用,我们始终保留 RNG 状态。默认值:True

  • use_reentrant (bool) – 指定是否使用需要可重入自动微分的激活检查点变体。此参数应显式传递。在 2.4 版本中,如果未传递 use_reentrant,我们将引发异常。如果 use_reentrant=Falsecheckpoint 将使用不需要可重入自动微分的实现。这允许 checkpoint 支持其他功能,例如与 torch.autograd.grad 一起按预期工作,并支持输入到检查点函数的关键字参数。

  • context_fn (Callable, optional) – 一个可调用对象,返回两个上下文管理器的元组。该函数及其重新计算将在第一个和第二个上下文管理器下分别运行。此参数仅在 use_reentrant=False 时支持。

  • determinism_check (str, optional) – 指定要执行的确定性检查的字符串。默认情况下它设置为 "default",它将重新计算的张量的形状、数据类型和设备与保存的张量进行比较。要关闭此检查,请指定 "none"。目前,只有这两个支持的值。如果您想查看更多确定性检查,请打开一个问题。此参数仅在 use_reentrant=False 时支持,如果 use_reentrant=True,则确定性检查始终被禁用。

  • debug (bool, optional) – 如果 True,错误消息还将包括原始正向计算和重新计算期间运行的运算符的跟踪。此参数仅在 use_reentrant=False 时支持。

  • args – 包含 function 的输入的元组

返回值

function 上运行 *args 的输出

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source]

检查点一个顺序模型以节省内存。

顺序模型按顺序执行模块/函数列表。因此,我们可以将这样的模型分成多个段,并对每个段进行检查点。除了最后一个段之外,所有段都不会存储中间激活。将保存每个检查点段的输入,以便在反向传递中重新运行该段。

警告

应显式传递 use_reentrant 参数。在 2.4 版本中,如果未传递 use_reentrant,我们将引发异常。如果您正在使用 use_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False.

参数
  • functions – 一个 torch.nn.Sequential 或模块或函数的列表(构成模型),按顺序运行。

  • segments – 在模型中创建的块数

  • input – 输入到 functions 的张量

  • preserve_rng_state (bool, optional) – 在每个检查点期间省略保存和恢复 RNG 状态。默认值:True

  • use_reentrant (bool) – 指定是否使用需要可重入自动微分的激活检查点变体。此参数应显式传递。在 2.4 版本中,如果未传递 use_reentrant,我们将引发异常。如果 use_reentrant=Falsecheckpoint 将使用不需要可重入自动微分的实现。这允许 checkpoint 支持其他功能,例如与 torch.autograd.grad 一起按预期工作,并支持输入到检查点函数的关键字参数。

返回值

functions 上按顺序运行 *inputs 的输出

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]

上下文管理器,设置检查点在运行时是否应打印额外的调试信息。有关更多信息,请参见 checkpoint()debug 标志。请注意,当设置时,此上下文管理器会覆盖传递给检查点的 debug 的值。要推迟到本地设置,请将 None 传递给此上下文。

参数

enabled (bool) – 检查点是否应打印调试信息。默认为 'None'。

文档

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources