torch.utils.checkpoint¶
注意
检查点通过在反向传播期间为每个检查点细分重新运行前向传递细分来实现。这可能会导致持久状态(如 RNG 状态)比没有检查点时更加高级。默认情况下,检查点包含逻辑来处理 RNG 状态,以便使用 RNG(例如通过 dropout)的检查点传递与非检查点传递相比具有确定性输出。根据检查点操作的运行时,存储和恢复 RNG 状态的逻辑可能会产生适度的性能影响。如果不需要与非检查点传递相比的确定性输出,请提供 preserve_rng_state=False
给 checkpoint
或 checkpoint_sequential
以在每次检查点期间省略存储和恢复 RNG 状态。
隐藏逻辑保存和恢复 CPU 和其他设备类型的 RNG 状态(通过 _infer_device_type
从排除 CPU 张量的张量参数中推断设备类型)到 run_fn
。如果有多个设备,则仅保存单个设备类型的设备状态,而忽略剩余设备。因此,如果任何检查点函数涉及随机性,这可能会导致不正确的梯度。(请注意,如果检测到的设备中包含 CUDA 设备,则会优先考虑;否则,将选择遇到的第一个设备。)如果没有 CPU 张量,则将保存和恢复默认设备类型状态(默认值为 cuda,可以通过 DefaultDeviceType
设置为其他设备)。但是,该逻辑无法预测用户是否会在 run_fn
本身中将张量移动到新设备。因此,如果您在 run_fn
中将张量移动到新设备(“新”表示不属于 [当前设备 + 张量参数的设备] 集合),则永远无法保证与非检查点传递相比确定性输出。
- torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source]¶
检查点模型或模型的一部分。
激活检查点是一种用计算换取内存的技术。在检查点区域中,前向计算省略了为反向保存反向所需的张量,并在反向传递期间重新计算它们,而不是让它们保持活动状态,直到在反向期间用于梯度计算。激活检查点可以应用于模型的任何部分。
目前有两种可用的检查点实现,由
use_reentrant
参数决定。建议您使用use_reentrant=False
。请参阅下面的注释以了解它们的差异。警告
如果反向传递期间的
function
调用与前向传递不同,例如,由于全局变量,则检查点版本可能不等效,可能导致引发错误或导致错误的梯度。警告
应明确传递
use_reentrant
参数。在版本 2.4 中,如果未传递use_reentrant
,我们将引发异常。如果您使用use_reentrant=True
变体,请参阅下面的注释以了解重要的注意事项和潜在的限制。注意
检查点的可重入变体 (
use_reentrant=True
) 和检查点的不可重入变体 (use_reentrant=False
) 在以下方面有所不同非可重入检查点在所有所需的中间激活重新计算后立即停止重新计算。此功能默认启用,但可以使用
set_checkpoint_early_stop()
禁用。可重入检查点在反向传递期间始终重新计算function
的全部内容。可重入变体在正向传递期间不记录自动微分图,因为它在
torch.no_grad()
下运行正向传递。非可重入版本确实记录了自动微分图,允许在检查点区域内对图执行反向传播。可重入检查点仅支持
torch.autograd.backward()
API 进行反向传递,不带 inputs 参数,而非可重入版本支持所有执行反向传递的方法。对于可重入变体,至少一个输入和输出必须具有
requires_grad=True
。如果不满足此条件,则模型的检查点部分将没有梯度。非可重入版本没有此要求。可重入版本不将嵌套结构中的张量(例如,自定义对象、列表、字典等)视为参与自动微分,而非可重入版本则参与。
可重入检查点不支持来自计算图的分离张量的检查点区域,而非可重入版本则支持。对于可重入变体,如果检查点段包含使用
detach()
或torch.no_grad()
分离的张量,则反向传递将引发错误。这是因为checkpoint
使所有输出都需要梯度,当模型中定义张量没有梯度时,这会导致问题。为避免这种情况,请在checkpoint
函数外部分离张量。
- 参数
function – 描述在模型或模型部分的正向传递中运行的内容。它还应该知道如何处理作为元组传递的输入。例如,在 LSTM 中,如果用户传递
(activation, hidden)
,则function
应正确地将第一个输入用作activation
,将第二个输入用作hidden
preserve_rng_state (bool,可选) – 在每个检查点期间省略隐藏和恢复 RNG 状态。请注意,在 torch.compile 下,此标志不起作用,我们始终保留 RNG 状态。默认值:
True
use_reentrant (bool) – 指定是否使用需要可重入自动梯度的激活检查点变体。此参数应明确传递。在版本 2.4 中,如果未传递
use_reentrant
,我们将引发异常。如果use_reentrant=False
,checkpoint
将使用不需要可重入自动梯度的实现。这允许checkpoint
支持附加功能,例如按预期与torch.autograd.grad
配合使用,以及支持输入到检查点函数中的关键字参数。context_fn (Callable,可选) – 返回两个上下文管理器的元组的可调用对象。函数及其重新计算将分别在第一个和第二个上下文管理器下运行。仅当
use_reentrant=False
时才支持此参数。determinism_check (str,可选) – 指定要执行的确定性检查的字符串。默认情况下,它设置为
"default"
,它将重新计算的张量的形状、数据类型和设备与已保存张量的形状、数据类型和设备进行比较。要关闭此检查,请指定"none"
。目前,这些是仅支持的两个值。如果您希望看到更多确定性检查,请打开一个问题。仅当use_reentrant=False
时才支持此参数,如果use_reentrant=True
,则始终禁用确定性检查。debug (bool, 可选) – 如果
True
,错误消息还将包括在原始前向计算和重新计算期间运行的运算符的跟踪。仅当use_reentrant=False
时才支持此参数。args – 包含
function
输入的元组
- 返回
在
*args
上运行function
的输出
- torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source]¶
检查点顺序模型以节省内存。
顺序模型按顺序(顺序地)执行模块/函数的列表。因此,我们可以将这样的模型划分为不同的段并对每个段进行检查点。除了最后一个段之外的所有段都将不存储中间激活。每个检查点段的输入将被保存,以便在反向传递中重新运行该段。
警告
应明确传递
use_reentrant
参数。在版本 2.4 中,如果未传递use_reentrant
,我们将引发异常。如果您使用use_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False
。- 参数
functions – 要顺序运行的
torch.nn.Sequential
或模块或函数列表(组成模型)。segments – 在模型中创建的块数
input – 输入到
functions
的张量preserve_rng_state (bool, 可选) – 在每次检查点期间省略隐藏和恢复 RNG 状态。默认值:
True
use_reentrant (bool) – 指定是否使用需要可重入自动梯度的激活检查点变体。此参数应明确传递。在版本 2.4 中,如果未传递
use_reentrant
,我们将引发异常。如果use_reentrant=False
,checkpoint
将使用不需要可重入自动梯度的实现。这允许checkpoint
支持附加功能,例如按预期与torch.autograd.grad
配合使用,以及支持输入到检查点函数中的关键字参数。
- 返回
按顺序在
*inputs
上运行functions
的输出
示例
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)
- torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]¶
上下文管理器,用于设置检查点在运行时是否应打印其他调试信息。有关更多信息,请参阅
checkpoint()
的debug
标志。请注意,设置后,此上下文管理器将覆盖传递给检查点的debug
的值。要推迟到本地设置,请将None
传递给此上下文。- 参数
enabled (bool) – 检查点是否应打印调试信息。默认值为“None”。