FSDP 说明¶
FSDP 预取细微差别¶
为了将 forward
全局收集与 forward
计算重叠,有两种可能的机制
隐式前向预取(始终启用)
显式前向预取 (
forward_prefetch=True
)
隐式 forward
预取是指依赖于从单独的 CUDA 流发出全局收集,以便将全局收集与之前(从 CPU 的角度来看)发出的 forward
计算重叠。例如,如果我们有第 0 层全局收集 -> 第 0 层 forward
计算 -> 第 1 层全局收集 -> …,则第 1 层全局收集可以与第 0 层 forward
计算重叠,即使 CPU 线程是在之后发出的。(第一次全局收集将无法与任何内容重叠。)
显式 forward
预取是指更改 CPU 线程的发射顺序:例如,第 0 层全局收集 -> 第 1 层全局收集 -> 第 0 层 forward
计算 -> …。在急切模式下,通常无法在执行第 0 层时知道下一层是哪一层(例如,示例中的第 1 层)。因此,显式 forward
预取只能用于执行顺序在每次迭代中都固定的模型(我们有时称之为“静态图”)。不满足此约束的模型的一个示例是 FLAVA)。
显式 forward
预取只会节省发出层的 forward
计算内核所需的时间,但代价是必须在当前输出张量仍在使用时分配下一个全局收集的输出张量。通过在当前 forward
计算内核之前发出下一个全局收集,下一个全局收集可以在 GPU 上更快地启动。对于大多数 LLM 工作负载,情况并非如此,因此没有理由启用 forward_prefetch=True
。
相反,对于 backward
,我们必须使用显式 backward
预取,否则通信和计算将完全没有重叠。原因是我们对全局收集和分散归约都使用了一个 NCCL 进程组(部分原因是在早期的 NCCL 版本中,在同一设备上对相同的秩同时使用多个进程组是不安全的)。单个 NCCL 进程组意味着单个内部 NCCL 流,分散归约和全局收集在该流上串行运行。因此,除非我们明确地将 CPU 发射顺序重新排序为下一个全局收集 -> 当前分散归约,否则当前分散归约将阻塞下一个全局收集,从而阻塞下一个 backward
计算,从而阻止当前分散归约发生重叠。
通信有效负载大小¶
在 FSDP 中,通信是
在
forward
中对参数进行全局收集在
backward
中对参数进行全局收集在
backward
中对梯度进行分散归约
如果使用激活检查点 (checkpoint()
),则不会产生额外的通信,因为在 backward
期间参数已经被预取了。
在 FSDP 设计中,每个秩的通信有效负载大小确定如下:每次调用 FullyShardedDataParallel
都会创建一个通信组,该组包含 module.parameters()
中的参数,但已分配给嵌套的 FullyShardedDataParallel
实例的除外。例如,对于 Llama,如果对每个 Transformer 块和根模块都应用 FullyShardedDataParallel
,则每个 Transformer 块都有一个通信组,最后还有一个通信组包含初始嵌入和最终线性层。每个通信组对应于一个全局收集调用和一个分散归约调用。这样,如何应用 FullyShardedDataParallel
决定了通信大小。一般来说,将 FSDP 应用于每个 Transformer 块对于 LLM 来说是一个很好的启发式方法,并且在当前设计下很难做得更好。
让我们考虑一个示例,其中我们有一个基于 Transformer 的模型,该模型在 8 个 GPU 上进行分片,其中分片仅在 Transformer 块级别进行,每个 Transformer 块包含 16 亿个参数,并且参数采用 fp32 格式(每个参数 4 个字节)。这意味着一旦分片,每个 Transformer 块在每个秩上将包含 2 亿个参数。
forward
传递将以0.2*4 = 0.8GB
的块进行全局收集通信backward
传递将进行 2 次通信,每次0.8GB
(1 次 all-gather 和 1 次 reduce-scatter)。
换句话说,将进行 3 次通信,每次的有效负载为 0.8GB
。如果模型由 10 个 Transformer 块组成,则总共将进行 30 次通信,总计 30*0.8=24GB
。
形式化地,每个秩每次通信的有效负载大小为 total_transformer_block_params_in_B*dtype_bytes/num_gpus
(GB)。
请注意,在本例中,我们没有包括嵌入所需的额外通信,这也应该考虑在内。计算将取决于输入和输出嵌入是否绑定。如果它们没有绑定,则通信次数将增加 2 倍。
FSDP 缓冲区大小¶
首先,让我们介绍一下为通信分配的缓冲区。
forward
当前需要 2 倍于 all-gather 的缓冲区大小。原因如下:
如 FSDP 预取细微差别 中所述,在显式 forward
预取(forward_prefetch=True`) 情况下,第 0 层的 all-gather -> 第 0 层 forward 计算 -> 第 1 层的 all-gather 序列中,需要 2 个 all-gather 大小的 缓冲区,因为 一个 缓冲区 用于 当前 ``forward
,而另一个缓冲区用于预取。
虽然在相同的序列中,隐式 forward
预取(forward_prefetch=False
,默认)理论上只需要 1 个缓冲区,但实际上仍然是 2 倍于 all-gather 的缓冲区大小。原因是在扁平参数 FSDP 设计中,我们不会从 all-gather 缓冲区中复制数据。用于计算的参数直接查看 all-gather 缓冲区(实际上,“扁平参数”的主要优势正是这个原因)。在这种情况下,虽然“第 1 层 all-gather”与“第 0 层 forward 计算”重叠,但“第 0 层 forward 计算”使用的是查看“第 0 层 all-gather”缓冲区的参数。
那么,什么时候需要 forward_prefetch=False
呢?对于静态图模型(如大多数 LLM),有一个重要的技术原因。实际上,我们为一些受 CPU 限制的内部模型快速添加了此选项,并且没有在单元测试中使用它测试所有代码路径,因此我们对它的信心不足。forward_prefetching=False
可能更容易理解,因为我们不必将记录的 forward 顺序作为可能的“故障模式”进行检查;模块的 all-gather 始终可以在其分析器跟踪中其自身的 record_function
标签下找到。
backward
当前至少需要 2 倍于 all-gather 的缓冲区大小,并且可能还需要更多。原因如下:
当前的 FSDP 设计使用 recordStream
来管理在一个流中生成并在另一个流中消耗的分配,这可能导致内存使用量超出预期。内存使用量增加多少可能是“不确定的”,因为它取决于相对于 CPU 的 GPU 内核计时。limit_all_gathers=True
参数对此进行了缓解 - 有关更多详细信息,请参阅 FSDP 和 CUDACachingAllocator 中的讨论。
现有 FSDP 如何与自动梯度一起工作。
现有的 FSDP 对
flat_param
进行 all-gather 操作,它是自动梯度的叶子节点。它调用
torch.split
来获取对应于其构成原始参数的flat_param
的一维视图。它对每个一维分割调用
torch.view
,以将其视图返回到 ND。这意味着在
backward
中,我们最终得到ViewBackward
(ND -> 1D)和SplitWithSizesBackward
(它是一个连接操作)。特别是,每个单独的梯度都作为一个单独的分配进行计算,并且会发生一个显式的连接操作来构建 reduce-scatter 输入缓冲区。这意味着在峰值内存点,reduce-scatter 实际上需要 2 倍的缓冲区大小。
总之,对于 backward
,它大约需要 2 倍于 reduce-scatter 的缓冲区大小,再加上任何 recordStream
的影响。
其次,让我们讨论一下其他缓冲区。
从所有秩收集分片参数后,它们需要一个额外的 total_transformer_block_params_in_B*dtype_bytes 的缓冲区来存储完整参数 - 因此,继续前面的例子,如果每个 Transformer 块有 1.6B 个参数,并且参数是 fp32,那么它将是一个 1.6*4=6.4GB 的缓冲区。
并且需要 2 个这样的缓冲区,因为一个正在使用,另一个正在预取。
总而言之,我们有:
2 倍于
total_transformer_block_params_in_B*dtype_bytes/num_gpus
的通信缓冲区。2 倍于
``total_transformer_block_params_in_B*dtype_bytes
的未分片 Transformer 块参数缓冲区。
或者,如果您一直在关注这个例子:
2*1.6*4/8=1.6GB
2**1.6*4=12.8GB
总计 14.4GB
。
现在让我们简要讨论一下嵌入会发生什么,因为我们没有将它们计入计算中。
根据我们讨论过的规则(您在以“通信缓冲区大小的确定如下”开头的注释中包含了该规则),我们可以进行如下分析:
假设我们将 FSDP 应用于根模块(例如
Transformer
类)。假设我们进一步将 FSDP 应用于每个 Transformer 块(例如TransformerBlock
类)。最常见的情况是,嵌入和最终线性投影是根
Transformer
类的直接子级。根据我们的规则,这意味着嵌入和最终线性投影将分配给根
Transformer
的扁平参数。我们还有_另一个_特殊规则,即根节点在 forward 后不会释放其参数,因为它们无论如何都会在 backward 中立即进行 all-gather 操作。
综合考虑,这意味着包括嵌入和最终线性投影在内的根节点的扁平参数将在开始 forward 时进行 all-gather 操作,并一直保留在 GPU 内存中,直到 backward 结束。
如果嵌入和最终线性未进行权重绑定,那么我们_可以_进一步将 FSDP 应用于嵌入和最终线性。对于权重绑定参数,我们要求它们是同一个扁平参数的一部分(否则它将被重复计算)。这将允许嵌入在其在 forward 中使用后被释放,并且仅在 backward 结束时进行 all-gather 操作。
希望这能 memberikan pemahaman yang lebih baik - 每个 FSDP 模块都会在其
module.parameters
中分配参数,但已经分配给另一个嵌套 FSDP 模块的参数除外,并且 FSDP 模块的forward
定义了其参数的“活动”间隔。因此,嵌套的nn.Module
结构会影响 all-gather/free 调度,从而影响内存/吞吐量性能。