快捷方式

FSDP 笔记

FSDP 预取细微之处

对于与 forward 计算重叠的 forward all-gather,有两种可能的机制

  1. 隐式前向预取(始终启用)

  2. 显式前向预取 (forward_prefetch=True)

隐式 forward 预取指的是依赖于从单独的 CUDA 流发出 all-gather,以便将 all-gather 与之前发出的 forward 计算(从 CPU 角度来看)重叠。例如,如果我们有 layer 0 all-gather -> layer 0 forward 计算 -> layer 1 all-gather -> …,那么 layer 1 all-gather 可以与 layer 0 forward 计算重叠,即使 CPU 线程稍后发出它。(第一个 all-gather 将无法与任何内容重叠。)

显式 forward 预取指的是更改 CPU 线程的发出顺序:例如 layer 0 all-gather -> layer 1 all-gather -> layer 0 forward 计算 -> …。在 eager 模式下,通常无法知道哪个层是下一层(例如示例中的 layer 1),当仍在 layer 0 上执行时。因此,显式 forward 预取仅应用于执行顺序从迭代到迭代固定的模型(我们有时称之为“静态图”)。不满足此约束的模型示例是 FLAVA)。

显式 forward 预取仅节省了发出层 forward 计算内核所花费的时间,代价是必须在当前 all-gather 的输出张量仍在使用的同时分配下一个 all-gather 的输出张量。通过在当前 forward 计算内核之前发出下一个 all-gather,下一个 all-gather 可以在 GPU 上更快地启动。对于大多数 LLM 工作负载,情况并非如此,因此没有启用 forward_prefetch=True 的动机。

相比之下,对于 backward,我们必须使用显式 backward 预取,否则通信和计算将完全没有重叠。原因是我们在 all-gather 和 reduce-scatter 中都使用单个 NCCL 进程组(部分原因是早期 NCCL 版本中,在同一设备上通过相同的 ranks 并发使用多个进程组是不安全的)。单个 NCCL 进程组意味着单个内部 NCCL 流,reduce-scatter 和 all-gather 在该流上串行运行。因此,除非我们显式地将 CPU 发出顺序重新排序为下一个 all-gather -> 当前 reduce-scatter,否则当前 reduce-scatter 将会阻塞下一个 all-gather,从而阻塞下一个 backward 计算,阻止当前 reduce-scatter 重叠。

通信负载大小

在 FSDP 中,通信包括

  1. forward 中的参数上的 all-gather

  2. backward 中的参数上的 all-gather

  3. backward 中的梯度上的 reduce-scatter

如果使用了激活检查点 (checkpoint()),则没有额外的通信,因为参数在 backward 期间无论如何都会被预取。

在 FSDP 设计中,每个 rank 的通信负载按如下方式确定:每次调用 FullyShardedDataParallel 都会创建一个通信组,该组包含 module.parameters() 中的参数,但不包括已分配给嵌套 FullyShardedDataParallel 实例的任何参数。例如,对于 Llama,如果您将 FullyShardedDataParallel 应用于每个 transformer 块以及根模块,那么每个 transformer 块都会有一个通信组,最后还会有一个包含初始 embedding 和最终线性层的通信组。每个通信组对应于单个 all-gather 调用和单个 reduce-scatter 调用。通过这种方式,您如何应用 FullyShardedDataParallel 决定了通信大小。一般来说,将 FSDP 应用于每个 transformer 块是 LLM 的一个很好的启发式方法,并且在当前设计下很难做得更好。

让我们考虑一个示例,其中我们有一个基于 Transformer 的模型在 8 个 GPU 上分片,其中分片仅发生在 transformer 块级别,并且每个 transformer 块包含 1.6B 参数,参数为 fp32(每个参数 4 字节)。这意味着一旦分片,每个 transformer 块在每个 rank 上将包含 0.2B 参数。

  • forward 传递将在 all-gather 中通信 0.2*4 = 0.8GB 的块

  • backward 传递将通信 2 次,每次 0.8GB(1 次 all-gather 和 1 次 reduce-scatter)

换句话说,将有 3 次通信,每次通信的负载为 0.8GB。如果模型由 10 个 transformer 块组成,则总共将有 30 次通信,总共 30*0.8=24GB

形式化表示,每个 rank 每次通信的负载大小为 total_transformer_block_params_in_B*dtype_bytes/num_gpus (GBs)。

请注意,在本示例中,我们没有包括 embedding 所需的额外通信,这也应该考虑在内。并且计算将取决于输入和输出 embedding 是否绑定。如果它们未绑定,则将有 2 倍以上的通信。

FSDP 缓冲区大小

首先,让我们介绍为通信分配的缓冲区

forward 当前需要 2 倍 all-gather 缓冲区大小。原因如下

FSDP 预取细微之处 中解释的,在显式 forward 预取 (forward_prefetch=True`) 情况下,对于 layer 0 all-gather -> layer 0 forward 计算 -> layer 1 all-gather ,需要 2 all-gather 大小的缓冲区,因为一个缓冲区用于当前的 ``forward,而另一个缓冲区用于执行预取。

虽然隐式 forward 预取 (forward_prefetch=False,默认) 情况下,相同的序列理论上应该只需要 1 个缓冲区,但实际上仍然是 2 倍 all-gather 大小的缓冲区。原因是,在 flat-parameter FSDP 设计中,我们不从 all-gather 缓冲区中复制出来。用于计算的参数直接查看 all-gather 缓冲区(事实上,“flat parameter”的主要好处正是这个原因)。在这种情况下,当 ‘layer 1 all-gather’ 与 ‘layer 0 forward 计算’ 重叠时,‘layer 0 forward 计算’ 正在使用查看 ‘layer 0 all-gather’ 缓冲区的参数。

那么,一个自然的问题是,您何时想要 forward_prefetch=False?对于静态图模型(如大多数 LLM),有一个主要的技术原因。更确切地说,实际上,我们为某些 CPU 密集型内部模型快速添加了这个选项,并且没有在单元测试中测试每个代码路径,因此我们对此不太自信。forward_prefetching=False 可能更容易推理,因为我们不必检查记录的前向顺序作为可能的“失败模式”;模块的 all-gather 始终可以在其分析器跟踪中的自身 record_function 标签下找到。

backward 当前至少需要 2 倍 all-gather 缓冲区大小,并且可能更多一点。原因如下

当前的 FSDP 设计使用 recordStream 来管理在一个流中生成并在另一个流中使用的分配,这可能导致比预期更多的内存使用。更多多少可能是“不确定的”,因为它取决于 GPU 内核计时相对于 CPU 的时间。limit_all_gathers=True 参数是对这种情况的缓解措施 - 更多详细信息请参考此讨论 FSDP & CUDACachingAllocator

现有 FSDP 与 autograd 的工作方式

  • 现有 FSDP all-gather flat_param,它是 autograd 叶节点。

  • 它调用 torch.split 以获取 flat_param 的 1D 视图,对应于其组成的原始参数。

  • 它在每个 1D split 上调用 torch.view 以查看回 ND。

  • 这意味着在 backward 中,我们最终得到 ViewBackward (ND -> 1D) 和 SplitWithSizesBackward(它是 concat)。特别是,每个单独的梯度都作为单独的分配计算,并且发生显式 concat 以构造 reduce-scatter 输入缓冲区。这意味着实际上在峰值内存点为 reduce-scatter 提供了 2 倍缓冲区大小。

总而言之,对于 backward,大约是 reduce-scatter 的 2 倍缓冲区大小,加上任何 recordStream 效应。

其次,让我们讨论额外的缓冲区

一旦从所有 ranks 收集到分片参数,它们就需要一个额外的缓冲区 total_transformer_block_params_in_B*dtype_bytes 用于完整参数 - 因此继续之前的示例,如果每个 transformer 块是 1.6B 参数,并且参数是 fp32,那么它将是 1.6*4=6.4GB 缓冲区。

并且需要 2 个这样的缓冲区,因为有一个当前正在使用,另一个正在预取。

总结一下,我们有

  1. 2 倍的通信缓冲区,大小为 total_transformer_block_params_in_B*dtype_bytes/num_gpus

  2. 2 倍的未分片 transformer 块参数缓冲区 ``total_transformer_block_params_in_B*dtype_bytes

或者如果您一直在关注示例

  1. 2*1.6*4/8=1.6GB

  2. 2**1.6*4=12.8GB

总共 14.4GB

现在让我们简要讨论一下 embedding 发生了什么,因为我们已将这些从计算中排除

鉴于我们讨论的规则,您包含在以“通信缓冲区大小按如下方式确定”开头的注释中,我们可以按如下方式分析

  • 假设我们将 FSDP 应用于根模块(例如 Transformer 类)。假设我们进一步将 FSDP 应用于每个 transformer 块(例如 TransformerBlock 类)。

  • 最常见的是,embedding 和最终线性投影是根 Transformer 类的直接子项。

  • 按照我们的规则,这意味着 embedding 和最终线性投影被分配给根 Transformer 的 flat parameter。

  • 我们有 _另一个_ 特殊规则,即根模块在前向传播后不会释放其参数,因为无论如何它们都会在反向传播中立即被 all-gather。

  • 综上所述,这意味着根模块的 flat parameter(包括 embedding 和最终投影)在开始前向传播时都会被 all-gather,并保留在 GPU 内存中直到反向传播结束。

  • 如果 embedding 和最终线性层未绑定权重,那么我们 _可以_ 进一步将 FSDP 应用于 embedding 和最终线性层。对于权重绑定的参数,我们要求它们成为同一 flat parameter 的一部分(否则会重复计数)。这将允许 embedding 在其在前向传播中使用后被释放,并且仅在反向传播结束时才被 all-gather。

  • 希望这能更好地理解 - 每个 FSDP 模块都会分配其 module.parameters 中的参数,但已分配给另一个嵌套 FSDP 模块的参数除外,并且 FSDP 模块的 forward 定义了其参数的“活动”间隔。因此,嵌套的 nn.Module 结构会影响 all-gather/free 调度,从而影响内存/吞吐量性能。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源