快捷方式

FSDP 注释

FSDP 预取细微差别

对于重叠 forward 全部收集与 forward 计算,有两种可能的机制

  1. 隐式前向预取(始终启用)

  2. 显式前向预取(forward_prefetch=True

隐式 forward 预取是指依赖于从单独的 CUDA 流发出全部收集,以允许全部收集与之前发出的 forward 计算重叠(从 CPU 角度来看)。例如,如果我们有层 0 全部收集 -> 层 0 forward 计算 -> 层 1 全部收集 -> …,那么层 1 全部收集可以与层 0 forward 计算重叠,即使 CPU 线程在之后发出它。(第 1 次全部收集将无法与任何内容重叠。)

显式 forward 预取是指更改 CPU 线程的发行顺序:例如,第 0 层全聚合 -> 第 1 层全聚合 -> 第 0 层 forward 计算 -> …. 在 eager 模式下,在第 0 层上执行时,通常无法知道下一层是哪一层(例如,示例中的第 1 层)。因此,显式 forward 预取仅应用于执行顺序在每次迭代中都固定的模型(我们有时称之为“静态图”)。不满足此约束的模型示例为 FLAVA)。

显式 forward 预取仅节省了在层 forward 计算内核上发出的时间,代价是下一个全聚合的输出张量必须在当前张量仍在使用时分配。通过在当前 forward 计算内核之前发出下一个全聚合,下一个全聚合可以在 GPU 上更早启动。对于大多数 LLM 工作负载,情况并非如此,因此没有启用 forward_prefetch=True 的动机。

相比之下,对于 backward,我们必须使用显式 backward 预取,否则通信和计算将重叠 0。原因是我们对全聚合和归约散射都使用一个 NCCL 进程组(部分原因是,在较早的 NCCL 版本中,在同一设备上对同一等级同时使用多个 NCCL 进程组是不安全的)。单个 NCCL 进程组意味着单个内部 NCCL 流,在该流上归约散射和全聚合串行运行。因此,除非我们显式地重新排序 CPU 发行顺序为下一个全聚合 -> 当前归约散射,否则当前归约散射将阻塞下一个全聚合,从而导致下一个 backward 计算,从而阻止当前归约散射重叠。

通信有效负载大小

在 FSDP 中,通信是

  1. forward 中对参数进行全聚合

  2. backward 中对参数进行全聚合

  3. backward 中对梯度进行归约散射

如果使用激活检查点 (checkpoint()),则没有额外的通信,因为无论如何,参数都会在 backward 期间被预取。

在 FSDP 设计中,每个等级的通信有效负载确定如下:每次调用 FullyShardedDataParallel 都会创建一个通信组,其中包含 module.parameters() 中的参数,但任何已分配给嵌套 FullyShardedDataParallel 实例的参数除外。例如,对于 Llama,如果你将 FullyShardedDataParallel 应用于每个 transformer 块以及根模块,那么每个 transformer 块有一个通信组,最后有一个通信组包含初始嵌入和最终线性。每个通信组对应一个全聚合调用和一个归约散射调用。通过这种方式,你如何应用 FullyShardedDataParallel 决定了通信大小。通常,将 FSDP 应用于每个 transformer 块对于 LLM 来说是一个很好的启发式方法,并且很难在当前设计的基础上做得更好。

让我们考虑一个示例,其中我们有一个基于 Transformer 的模型,该模型在 8 个 GPU 上分片,分片仅在 transformer 块级别发生,每个 transformer 块包含 1.6B 参数,并且参数为 fp32(每个 4 字节)。这意味着一旦分片,每个 transformer 块将在每个等级上包含 0.2B 参数。

  • 在全聚合中,forward 传递将以 0.2*4 = 0.8GB 的块进行通信

  • The backward pass 会 2 次通信,每次 0.8GB(1x 全聚合和 1x 减少散射)

换句话说,将有 3 次通信,每次有效负载为 0.8GB。如果模型由 10 个 transformer 块组成,则总共有 30 次通信,总计 30*0.8=24GB

要将每个通信的有效负载大小正式化,每个等级为 total_transformer_block_params_in_B*dtype_bytes/num_gpus(GB)。

请注意,在此示例中,我们没有包括嵌入所需的额外通信,也应考虑这些通信。并且数学将取决于输入和输出嵌入是否绑定。如果它们没有绑定,将会有 2 倍的通信。

FSDP 缓冲区大小

首先,让我们介绍为通信分配的缓冲区

forward 当前需要 2 倍全聚合缓冲区大小。原因如下

正如 FSDP 预取细微差别 中所解释的,在显式 forward 预取 (forward_prefetch=True`) case of layer 0 all-gather -> layer 0 forward compute -> layer 1 all-gather there is a need for 2 all-gather-sized buffers, because one buffer is used in the current ``forward 而另一个用于进行预取。

虽然隐式 forward 预取 (forward_prefetch=False,默认值) 相同序列的案例理论上只需要 1 个缓冲区,但实际上仍然是 2 倍的 all-gather 大小缓冲区。原因在于,在扁平参数 FSDP 设计中,我们不会从 all-gather 缓冲区中复制出来。用于计算的参数直接查看 all-gather 缓冲区(事实上,“扁平参数”的主要好处正是这个原因)。在这种情况下,当“layer 1 all-gather”与“layer 0 forward compute”重叠时,“layer 0 forward compute”使用的是查看“layer 0 all-gather”缓冲区中的参数。

那么一个自然的问题是,你什么时候需要 forward_prefetch=False?对于静态图模型(如大多数 LLM),有一个主要的技術原因。更确切地说,我们在实践中为一些受 CPU 限制的内部模型快速添加了此选项,并且尚未在单元测试中使用它测试每个代码路径,因此我们对其不太有信心。 forward_prefetching=False 可能更容易推理,因为我们不必检查记录的前向顺序作为可能的“故障模式”;模块的 all-gather 始终可以在其分析器跟踪中的 record_function 标签下找到。

backward 目前至少需要 2 倍的 all-gather 缓冲区大小,并且可能还需要更多。原因如下

当前的 FSDP 设计使用 recordStream 来管理在一个流中产生的分配,而在另一个流中使用,这可能导致比预期更多的内存使用。更多可能是“非确定性的”,因为它取决于 GPU 内核相对于 CPU 的时序。 limit_all_gathers=True 参数是对此的缓解 - 有关更多详细信息,请参阅此讨论 FSDP & CUDACachingAllocator

现有 FSDP 与自动微分配合的方式

  • 现有的 FSDP 对 flat_param(自动微分叶)进行全聚合。

  • 它调用 torch.split 以获取 flat_param 中的 1D 视图,该视图对应于其原始组成参数。

  • 它对每个 1D 拆分调用 torch.view 以查看返回 ND。

  • 这意味着在 backward 中,我们最终得到 ViewBackward (ND -> 1D) 和 SplitWithSizesBackward(这是一个串联)。具体来说,每个单独的梯度都作为单独的分配计算,并且显式串联发生以构建 reduce-scatter 输入缓冲区。这意味着在该峰值内存点处,reduce-scatter 的缓冲区大小实际上为 2 倍。

总之,对于 backward,它大约是 reduce-scatter 的 2 倍缓冲区大小加上任何 recordStream 效果。

其次,让我们讨论额外的缓冲区

从所有秩收集分片参数后,它们需要一个额外的缓冲区 total_transformer_block_params_in_B*dtype_bytes 来存储完整参数 - 因此,如果继续前面的示例,如果每个 transformer 块是 1.6B 参数且参数为 fp32,那么它将是 1.6*4=6.4GB 缓冲区。

并且需要 2 个这样的缓冲区,因为当前正在使用一个,另一个正在预取。

总结一下,我们有

  1. 2 倍通信缓冲区 total_transformer_block_params_in_B*dtype_bytes/num_gpus

  2. 2 倍非分片 transformer 块参数缓冲区 ``total_transformer_block_params_in_B*dtype_bytes

或者如果你一直在关注示例

  1. 2*1.6*4/8=1.6GB

  2. 2**1.6*4=12.8GB

以及 14.4GB 的总和。

现在让我们简要讨论一下嵌入的情况,因为我们已经将它们从计算中排除在外

鉴于我们讨论的规则,您在从“通信缓冲区大小确定如下”开始的注释中包含了该规则,我们可以进行如下分析

  • 假设我们将 FSDP 应用于根模块(例如 Transformer 类)。假设我们进一步将 FSDP 应用于每个 transformer 块(例如 TransformerBlock 类)。

  • 最常见的是,嵌入和最终线性投影是根 Transformer 类的直接子类。

  • 根据我们的规则,这意味着嵌入和最终线性投影被分配给根 Transformer 的扁平参数。

  • 我们有 _另一个_ 特殊规则,即根在正向传播后不会释放其参数,因为它们无论如何都会立即在反向传播中全部收集。

  • 综合起来,这意味着包括嵌入和最终投影在内的根的扁平参数在正向传播开始时全部收集,并保留在 GPU 内存中,直到反向传播结束。

  • 如果嵌入和最终线性不是权重绑定的,那么我们 _可以_ 进一步将 FSDP 应用于嵌入和最终线性。对于权重绑定的参数,我们要求它们成为同一扁平参数的一部分(否则它将被双重计算)。这样就可以在正向传播中使用嵌入后释放嵌入,并且仅在反向传播结束时全部收集。

  • 希望这能提供更好的理解——每个 FSDP 模块在其 module.parameters 中分配参数,但已经分配给另一个嵌套 FSDP 模块的参数除外,并且 FSDP 模块的 forward 定义了其参数的“活动”间隔。因此,嵌套 nn.Module 结构会影响全部收集/释放计划,从而影响内存/吞吐量性能。

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源