博客

PyTorch 中现有及全新的激活检查点技术

作者 2025年3月5日2025年5月3日暂无评论

随着模型在深度、批次大小(batch size)和序列长度等维度上的扩展,激活内存(activation memory)在整体内存占用中的占比日益显著。为了解决这一问题,PyTorch 提供了激活检查点(activation checkpointing)工具,通过在需要时重新计算张量来减少保存的张量数量,从而以额外的计算开销换取内存空间的节省。

在本文中,我们将介绍激活内存的基础知识,探讨现有激活检查点技术背后的高层设计思想,并介绍一些旨在提高灵活性和提供更多即插即用型优化/自动化能力的新技术。

在审视这些技术时,我们将比较它们在“速度与内存”权衡图中的位置,并希望能为如何为您的用例选择合适的策略提供一些见解。

(如果您希望直接了解新的 API,请直接跳转到下方的“选择性激活检查点”和“内存预算 API”部分。)

flow diagram

激活内存基础

默认情况下,在 Eager 模式(而非使用 torch.compile)下,PyTorch 的自动求导(autograd)会保留中间激活以用于反向传播计算。例如,如果您在正向传播过程中对张量 x 调用 sin 函数,自动求导机制必须记住 x,以便在反向传播过程中计算 cos(x)

flow diagram

如果该张量 x 在正向传播开始时被保存,它将在整个正向和反向传播阶段都保留在内存中。它只有在用于计算梯度后才能被释放,这通常发生在反向传播的末尾(由于执行顺序是逆序的)。

因此,随着您在正向传播中执行越来越多的操作,您会积累越来越多的激活值,导致激活内存不断增长,直到它(通常)在反向传播开始时达到峰值(此时激活值才能开始被释放)。

flow diagram

在上图中,橙色方框代表操作,黑色箭头代表其张量输入和输出。交叉向右的黑色箭头代表自动求导机制为了反向传播而保存的张量。

一种直观地整理 Eager 模式下的这种默认保存行为以及我们将要介绍的技术的方法,是基于它们如何在速度与内存之间进行权衡。

flow diagram

该图表中最理想的位置是左上角,即既拥有“高”速度又保持低内存占用。

我们首先将默认保存行为放在右上角(原因将在我们介绍其他技术时详细解释)。


激活检查点 (AC)

激活检查点 (AC) 是一种在 PyTorch 中减少内存占用的常用技术。

在正向传播过程中,AC 区域内执行的任何操作都不会为反向传播保存张量(仅保存该函数的输入)。在反向传播过程中,通过第二次运行该函数来重新生成梯度计算所需的中间激活值。

flow diagram

在(右侧)图中,黑色方框显示了激活检查点的应用位置。与默认的 Eager 方法(左侧)相比,这种设置减少了需要保存的张量数量(从 3 个减少到 1 个)。

在模型右侧部分应用 AC 可以起到降低峰值内存的作用,因为在内存使用通常达到峰值时(反向传播开始时),中间激活值不再驻留在内存中。

在速度与内存权衡图中,AC 被绘制在左下角。相对于 Eager 模式,它减少了为反向传播保存的内存量,但由于需要重新计算,带来了额外的计算成本。

flow diagram

请注意,AC 的速度与内存权衡可以通过选择正向传播中需要设置检查点的部分以及定义使用的检查点区域数量来进行调整。然而,实施这些更改可能需要修改您的模型结构,并且根据代码组织方式的不同,可能会比较繁琐。为了本图表的目的,我们假设仅对一个区域设置检查点;在此假设下,AC 在权衡图上表现为一个单点。

此外还要注意,这里的“内存”并不指峰值内存使用量;相反,它表示对于固定区域为反向传播保存了多少内存。


torch.compile 与最小割分区器 (min-cut partitioner)

另一个值得注意的方法是 torch.compile(在 PyTorch 2.0 中引入)。与激活检查点一样,torch.compile 也可以在底层执行一定程度的重新计算。具体而言,它将正向和反向计算追踪到一个联合图中,然后由“最小割(min-cut)”分区器进行处理。该分区器使用最小割/最大流算法对图进行分割,从而最大限度地减少为反向传播需要保存的张量数量。

乍一看,这听起来很像我们想要的激活内存缩减方案。然而,事实更加微妙。默认情况下,分区器的主要目标是减少运行时间。因此,它只重新计算特定类型的操作——主要是简单的、可融合的且非计算密集型的操作(例如逐点运算/pointwise ops)。

将“compile”放在速度与内存权衡图上……

flow diagram

它位于 Eager 非 AC 点的左上方,因为我们预期 torch.compile 在速度和内存方面都有所改进。

另一方面,相对于激活检查点,torch.compile 在重新计算的内容上更为保守,使其在速度与内存图上更靠近左上角。


选择性激活检查点 (SAC) [新功能!]

虽然普通的检查点会重新计算选定区域中的每一个操作,但选择性激活检查点 (SAC) 是在激活检查点之上的附加设置,您可以利用它对哪些操作需要重新计算进行更精细的控制。

如果您有一些计算成本较高的操作(如矩阵乘法 matmuls)希望避免重新计算,但同时仍希望重新计算较廉价的操作(如逐点运算),这将非常有用。

flow diagram

在普通 AC(左侧)会保存单个张量并重新计算整个 AC 区域的情况下,使用 SAC(右侧)您可以有选择地保存区域中特定的操作(标记为红色),从而避免重新计算它们。

要指定选择性保存的内容,您可以定义一个 policy_fn。为了说明通过这种方式可以实现的额外权衡,我们提供了两个简单的策略函数。

策略 1:不重新计算矩阵乘法

aten = torch.ops.aten
compute_intensive_ops = [  
        aten.mm,
        aten.bmm,
        aten.addmm,
] 
def policy_fn(ctx, op, *args, **kwargs):
    if op in compute_intensive_ops:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE
flow diagram

策略 2:更激进地保存任何计算密集型操作

# torch/_functorch/partitioners.py
aten = torch.ops.aten
compute_intensive_ops = [  
   aten.mm,
   aten.convolution,
   aten.convolution_backward,
   aten.bmm,
   aten.addmm,
   aten._scaled_dot_product_flash_attention,
   aten._scaled_dot_product_efficient_attention,
   aten._flash_attention_forward,
   aten._efficient_attention_forward,
   aten.upsample_bilinear2d,
   aten._scaled_mm
] 
def policy_fn(ctx, op, *args, **kwargs):
    if op in compute_intensive_ops:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE
flow diagram

在速度与内存权衡图上,SAC 被绘制为一系列点,根据您选择的策略,其位置从靠近 AC 到靠近 Eager 之间波动。

flow diagram

快来试试吧!(在 2.5 版本中作为原型功能提供;请参阅文档获取更多信息和可直接复制的示例)

from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts

# Create a policy function that returns a CheckpointPolicy
def policy_fn(ctx, op, *args, **kwargs):
    if op in ops_to_save:
        return CheckpointPolicy.MUST_SAVE
    else:
        return CheckpointPolicy.PREFER_RECOMPUTE

# Use the context_fn= arg of the existing checkpoint API
out = checkpoint(
    fn, *args,
    use_reentrant=False,
    # Fill in SAC context_fn's policy_fn with functools.partial
    context_fn=partial(create_selective_checkpoint_contexts, policy_fn),
)


(仅限 compile)内存预算 API [新功能!]

如前所述,任何给定的 SAC 策略都可以表示为速度-内存权衡图上的一个点。然而,并非所有策略都是平等的。“最优”策略是那些落在帕累托前沿(pareto curve)上的策略,例如,对于所有导致相同内存开销的策略,该策略是能够最大限度减少所需计算量的策略。

对于使用 torch.compile 的用户,我们提供了内存预算 API,它会根据用户指定的 0 到 1 之间的“内存预算”,自动在您的编译区域上应用具有帕累托最优策略的 SAC。其中预算为 0 时表现得像普通 AC,预算为 1 时表现得像默认的 torch.compile。

flow diagram

以下是在 Transformer 模型上的一些真实测试结果:

flow diagram

我们观察到,通过仅重新计算逐点运算,内存减少了 50%,随着您重新计算越来越多的矩阵乘法,内存节省效果会稳步下降。Attention(注意力机制)是最昂贵的操作,因此您通常倾向于最后才重新计算它们。

快来试试吧!(在 2.4 版本中作为实验性功能提供;请参阅此注释块了解更多信息)

torch._dynamo.config.activation_memory_budget = 0.5

out = torch.compile(fn)(inp)

结论

flow diagram

总之,PyTorch 中的激活检查点技术提供了多种平衡内存和计算需求的方法,从简单的基于区域的检查点到更具选择性和自动化的方法。通过选择最符合您模型结构和资源约束的方案,您可以在可接受的计算权衡下实现显著的内存节省。

致谢

我们要感谢 Meta 的 xformers 团队,包括 Francisco Massa,感谢他们开发了最初版本的选择性激活检查点。