torch.utils.checkpoint¶
注意
检查点是通过在反向传播期间为每个检查点分段重新运行前向传播分段来实现的。这可能会导致持久状态(如 RNG 状态)比不使用检查点时更超前。默认情况下,检查点包含用于调整 RNG 状态的逻辑,以便与非检查点传递相比,使用 RNG(例如通过 dropout)的检查点传递具有确定性输出。与非检查点传递相比,存储和恢复 RNG 状态的逻辑可能会导致适度的性能损失,具体取决于检查点操作的运行时。如果不需要与非检查点传递相比的确定性输出,请提供 preserve_rng_state=False
给 checkpoint
或 checkpoint_sequential
以省略在每次检查点期间存储和恢复 RNG 状态。
存储逻辑保存并恢复 CPU 和另一种设备类型的 RNG 状态(通过 _infer_device_type
从排除 CPU 张量的 Tensor 参数推断设备类型)到 run_fn
。如果存在多个设备,则设备状态将仅为单一设备类型的设备保存,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,则可能导致不正确的梯度。(请注意,如果 CUDA 设备在检测到的设备中,它将被优先考虑;否则,将选择遇到的第一个设备。)如果没有 CPU 张量,则将保存和恢复默认设备类型状态(默认值为 cuda,可以通过 DefaultDeviceType
设置为其他设备)。但是,该逻辑无法预测用户是否会在 run_fn
本身中将张量移动到新设备。因此,如果您在 run_fn
中将张量移动到新设备(“新”指的是不属于 [当前设备 + Tensor 参数设备] 集合的设备),则永远无法保证与非检查点传递相比的确定性输出。
- torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source][source]¶
检查点模型或模型的一部分。
激活检查点是一种用计算换内存的技术。检查点区域中的前向计算省略保存用于反向传播的张量,而是在反向传播过程中重新计算它们,而不是保留反向传播所需的张量,直到在反向传播期间的梯度计算中使用它们。激活检查点可以应用于模型的任何部分。
目前有两种检查点实现可用,由
use_reentrant
参数确定。建议您使用use_reentrant=False
。请参阅下面的注释,了解它们差异的讨论。警告
如果反向传播期间的
function
调用与前向传播不同,例如,由于全局变量,则检查点版本可能不等效,可能会导致引发错误或导致静默地不正确的梯度。警告
应显式传递
use_reentrant
参数。在 2.4 版本中,如果未传递use_reentrant
,我们将引发异常。如果您正在使用use_reentrant=True
变体,请参阅下面的注释,了解重要的注意事项和潜在的局限性。注意
重入检查点变体 (
use_reentrant=True
) 和非重入检查点变体 (use_reentrant=False
) 在以下方面有所不同非重入检查点在所有需要的中间激活都被重新计算后立即停止重新计算。此功能默认启用,但可以使用
set_checkpoint_early_stop()
禁用。重入检查点始终在反向传播期间完整地重新计算function
。重入变体在前向传播期间不记录 autograd 图,因为它在前向传播时在
torch.no_grad()
下运行。非重入版本确实记录 autograd 图,允许在检查点区域内对图执行反向传播。重入检查点仅支持
torch.autograd.backward()
API 用于不带其 inputs 参数的反向传播,而非重入版本支持执行反向传播的所有方式。对于重入变体,至少一个输入和输出必须具有
requires_grad=True
。如果不满足此条件,则模型的检查点部分将没有梯度。非重入版本没有此要求。重入版本不考虑嵌套结构(例如,自定义对象、列表、字典等)中的张量参与 autograd,而非重入版本会考虑。
重入检查点不支持具有从计算图中分离的张量的检查点区域,而非重入版本支持。对于重入变体,如果检查点分段包含使用
detach()
或torch.no_grad()
分离的张量,则反向传播将引发错误。这是因为checkpoint
使所有输出都需要梯度,当张量被定义为在模型中没有梯度时,这会导致问题。为了避免这种情况,请在checkpoint
函数外部分离张量。
- 参数
function – 描述模型或模型的一部分在前向传播中运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在 LSTM 中,如果用户传递
(activation, hidden)
,function
应该正确地将第一个输入用作activation
,将第二个输入用作hidden
preserve_rng_state (bool, 可选) – 省略在每次检查点期间存储和恢复 RNG 状态。请注意,在 torch.compile 下,此标志不起作用,我们始终保留 RNG 状态。默认值:
True
use_reentrant (bool) – 指定是否使用需要重入 autograd 的激活检查点变体。应显式传递此参数。在 2.5 版本中,如果未传递
use_reentrant
,我们将引发异常。如果use_reentrant=False
,checkpoint
将使用不需要重入 autograd 的实现。这允许checkpoint
支持其他功能,例如与torch.autograd.grad
一起按预期工作,并支持关键字参数输入到检查点函数中。context_fn (Callable, 可选) – 一个可调用对象,返回两个上下文管理器的元组。函数及其重新计算将在第一个和第二个上下文管理器下分别运行。仅当
use_reentrant=False
时才支持此参数。determinism_check (str, 可选) – 指定要执行的确定性检查的字符串。默认情况下,它设置为
"default"
,它将重新计算的张量的形状、dtypes 和设备与保存的张量进行比较。要关闭此检查,请指定"none"
。目前,这些是仅有的两个受支持的值。如果您希望看到更多确定性检查,请打开一个 issue。仅当use_reentrant=False
时才支持此参数,如果use_reentrant=True
,则始终禁用确定性检查。debug (bool, 可选) – 如果
True
,则错误消息还将包括原始前向计算以及重新计算期间运行的算子的跟踪。仅当use_reentrant=False
时才支持此参数。args – 包含
function
输入的元组
- 返回
在
*args
上运行function
的输出
- torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source][source]¶
检查点序列模型以节省内存。
序列模型按顺序(顺序地)执行模块/函数的列表。因此,我们可以将这样的模型划分为不同的分段,并检查点每个分段。除最后一个分段外的所有分段都不会存储中间激活。每个检查点分段的输入将被保存,以便在反向传播中重新运行该分段。
警告
应显式传递
use_reentrant
参数。在 2.4 版本中,如果未传递use_reentrant
,我们将引发异常。如果您正在使用use_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False
。- 参数
functions – 要顺序运行的
torch.nn.Sequential
或模块或函数的列表(构成模型)。segments – 模型中要创建的块数
input – 输入到
functions
的张量preserve_rng_state (bool, 可选) – 省略在每次检查点期间存储和恢复 RNG 状态。默认值:
True
use_reentrant (bool) – 指定是否使用需要重入 autograd 的激活检查点变体。应显式传递此参数。在 2.5 版本中,如果未传递
use_reentrant
,我们将引发异常。如果use_reentrant=False
,checkpoint
将使用不需要重入 autograd 的实现。这允许checkpoint
支持其他功能,例如与torch.autograd.grad
一起按预期工作,并支持关键字参数输入到检查点函数中。
- 返回
在
*inputs
上顺序运行functions
的输出
示例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)
- torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source][source]¶
上下文管理器,用于设置检查点在运行时是否应打印额外的调试信息。有关更多信息,请参阅
checkpoint()
的debug
标志。请注意,设置后,此上下文管理器将覆盖传递给检查点的debug
值。要推迟到本地设置,请将None
传递给此上下文。- 参数
enabled (bool) – 检查点是否应打印调试信息。默认值为 ‘None’。
- class torch.utils.checkpoint.CheckpointPolicy(value)[source][source]¶
用于指定反向传播期间检查点策略的枚举。
支持以下策略
{MUST,PREFER}_SAVE
: 操作的输出将在前向传播期间保存,并且不会在反向传播期间重新计算{MUST,PREFER}_RECOMPUTE
: 操作的输出将不会在前向传播期间保存,并且将在反向传播期间重新计算
使用
MUST_*
而不是PREFER_*
来指示策略不应被其他子系统(如 torch.compile)覆盖。注意
始终返回
PREFER_RECOMPUTE
的策略函数等同于 vanilla 检查点。返回
PREFER_SAVE
的策略函数对于每个操作都与不使用检查点不等效。使用此类策略将保存额外的张量,而不仅仅限于实际梯度计算所需的张量。
- class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[source][source]¶
在选择性检查点期间传递给策略函数的上下文。
此类用于在选择性检查点期间将相关元数据传递给策略函数。元数据包括策略函数的当前调用是否在重新计算期间。
示例
>>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> print(ctx.is_recompute) >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )
- torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[source][source]¶
帮助避免在激活检查点期间重新计算某些操作。
将此与 torch.utils.checkpoint.checkpoint 一起使用,以控制在反向传播期间重新计算哪些操作。
- 参数
policy_fn_or_list (Callable or List) –
如果提供策略函数,它应接受
SelectiveCheckpointContext
、OpOverload
、操作的 args 和 kwargs,并返回一个CheckpointPolicy
枚举值,指示是否应重新计算操作的执行。如果提供操作列表,则等同于策略为指定操作返回 CheckpointPolicy.MUST_SAVE,为所有其他操作返回 CheckpointPolicy.PREFER_RECOMPUTE。
allow_cache_entry_mutation (bool, 可选) – 默认情况下,如果选择性激活检查点缓存的任何张量被修改,则会引发错误,以确保正确性。如果设置为 True,则禁用此检查。
- 返回
两个上下文管理器的元组。
示例
>>> import functools >>> >>> x = torch.rand(10, 10, requires_grad=True) >>> y = torch.rand(10, 10, requires_grad=True) >>> >>> ops_to_save = [ >>> torch.ops.aten.mm.default, >>> ] >>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> if op in ops_to_save: >>> return CheckpointPolicy.MUST_SAVE >>> else: >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> # or equivalently >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) >>> >>> def fn(x, y): >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )