快捷方式

FSDP 笔记

FSDP 预取细节

为了将 forward 的 all-gather 与 forward 计算重叠,有两种可能的机制

  1. 隐式前向预取(始终启用)

  2. 显式前向预取(forward_prefetch=True

隐式 forward 预取是指依赖于从单独的 CUDA 流发出 all-gather,以允许 all-gather 与之前发出的(从 CPU 视角看)forward 计算重叠。例如,如果我们有第 0 层 all-gather -> 第 0 层 forward 计算 -> 第 1 层 all-gather -> …,那么即使 CPU 线程在其后发出,第 1 层 all-gather 也可以与第 0 层 forward 计算重叠。(第一个 all-gather 无法与任何其他操作重叠。)

显式 forward 预取是指改变 CPU 线程的发出顺序:例如,第 0 层 all-gather -> 第 1 层 all-gather -> 第 0 层 forward 计算 -> …。在 eager 模式下,通常无法在执行第 0 层时知道下一层(例如示例中的第 1 层)是哪一层。因此,显式 forward 预取只能用于执行顺序在迭代之间固定的模型(有时称为“静态图”)。不满足此约束的模型示例包括 FLAVA)。

显式 forward 预取仅节省发出层级 forward 计算核函数所需的时间,代价是当前 all-gather 的输出张量仍在使用时,必须分配下一个 all-gather 的输出张量。通过在当前 forward 计算核函数之前发出下一个 all-gather,下一个 all-gather 可以在 GPU 上更快地开始。对于大多数 LLM 工作负载而言,情况并非如此,因此没有理由启用 forward_prefetch=True

相比之下,对于 backward,必须使用显式 backward 预取,否则通信和计算将完全没有重叠。原因在于我们使用单个 NCCL 进程组进行 all-gather 和 reduce-scatter(部分原因是早期 NCCL 版本在同一设备上使用相同的 ranks 并发使用多个进程组是不安全的)。单个 NCCL 进程组意味着 reduce-scatter 和 all-gather 在单个内部 NCCL 流上串行运行。因此,除非我们明确重新排序 CPU 发出顺序为下一个 all-gather -> 当前 reduce-scatter,否则当前 reduce-scatter 将阻塞下一个 all-gather,从而阻塞下一个 backward 计算,阻止当前 reduce-scatter 的重叠。

通信负载大小

在 FSDP 中,通信包括

  1. forward 中的参数 all-gather

  2. backward 中的参数 all-gather

  3. backward 中的梯度 reduce-scatter

如果使用激活检查点 (checkpoint()),则没有额外的通信,因为参数在 backward 期间 anyway 会被预取。

在 FSDP 设计中,每个 rank 的通信负载确定如下:每次调用 FullyShardedDataParallel 都会创建一个通信组,该组由 module.parameters() 中的参数组成,但已分配给嵌套 FullyShardedDataParallel 实例的参数除外。例如,对于 Llama,如果您对每个 Transformer 块以及根模块都应用 FullyShardedDataParallel,那么每个 Transformer 块都有一个通信组,最后根模块有一个包含初始嵌入和最终线性层的通信组。每个通信组对应于一次 all-gather 调用和一次 reduce-scatter 调用。因此,您如何应用 FullyShardedDataParallel 决定了通信大小。总的来说,对每个 Transformer 块应用 FSDP 是 LLM 的一个好的启发式方法,并且鉴于当前的设计,很难做得更好。

考虑一个例子,我们有一个基于 Transformer 的模型在 8 个 GPU 上分片,分片仅在 Transformer 块级别进行,每个 Transformer 块包含 1.6B 参数,参数为 fp32(每个 4 字节)。这意味着分片后,每个 Transformer 块在每个 rank 上包含 0.2B 参数。

  • forward 通信将以 0.2*4 = 0.8GB 为块进行 all-gather

  • backward 通信将进行 2 次,每次 0.8GB(1 次 all-gather 和 1 次 reduce-scatter)

换句话说,总共有 3 次通信,每次负载 0.8GB。如果模型包含 10 个 Transformer 块,则总共有 30 次通信,总计 30*0.8=24GB

正式地说,每次通信每个 rank 的负载大小为 total_transformer_block_params_in_B*dtype_bytes/num_gpus (GBs)。

请注意,在此示例中,我们未包含嵌入所需的额外通信,这也应予以考虑。并且计算方式取决于输入和输出嵌入是否绑定。如果未绑定,则通信次数将是 2 倍。

FSDP 缓冲区大小

首先,让我们来看为通信分配的缓冲区

forward 目前需要 2 倍 all-gather 缓冲区大小。原因如下:

FSDP 预取细节 中所解释的,在显式 forward 预取(forward_prefetch=True)的情况下,即第 0 层 all-gather -> 第 0 层 forward 计算 -> 第 1 层 all-gather,需要 2 个 all-gather 大小的缓冲区,因为一个缓冲区用于当前的 forward,而另一个用于进行预取。

尽管隐式 forward 预取(forward_prefetch=False,默认)在理论上只需要 1 个缓冲区,但实际上仍需要 2 倍 all-gather 大小的缓冲区。原因在于,在扁平参数 FSDP 设计中,我们不会从 all-gather 缓冲区中复制出来。用于计算的参数直接作为 all-gather 缓冲区的视图(实际上,“扁平参数”的主要好处正是这个原因)。在这种情况下,尽管“第 1 层 all-gather”与“第 0 层 forward 计算”重叠,但“第 0 层 forward 计算”正在使用作为“第 0 层 all-gather”缓冲区视图的参数。

那么一个很自然的问题是,何时会需要 forward_prefetch=False?对于静态图模型(如大多数 LLM),有一个主要的技术原因。更实际地说,我们快速为一些 CPU 绑定的内部模型添加了此选项,并且尚未在单元测试中测试其所有代码路径,因此我们对其信心不足。forward_prefetching=False 可能更容易理解,因为我们不必检查记录的前向顺序作为可能的“故障模式”;模块的 all-gather 始终可以在其自己的 record_function 标签下在其 profiler 跟踪中找到。

backward 目前至少需要 2 倍 all-gather 缓冲区大小,并且可能更多。原因如下:

当前的 FSDP 设计使用 recordStream 来管理在一个流中生成并在另一个流中使用的分配,这可能导致比预期更多的内存使用。增加多少可能“不确定”,因为它取决于 GPU 核函数计时相对于 CPU 的情况。limit_all_gathers=True 参数是对此的缓解措施 - 有关更多详细信息,请参阅此讨论 FSDP & CUDACachingAllocator

现有 FSDP 与自动求导的协作方式

  • 现有 FSDP 对 flat_param 执行 all-gather 操作,flat_param 是自动求导的叶节点。

  • 它调用 torch.split 以获取 flat_param 中与其组成的原始参数对应的 1D 视图。

  • 它对每个 1D 分割调用 torch.view 以将其视图恢复为 ND。

  • 这意味着在 backward 中,我们最终得到 ViewBackward(ND -> 1D)和 SplitWithSizesBackward(它是一个 concat)。特别是,每个单独的梯度都计算为一个单独的分配,并且会发生显式 concat 来构建 reduce-scatter 输入缓冲区。这意味着在该内存峰值点,reduce-scatter 的缓冲区大小实际上是 2 倍。

总而言之,对于 backward,缓冲区大小大约是 reduce-scatter 的 2 倍,再加上任何 recordStream 的影响。

其次,让我们讨论额外的缓冲区

一旦从所有 rank 收集了分片参数,它们需要一个额外的缓冲区来存储完整参数,大小为 total_transformer_block_params_in_B*dtype_bytes - 所以继续前面的例子,如果每个 Transformer 块是 1.6B 参数,参数是 fp32,那么缓冲区大小将是 1.6*4=6.4GB

需要两个这样的缓冲区,因为一个当前正在使用,另一个正在被预取。

总结一下,我们有

  1. 2 倍于 total_transformer_block_params_in_B*dtype_bytes/num_gpus 的通信缓冲区

  2. 2 倍于未分片 Transformer 块参数缓冲区 ``total_transformer_block_params_in_B*dtype_bytes

或者按照前面的例子

  1. 2*1.6*4/8=1.6GB

  2. 2*1.6*4=12.8GB

总计 14.4GB

现在让我们简要讨论一下嵌入层会发生什么,因为我们之前的计算中忽略了它们

根据我们讨论过的规则,即在笔记中以“通信缓冲区大小确定如下”开头的部分,我们可以进行如下分析

  • 假设我们将 FSDP 应用于根模块(例如 Transformer 类)。假设我们进一步将 FSDP 应用于每个 Transformer 块(例如 TransformerBlock 类)。

  • 通常,嵌入层和最终线性投影是根 Transformer 类的直接子模块。

  • 根据我们的规则,这意味着嵌入层和最终线性投影被分配给根 Transformer 的扁平参数。

  • 我们有_另一个_特殊规则,即根模块在前向传播后不会释放其参数,因为无论如何它们都会在反向传播中立即进行 all-gather。

  • 综合来看,这意味着包含嵌入层和最终投影层的根模块扁平参数在开始前向传播时进行 all-gather,并保留在 GPU 内存中直到反向传播结束。

  • 如果嵌入层和最终线性层不绑定权重,那么我们_可以_进一步将 FSDP 应用于嵌入层和最终线性层。对于绑定权重的参数,我们要求它们属于同一个扁平参数(否则会重复计数)。这将允许嵌入层在其前向使用后被释放,并仅在反向传播结束时进行 all-gather。

  • 希望这能提供更好的理解——每个 FSDP 模块都会被分配其 module.parameters 中的参数,除非这些参数已被分配给另一个嵌套的 FSDP 模块,并且 FSDP 模块的 forward 定义了其参数的“活跃”区间。因此,嵌套的 nn.Module 结构会影响 all-gather/free 调度,从而影响内存/吞吐量性能。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源