torch.utils.checkpoint¶
注意
Checkpointing 通过在反向传播期间针对每个检查点段重新运行正向传播段来实现。这可能导致持久状态(例如 RNG 状态)比不使用 checkpointing 时更靠前。默认情况下,checkpointing 包含调整 RNG 状态的逻辑,以便使用 RNG(例如通过 dropout)的检查点传播与非检查点传播相比具有确定性输出。保存和恢复 RNG 状态的逻辑可能会根据检查点操作的运行时长而产生中等程度的性能损失。如果不需要与非检查点传播相比的确定性输出,请将 preserve_rng_state=False
提供给 checkpoint
或 checkpoint_sequential
,以在每个检查点期间省略保存和恢复 RNG 状态。
保存逻辑会将 CPU 和另一种设备类型(通过 _infer_device_type
从 Tensor 参数中推断出非 CPU Tensor 的设备类型)的 RNG 状态保存并恢复到 run_fn
。如果存在多种设备,则只会保存单一设备类型的设备状态,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,这可能导致梯度不正确。(请注意,如果检测到的设备中包含 CUDA 设备,则 CUDA 将被优先考虑;否则,将选择遇到的第一个设备。)如果没有 CPU Tensor,则将保存和恢复默认设备类型状态(默认值为 cuda,可以通过 DefaultDeviceType
设置为其他设备)。然而,该逻辑无法预测用户是否会在 run_fn
本身内部将 Tensors 移动到新设备。因此,如果您在 run_fn
中将 Tensors 移动到新设备(“新”设备指不属于 [当前设备 + Tensor 参数设备] 集合的设备),则无法保证与非检查点传播相比的确定性输出。
- torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[源代码][源代码]¶
对模型或模型的一部分进行 Checkpoint。
激活 checkpointing 是一种用计算换取内存的技术。它不像在反向传播中一直保留反向计算所需的张量直到它们被用于梯度计算,而是在检查点区域的正向计算中省略保存反向传播所需的张量,并在反向传播过程中重新计算它们。激活 checkpointing 可以应用于模型的任何部分。
目前有两种可用的 checkpointing 实现,由
use_reentrant
参数决定。建议您使用use_reentrant=False
。有关它们差异的讨论,请参阅下面的注意。警告
如果
function
在反向传播期间的调用与正向传播不同,例如由于全局变量,则检查点版本可能不一致,可能导致引发错误或导致悄无声息的错误梯度。警告
use_reentrant
参数应显式传递。在版本 2.4 中,如果未传递use_reentrant
,我们将引发异常。如果您使用use_reentrant=True
变体,请参阅下面的注意,了解重要注意事项和潜在限制。注意
checkpoint 的可重入变体 (
use_reentrant=True
) 和非可重入变体 (use_reentrant=False
) 在以下方面有所不同:非可重入 checkpoint 在所有需要的中间激活被重新计算后立即停止重新计算。此功能默认启用,但可以使用
set_checkpoint_early_stop()
禁用。可重入 checkpoint 在反向传播期间总是完整地重新计算function
。可重入变体在正向传播期间不记录 autograd 图,因为它在
torch.no_grad()
下运行。非可重入版本记录 autograd 图,允许在检查点区域内对图执行反向传播。可重入 checkpoint 仅支持不带 inputs 参数的
torch.autograd.backward()
API 进行反向传播,而非可重入版本支持所有进行反向传播的方式。可重入变体要求至少一个输入和输出具有
requires_grad=True
。如果未能满足此条件,模型的检查点部分将没有梯度。非可重入版本没有此要求。可重入版本不将嵌套结构(例如自定义对象、列表、字典等)中的 tensors 视为参与 autograd,而非可重入版本则视为参与。
可重入 checkpoint 不支持计算图中包含已分离 tensors 的检查点区域,而非可重入版本支持。对于可重入变体,如果检查点段包含使用
detach()
或torch.no_grad()
分离的 tensors,则反向传播将引发错误。这是因为checkpoint
使所有输出都需要梯度,当模型中某个 tensor 定义为无梯度时,这会导致问题。为避免此问题,请在checkpoint
函数外部分离 tensors。
- 参数
function – 描述在模型或模型一部分的正向传播中运行的内容。它还应该知道如何处理以元组形式传递的输入。例如,在 LSTM 中,如果用户传递
(activation, hidden)
,function
应正确地将第一个输入用作activation
,将第二个输入用作hidden
。preserve_rng_state (bool, 可选) – 在每个 checkpoint 期间省略保存和恢复 RNG 状态。请注意,在 torch.compile 下,此标志不生效,并且我们始终保留 RNG 状态。默认值:
True
use_reentrant (bool) – 指定是否使用需要可重入 autograd 的激活 checkpoint 变体。此参数应显式传递。在版本 2.5 中,如果未传递
use_reentrant
,我们将引发异常。如果use_reentrant=False
,checkpoint
将使用不需要可重入 autograd 的实现。这使得checkpoint
能够支持附加功能,例如与torch.autograd.grad
一起正常工作,以及支持输入到检查点函数中的关键字参数。context_fn (Callable, 可选) – 一个可调用对象,返回包含两个上下文管理器的元组。该函数及其重新计算将分别在第一个和第二个上下文管理器下运行。此参数仅在
use_reentrant=False
时受支持。determinism_check (str, 可选) – 指定要执行的确定性检查的字符串。默认设置为
"default"
,它将重新计算的 tensors 的形状、dtype 和设备与保存的 tensors 进行比较。要关闭此检查,请指定"none"
。目前只支持这两个值。如果您希望看到更多确定性检查,请提出 Issue。此参数仅在use_reentrant=False
时受支持;如果use_reentrant=True
,则确定性检查始终禁用。debug (bool, 可选) – 如果为
True
,错误消息还将包含原始正向计算以及重新计算期间运行的操作符的跟踪信息。此参数仅在use_reentrant=False
时受支持。args – 包含
function
输入的元组
- 返回值
在
*args
上运行function
的输出
- torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[源代码][源代码]¶
对顺序模型进行 Checkpoint 以节省内存。
顺序模型按顺序(依次)执行模块/函数的列表。因此,我们可以将模型划分为多个段并对每个段进行 checkpoint。除最后一段外的所有段都不会存储中间激活。每个检查点段的输入将被保存,以便在反向传播中重新运行该段。
警告
use_reentrant
参数应显式传递。在版本 2.4 中,如果未传递use_reentrant
,我们将引发异常。如果您使用use_reentrant=True` 变体,请参阅 :func:`~torch.utils.checkpoint.checkpoint` 了解此变体的重要注意事项和限制。建议您使用 ``use_reentrant=False
。- 参数
functions – 一个
torch.nn.Sequential
或要顺序运行的模块或函数列表(构成模型)。segments – 在模型中创建的段数
input – 输入到
functions
的 Tensor。preserve_rng_state (bool, 可选) – 在每个 checkpoint 期间省略保存和恢复 RNG 状态。默认值:
True
use_reentrant (bool) – 指定是否使用需要可重入 autograd 的激活 checkpoint 变体。此参数应显式传递。在版本 2.5 中,如果未传递
use_reentrant
,我们将引发异常。如果use_reentrant=False
,checkpoint
将使用不需要可重入 autograd 的实现。这使得checkpoint
能够支持附加功能,例如与torch.autograd.grad
一起正常工作,以及支持输入到检查点函数中的关键字参数。
- 返回值
顺序运行
functions
在*inputs
上的输出
示例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)
- torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[源代码][源代码]¶
上下文管理器,用于设置 checkpoint 运行时是否应打印额外的调试信息。有关更多信息,请参阅
checkpoint()
的debug
标志。请注意,设置后,此上下文管理器将覆盖传递给 checkpoint 的debug
值。要遵从本地设置,请将None
传递给此上下文。- 参数
enabled (bool) – checkpoint 是否应打印调试信息。默认值为 ‘None’。
- class torch.utils.checkpoint.CheckpointPolicy(value)[源代码][源代码]¶
用于指定反向传播期间 checkpointing 策略的枚举类。
支持以下策略
{MUST,PREFER}_SAVE
:操作的输出将在正向传播期间保存,并在反向传播期间不再重新计算。{MUST,PREFER}_RECOMPUTE
:操作的输出将在正向传播期间不保存,并在反向传播期间重新计算。
使用
MUST_*
而非PREFER_*
表示该策略不应被 torch.compile 等其他子系统覆盖。注意
始终返回
PREFER_RECOMPUTE
的策略函数等同于传统的 checkpointing。对于每个操作都返回
PREFER_SAVE
的策略函数与不使用 checkpointing 是不等价的。使用这样的策略会保存额外的 tensors,而不仅仅是梯度计算实际需要的 tensors。
- class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[源代码][源代码]¶
在选择性 checkpointing 期间传递给策略函数的上下文。
此类用于在选择性 checkpointing 期间将相关元数据传递给策略函数。元数据包括当前策略函数的调用是否在重新计算期间。
示例
>>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> print(ctx.is_recompute) >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )
- torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[源代码][源代码]¶
帮助避免在激活 checkpointing 期间重新计算某些操作。
将其与 torch.utils.checkpoint.checkpoint 一起使用,以控制哪些操作在反向传播期间被重新计算。
- 参数
policy_fn_or_list (Callable 或 List) –
如果提供了策略函数,它应该接受一个
SelectiveCheckpointContext
、OpOverload
、操作的 args 和 kwargs,并返回一个CheckpointPolicy
枚举值,指示是否应重新计算该操作的执行。如果提供了操作列表,则等同于一个策略,该策略对于指定的操作返回 CheckpointPolicy.MUST_SAVE,对于所有其他操作返回 CheckpointPolicy.PREFER_RECOMPUTE。
allow_cache_entry_mutation (bool, 可选) – 默认情况下,如果选择性激活 checkpointing 缓存的任何 tensors 被修改,则会引发错误,以确保正确性。如果设置为 True,则禁用此检查。
- 返回值
包含两个上下文管理器的元组。
示例
>>> import functools >>> >>> x = torch.rand(10, 10, requires_grad=True) >>> y = torch.rand(10, 10, requires_grad=True) >>> >>> ops_to_save = [ >>> torch.ops.aten.mm.default, >>> ] >>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> if op in ops_to_save: >>> return CheckpointPolicy.MUST_SAVE >>> else: >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> # or equivalently >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) >>> >>> def fn(x, y): >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )