作者:PyTorch 团队

随着模型在深度、批量大小、序列长度等方面的扩展,激活内存对整体内存使用量的贡献越来越大。为了解决这个问题,PyTorch 提供了激活检查点(activation checkpointing)工具,通过在需要时重新计算来减少保存张量的数量,从而用额外的计算换取内存使用量的减少。

在这篇文章中,我们将介绍激活内存的基本概念,现有激活检查点技术背后的高级思想,并引入一些旨在提高灵活性并提供更多开箱即用优化/自动化功能的新技术。

在研究这些技术时,我们将比较这些方法如何融入速度与内存权衡图,并希望能为如何根据您的用例选择合适的策略提供一些见解。

(如果您更喜欢直接跳到新的 API,请跳到下面的“选择性激活检查点”和“内存预算 API”部分。)

flow diagram


激活内存基础

默认情况下,在即时模式(而非使用 torch.compile)下,PyTorch 的 autograd 会保留中间激活用于反向计算。例如,如果您在前向传播期间对张量 x 调用 sin,autograd 必须记住 x 以便在反向传播期间计算 cos(x)

flow diagram

如果这个张量 x 在前向传播开始时被保存,它将贯穿前向和反向两个阶段始终保存在内存中。它只有在用于计算梯度后才能被清除,这发生在反向传播结束时(由于执行顺序相反)。

因此,当前向传播进行并执行越来越多的操作时,您会积累越来越多的激活,导致越来越多的激活内存,直到(通常)在反向传播开始时达到峰值(此时激活可以开始被清除)。

flow diagram

上图中的橙色框代表操作,黑色箭头代表它们的张量输入和输出。右侧的黑色箭头代表 autograd 为反向传播保存的张量。

一种有用的方式来可视化组织即时模式下的这种默认保存行为以及我们将要介绍的技术,是基于它们如何在速度和内存之间进行权衡。

flow diagram

理想的位置是左上角,在那里您具有“高”速度但内存使用量也很低。

我们首先将默认保存行为放在右上角(原因将在介绍其他技术的更多点时详细解释)。


激活检查点 (AC)

激活检查点 (AC) 是 PyTorch 中一种流行的减少内存使用量的技术。

在前向传播期间,在 AC 区域内执行的任何操作都不会为反向传播保存张量。(仅保存函数的输入。)在反向传播期间,通过第二次运行函数来重新生成梯度计算所需的中间激活。

flow diagram

在图(右)中,黑框显示应用激活检查点的位置。与默认的即时模式方法(左)相比,这种设置导致保存的张量更少(1 个对比 3 个)。

对模型的正确部分应用 AC 可以降低峰值内存,因为在内存使用量通常达到峰值时(反向传播开始时),中间激活不再在内存中具象化。

在速度与内存权衡图上,AC 绘制在左下角。相对于即时模式,它减少了为反向传播保存的内存量,但由于重新计算而增加了计算成本。

flow diagram

请注意,AC 的速度-内存权衡 /可以/ 通过选择要检查点的前向传播部分和定义使用多少个检查点区域来调整。然而,实现这些更改可能需要修改模型的结构,并且根据您的代码组织方式可能会很麻烦。为了本图的目的,我们假设只检查点一个区域;在此假设下,AC 在权衡图上显示为一个点。

另请注意,“内存”在此处不指峰值内存使用量;它表示为固定区域保存了多少内存用于反向传播。


torch.compile 和最小割分区器

另一个需要注意的重要方法是 torch.compile(在 PyTorch 2.0 中引入)。与激活检查点类似,torch.compile 也可以在底层执行一定程度的重新计算。具体来说,它将前向和反向计算跟踪到一个联合图中,然后由“最小割”分区器进行处理。该分区器使用最小割/最大流算法来分割图,从而最大限度地减少需要为反向传播保存的张量数量。

乍一看,这可能听起来很像我们想要的激活内存减少。然而,现实情况更为微妙。默认情况下,分区器的主要目标是减少运行时。因此,它只重新计算某些类型的操作——主要是更简单、可融合且计算密集度不高的操作(如逐点操作)。

将“compile”放在速度与内存权衡图上...

flow diagram

它位于非 AC 即时模式点的左上方,因为我们期望 torch.compile 在速度和内存方面都有改进。

另一方面,相对于激活检查点,torch.compile 在重新计算方面更为保守,使其在速度与内存图上更靠近左上方。


选择性激活检查点 [新增!]

通常的检查点会重新计算选定区域内的所有操作,而选择性激活检查点 (SAC) 是在激活检查点之上应用的一项额外设置,您可以对其进行更细粒度的控制,以决定重新计算哪些操作。

如果您有一些更昂贵的操作(如矩阵乘法),您希望避免重新计算,但通常仍希望重新计算较便宜的操作(如逐点操作),则这会很有用。

flow diagram

普通的 AC(左)会保存一个张量然后重新计算整个 AC 区域,而使用 SAC(右),您可以选择性地保存区域中的特定操作(标记为红色),这样您就可以避免重新计算它们。

要指定要选择性保存的内容,您可以指定一个 policy_fn。为了说明您可以通过此进行的其他权衡,我们提供了两个简单的策略函数。

策略 1:不重新计算矩阵乘法

aten = torch.ops.aten
compute_intensive_ops = [  
        aten.mm,
        aten.bmm,
        aten.addmm,
] 
def policy_fn(ctx, op, *args, **kwargs):
    if op in compute_intensive_ops:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE

flow diagram

策略 2:更积极地保存任何计算密集型操作

# torch/_functorch/partitioners.py
aten = torch.ops.aten
compute_intensive_ops = [  
   aten.mm,
   aten.convolution,
   aten.convolution_backward,
   aten.bmm,
   aten.addmm,
   aten._scaled_dot_product_flash_attention,
   aten._scaled_dot_product_efficient_attention,
   aten._flash_attention_forward,
   aten._efficient_attention_forward,
   aten.upsample_bilinear2d,
   aten._scaled_mm
] 
def policy_fn(ctx, op, *args, **kwargs):
    if op in compute_intensive_ops:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE

flow diagram

在速度与内存图上,SAC 根据您选择的策略,绘制为从接近 AC 到接近即时模式的一系列点。

flow diagram

试试吧!(在 2.5 中作为原型功能提供;请参阅文档了解更多信息 + 可直接复制粘贴的示例)

from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts

# Create a policy function that returns a CheckpointPolicy
def policy_fn(ctx, op, *args, **kwargs):
    if op in ops_to_save:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE

# Use the context_fn= arg of the existing checkpoint API
out = checkpoint(
    fn, *args,
    use_reentrant=False,
    # Fill in SAC context_fn's policy_fn with functools.partial
    context_fn=partial(create_selective_checkpoint_contexts, policy_fn),
)


(仅限 compile)内存预算 API [新增!]

如前所述,任何给定的 SAC 策略都可以表示为速度-内存权衡图上的一个点。然而,并非所有策略都是相同的。“最优”策略是那些落在帕累托曲线上的策略,例如,对于所有产生相同内存开销的策略,此策略是使所需计算量最小化的策略。

对于使用 torch.compile 的用户,我们提供了一个内存预算 API,它根据用户指定的介于 0 和 1 之间的“内存预算”,自动在编译区域上应用具有帕累托最优策略的 SAC,其中预算为 0 的行为类似于普通的 AC,预算为 1 的行为类似于默认的 torch.compile。

flow diagram

以下是一些在 transformer 模型上的实际结果

flow diagram

我们观察到,通过仅重新计算逐点操作,内存减少了 50%,并且随着您重新计算越来越多的矩阵乘法,内存稳定下降。注意力机制是最昂贵的,所以您通常希望最后才重新计算它们。

试试吧!(在 2.4 中作为实验性功能提供;请参阅此注释块了解更多信息)

torch._dynamo.config.activation_memory_budget = 0.5

out = torch.compile(fn)(inp)

结论

flow diagram

总而言之,PyTorch 中的激活检查点技术提供了多种平衡内存和计算需求的方式,从简单的基于区域的检查点到更具选择性和自动化的方法。通过选择最适合您的模型结构和资源限制的选项,您可以在可接受的计算权衡下实现显著的内存节省。

致谢

我们要感谢 Meta 的 xformers 团队,包括 Francisco Massa,他们参与了选择性激活检查点的原始版本。