随着模型在深度、批次大小和序列长度等方面的扩展,激活内存对整体内存使用量的贡献越来越大。为了解决这个问题,PyTorch 提供了激活检查点实用程序,它通过在需要时重新计算张量来减少保存的张量数量,从而在内存使用和额外计算之间进行权衡。
在这篇文章中,我们将介绍激活内存的基础知识、现有激活检查点技术背后的高级思想,并介绍一些旨在提高灵活性和提供更多开箱即用优化/自动化的新颖技术。
当我们审视这些技术时,我们将比较这些方法如何适应速度与内存的权衡图,并希望能提供一些关于如何为您的用例选择正确策略的见解。
(如果您想直接了解新 API,请跳到下面的“选择性激活检查点”和“内存预算 API”部分。)

激活内存基础知识
默认情况下,在即时模式下(而不是使用 torch.compile
),PyTorch 的自动微分会保留中间激活以进行反向计算。例如,如果在前向传播期间对张量 x
调用 sin
,自动微分必须记住 x
才能在反向传播期间计算 cos(x)
。

如果这个张量 x
在前向传播开始时被保存,它会在前向和反向阶段都保留在内存中。它只能在用于计算梯度后才能被清除,这发生在反向传播结束时(由于执行顺序相反)。
因此,随着您在前向传播中进行越来越多的操作,您会积累越来越多的激活,从而导致越来越多的激活内存,直到它(通常)在反向传播开始时达到峰值(此时激活可以开始被清除)。

上图中,橙色框表示操作,黑色箭头表示它们的张量输入和输出。向右穿过的黑色箭头表示自动微分为反向传播保存的张量。
一种有用的可视化组织即时模式下的默认保存行为以及我们将要介绍的技术的方法是根据它们如何在速度与内存之间进行权衡。

在这个图中,理想的位置是左上角,您拥有“高”速度和低内存使用量。
我们首先将默认保存行为放在右上角(原因我们将在介绍其他技术的更多点时更详细地解释)。
激活检查点 (AC)
激活检查点 (AC) 是 PyTorch 中一种流行的减少内存使用量的技术。
在前向传播期间,在 AC 区域内执行的任何操作都不会保存张量以进行反向传播。(只保存函数的输入。)在反向传播期间,通过第二次运行函数来重新实例化梯度计算所需的中间激活。

在图(右)中,黑框显示了应用激活检查点的位置。与默认的即时方法(左)相比,此设置导致保存的张量更少(1 对 3)。
在模型的正确部分应用 AC 可以降低峰值内存,因为当内存使用量通常达到峰值时(在反向传播开始时),中间激活不再实例化在内存中。
在速度与内存权衡图上,AC 绘制在左下角。相对于即时模式,它减少了为反向传播保存的内存量,但由于重新计算而增加了计算成本。

请注意,AC 的速度-内存权衡可以根据选择前向传播的哪些部分进行检查点以及定义要使用的检查点区域的数量进行调整。但是,实施这些更改可能需要修改模型的结构,并且根据代码的组织方式可能会很麻烦。为了本图的目的,我们假设只检查点一个区域;在此假设下,AC 在权衡图上显示为一个点。
另请注意,“内存”在此处不是指峰值内存使用量;相反,它表示为固定区域保存了多少内存以进行反向传播。
torch.compile 和最小割分区器
另一个需要注意的著名方法是 torch.compile(PyTorch 2.0 中引入)。与激活检查点类似,torch.compile
也可以在底层执行一定程度的重新计算。具体来说,它将前向和反向计算跟踪到一个单一的联合图,然后由一个“最小割”分区器进行处理。此分区器使用最小割/最大流算法来分割图,从而最大限度地减少需要为反向传播保存的张量数量。
乍一看,这听起来很像我们希望实现激活内存减少的目标。然而,现实情况更为微妙。默认情况下,分区器的主要目标是减少运行时。因此,它只重新计算某些类型的操作——主要是更简单、可融合且非计算密集型的操作(如逐点操作)。
将“compile”放在速度与内存权衡图上……

它位于非 AC 即时点的左上方,因为我们期望 torch.compile
能够提高速度和内存。
另一方面,相对于激活检查点,torch.compile 在重新计算方面更为保守,将其放置在速度与内存图上更靠近左上方的位置。
选择性激活检查点 [新!]
虽然普通检查点会重新计算所选区域中的每个操作,但选择性激活检查点 (SAC) 是在激活检查点之上的一种额外设置,您可以应用它来更精细地控制要重新计算的操作。
如果您有一些更昂贵的操作(如矩阵乘法)您更喜欢避免重新计算,但仍然普遍希望重新计算更便宜的操作(如逐点操作),这会很有用。

普通 AC(左)会保存单个张量,然后重新计算整个 AC 区域,而 SAC(右)允许您有选择地保存区域中的特定操作(红色标记),这样您就可以避免重新计算它们。
要指定要选择性保存的内容,您可以指定一个 policy_fn。为了说明您可以进行的其他权衡,我们提供了两个简单的策略函数。
策略 1:不重新计算矩阵乘法
aten = torch.ops.aten
compute_intensive_ops = [
aten.mm,
aten.bmm,
aten.addmm,
]
def policy_fn(ctx, op, *args, **kwargs):
if op in compute_intensive_ops:
return CheckpointPolicy.MUST_SAVE
else:
return CheckpointPolicy.PREFER_RECOMPUTE

策略 2:更积极地保存任何计算密集型操作
# torch/_functorch/partitioners.py
aten = torch.ops.aten
compute_intensive_ops = [
aten.mm,
aten.convolution,
aten.convolution_backward,
aten.bmm,
aten.addmm,
aten._scaled_dot_product_flash_attention,
aten._scaled_dot_product_efficient_attention,
aten._flash_attention_forward,
aten._efficient_attention_forward,
aten.upsample_bilinear2d,
aten._scaled_mm
]
def policy_fn(ctx, op, *args, **kwargs):
if op in compute_intensive_ops:
return CheckpointPolicy.MUST_SAVE
else:
return CheckpointPolicy.PREFER_RECOMPUTE

在速度与内存图上,SAC 根据您选择的策略,绘制为从接近 AC 到接近 Eager 的一系列点。

试试看! (在 2.5 版本中作为原型功能提供;请参阅文档了解更多信息 + 可复制粘贴的示例)
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
# Create a policy function that returns a CheckpointPolicy
def policy_fn(ctx, op, *args, **kwargs):
if op in ops_to_save:
return CheckpointPolicy.MUST_SAVE
else:
return CheckpointPolicy.PREFER_RECOMPUTE
# Use the context_fn= arg of the existing checkpoint API
out = checkpoint(
fn, *args,
use_reentrant=False,
# Fill in SAC context_fn's policy_fn with functools.partial
context_fn=partial(create_selective_checkpoint_contexts, policy_fn),
)
(仅编译)内存预算 API [新!]
如前所述,任何给定的 SAC 策略都可以表示为速度-内存权衡图上的一个点。然而,并非所有策略都是平等的。“最优”策略是那些落在帕累托曲线上的策略,例如,对于所有产生相同内存开销的策略,此策略是使所需计算量最小的策略。
对于使用 torch.compile 的用户,我们提供了内存预算 API,它会根据用户指定的 0 到 1 之间的“内存预算”,在您编译的区域上自动应用帕累托最优策略的 SAC,其中预算为 0 时表现为普通 AC,预算为 1 时表现为默认 torch.compile。

以下是变压器模型的一些真实结果

我们观察到仅通过重新计算逐点操作即可减少 50% 的内存,并且随着您重新计算越来越多的矩阵乘法,内存会稳步下降。注意力是最昂贵的,所以您倾向于最后重新计算它们。
试试看! (在 2.4 版本中作为实验功能提供;请参阅此注释块了解更多信息)
torch._dynamo.config.activation_memory_budget = 0.5
out = torch.compile(fn)(inp)
结论

总而言之,PyTorch 中的激活检查点技术提供了多种平衡内存和计算需求的方法,从简单的基于区域的检查点到更具选择性和自动化的方法。通过选择最适合您的模型结构和资源限制的选项,您可以在计算方面进行可接受的权衡,从而显着节省内存。
致谢
我们要感谢 Meta 的 xformers 团队,包括 Francisco Massa,感谢他们对选择性激活检查点原始版本所做的工作。