torch.utils.checkpoint¶
注意
检查点是通过在反向传播期间为每个检查点段重新运行前向传递段来实现的。这可能导致像 RNG 状态这样的持久状态比没有检查点时更先进。默认情况下,检查点包含逻辑来处理 RNG 状态,以便与非检查点段相比,使用 RNG(例如通过 dropout)的检查点段具有确定性的输出。存储和恢复 RNG 状态的逻辑可能会导致适度的性能下降,具体取决于检查点操作的运行时间。如果不需要与非检查点段相比具有确定性的输出,则向 checkpoint
或 checkpoint_sequential
提供 preserve_rng_state=False
以省略在每个检查点期间存储和恢复 RNG 状态。
存储逻辑将 CPU 和另一种设备类型(通过 _infer_device_type
从 Tensor 参数中排除 CPU 张量来推断设备类型)的 RNG 状态保存并恢复到 run_fn
。如果有多个设备,则仅保存单个设备类型的设备状态,其余设备将被忽略。因此,如果任何检查点函数涉及随机性,这可能会导致不正确的梯度。(请注意,如果 CUDA 设备位于检测到的设备中,则它将被优先考虑;否则,将选择遇到的第一个设备。)如果没有 CPU 张量,则将保存和恢复默认设备类型状态(默认值为 cuda,可以通过 DefaultDeviceType
设置为其他设备)。但是,该逻辑无法预测用户是否会在 run_fn
本身中将张量移动到新设备。“新”表示不属于 [当前设备 + 张量参数的设备] 集合。因此,如果您在 run_fn
中将张量移动到新设备,则与非检查点段相比,永远无法保证确定性的输出。
- torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source]¶
检查点模型或模型的一部分。
激活检查点是一种以计算换取内存的技术。它不是将反向传播所需的张量一直保留到在反向传播期间用于梯度计算时,而是检查点区域中的前向计算会省略保存用于反向传播的张量,并在反向传播期间重新计算它们。激活检查点可以应用于模型的任何部分。
目前有两种可用的检查点实现,由
use_reentrant
参数确定。建议您使用use_reentrant=False
。请参阅下面的说明,了解它们之间差异的讨论。警告
如果在反向传播期间对
function
的调用与前向传播不同,例如由于全局变量,则检查点版本可能不相同,这可能导致引发错误或导致梯度错误且没有提示。警告
应显式传递
use_reentrant
参数。在 2.4 版中,如果未传递use_reentrant
,我们将引发异常。如果您使用的是use_reentrant=True
变体,请参阅下面的说明,了解重要注意事项和潜在限制。注意
检查点的可重入变体(
use_reentrant=True
)和检查点的非可重入变体(use_reentrant=False
)在以下方面有所不同非可重入检查点在所有需要的中间激活都重新计算后立即停止重新计算。此功能默认启用,但可以使用
set_checkpoint_early_stop()
禁用。可重入检查点始终在反向传播期间完整地重新计算function
。可重入变体在执行前向传递时不记录 autograd 图,因为它在
torch.no_grad()
下运行前向传递。非可重入版本确实记录了 autograd 图,允许用户对检查点区域内的图执行反向传播。可重入检查点仅支持用于反向传播的
torch.autograd.backward()
API,且不带其inputs参数,而非可重入版本支持所有执行反向传播的方式。对于可重入变体,至少一个输入和输出必须具有
requires_grad=True
。如果此条件不满足,则模型的检查点部分将没有梯度。非可重入版本没有此要求。可重入版本不认为嵌套结构(例如,自定义对象、列表、字典等)中的张量参与自动梯度计算,而非可重入版本则认为参与。
可重入检查点不支持包含从计算图中分离的张量的检查点区域,而非可重入版本则支持。对于可重入变体,如果检查点段包含使用
detach()
或torch.no_grad()
分离的张量,则反向传播将引发错误。这是因为checkpoint
使所有输出都要求梯度,并且当张量在模型中定义为没有梯度时,这会导致问题。为避免这种情况,请在checkpoint
函数外部分离张量。
- 参数
function – 描述在模型或模型一部分的前向传播中要运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在 LSTM 中,如果用户传递
(activation, hidden)
,则function
应正确地将第一个输入用作activation
,将第二个输入用作hidden
preserve_rng_state (bool, optional) – 在每次检查点期间省略存储和恢复 RNG 状态。请注意,在 torch.compile 下,此标志无效,我们始终保留 RNG 状态。默认值:
True
use_reentrant (bool) – 指定是否使用需要可重入自动梯度的激活检查点变体。应显式传递此参数。在 2.5 版中,如果未传递
use_reentrant
,我们将引发异常。如果use_reentrant=False
,则checkpoint
将使用不需要可重入自动梯度的实现。这允许checkpoint
支持其他功能,例如与torch.autograd.grad
按预期工作以及对输入到检查点函数的关键字参数的支持。context_fn (Callable, optional) – 一个可调用对象,返回两个上下文管理器的元组。函数及其重新计算将分别在第一个和第二个上下文管理器下运行。仅当
use_reentrant=False
时才支持此参数。determinism_check (str, optional) – 指定要执行的确定性检查的字符串。默认情况下,它设置为
"default"
,它将重新计算的张量的形状、数据类型和设备与保存的张量进行比较。要关闭此检查,请指定"none"
。目前,这两个是唯一支持的值。如果您想查看更多确定性检查,请提交问题。仅当use_reentrant=False
时才支持此参数,如果use_reentrant=True
,则始终禁用确定性检查。debug (bool, optional) – 如果为
True
,则错误消息还将包含原始前向计算以及重新计算期间运行的操作的跟踪。仅当use_reentrant=False
时才支持此参数。args – 包含
function
输入的元组
- 返回值
在
function
上运行*args
的输出
- torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source]¶
检查点一个顺序模型以节省内存。
顺序模型按顺序(依次)执行模块/函数列表。因此,我们可以将此类模型划分为各个段,并对每个段进行检查点。除了最后一个段之外,所有段都不会存储中间激活。将保存每个检查点段的输入,以便在反向传播中重新运行该段。
警告
应显式传递
use_reentrant
参数。在 2.4 版中,如果未传递use_reentrant
,我们将引发异常。如果您正在使用use_reentrant=True` 变体,请参阅 :func:`~torch.utils.checkpoint.checkpoint` 以了解此变体的 重要考虑事项和限制。 建议您使用 ``use_reentrant=False``。
- 参数
functions – 一个
torch.nn.Sequential
或模块或函数列表(构成模型)以顺序运行。segments – 在模型中创建的块数
input – 输入到
functions
的张量preserve_rng_state (bool, optional) – 在每次检查点期间省略存储和恢复 RNG 状态。默认值:
True
use_reentrant (bool) – 指定是否使用需要可重入自动梯度的激活检查点变体。应显式传递此参数。在 2.5 版中,如果未传递
use_reentrant
,我们将引发异常。如果use_reentrant=False
,则checkpoint
将使用不需要可重入自动梯度的实现。这允许checkpoint
支持其他功能,例如与torch.autograd.grad
按预期工作以及对输入到检查点函数的关键字参数的支持。
- 返回值
在
functions
上依次运行*inputs
的输出
示例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)
- torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]¶
上下文管理器,用于设置检查点在运行时是否应打印其他调试信息。有关更多信息,请参阅
checkpoint()
的debug
标志。请注意,设置后,此上下文管理器会覆盖传递给检查点的debug
的值。要推迟到本地设置,请将None
传递给此上下文。- 参数
enabled (bool) – 检查点是否应打印调试信息。默认为“None”。
- class torch.utils.checkpoint.CheckpointPolicy(value)[source]¶
用于指定反向传播期间检查点策略的枚举。
支持以下策略
{MUST,PREFER}_SAVE
:操作的输出将在前向传递期间保存,并且不会在反向传递期间重新计算{MUST,PREFER}_RECOMPUTE
:操作的输出不会在前向传递期间保存,并且将在反向传递期间重新计算
使用
MUST_*
而不是PREFER_*
来指示策略不应被其他子系统(如torch.compile)覆盖。注意
始终返回
PREFER_RECOMPUTE
的策略函数等效于普通检查点。每个操作都返回
PREFER_SAVE
的策略函数不等效于不使用检查点。使用此类策略将保存其他张量,而不仅仅是实际需要用于梯度计算的张量。
- class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[source]¶
在选择性检查点期间传递给策略函数的上下文。
此类用于在选择性检查点期间将相关元数据传递给策略函数。元数据包括策略函数的当前调用是在重新计算期间还是否。
示例
>>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> print(ctx.is_recompute) >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )
- torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[source]¶
帮助避免在激活检查点期间重新计算某些操作。
将此与torch.utils.checkpoint.checkpoint一起使用以控制在反向传播期间哪些操作被重新计算。
- 参数
policy_fn_or_list (Callable 或 List) –
如果提供策略函数,它应该接受一个
SelectiveCheckpointContext
、OpOverload
、操作的参数和关键字参数,并返回一个CheckpointPolicy
枚举值,指示操作的执行是否应该被重新计算。如果提供操作列表,则等效于对指定操作返回CheckpointPolicy.MUST_SAVE,对所有其他操作返回CheckpointPolicy.PREFER_RECOMPUTE。
allow_cache_entry_mutation (bool, optional) – 默认情况下,如果选择性激活检查点缓存的任何张量被修改,则会引发错误以确保正确性。如果设置为True,则禁用此检查。
- 返回值
包含两个上下文管理器的元组。
示例
>>> import functools >>> >>> x = torch.rand(10, 10, requires_grad=True) >>> y = torch.rand(10, 10, requires_grad=True) >>> >>> ops_to_save = [ >>> torch.ops.aten.mm.default, >>> ] >>> >>> def policy_fn(ctx, op, *args, **kwargs): >>> if op in ops_to_save: >>> return CheckpointPolicy.MUST_SAVE >>> else: >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> >>> # or equivalently >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) >>> >>> def fn(x, y): >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> >>> out = torch.utils.checkpoint.checkpoint( >>> fn, x, y, >>> use_reentrant=False, >>> context_fn=context_fn, >>> )