FSDP 笔记¶
FSDP 预取细微之处¶
对于与 forward
计算重叠的 forward
all-gather,有两种可能的机制
隐式前向预取(始终启用)
显式前向预取 (
forward_prefetch=True
)
隐式 forward
预取指的是依赖于从单独的 CUDA 流发出 all-gather,以便将 all-gather 与之前发出的 forward
计算(从 CPU 角度来看)重叠。例如,如果我们有 layer 0 all-gather -> layer 0 forward
计算 -> layer 1 all-gather -> …,那么 layer 1 all-gather 可以与 layer 0 forward
计算重叠,即使 CPU 线程稍后发出它。(第一个 all-gather 将无法与任何内容重叠。)
显式 forward
预取指的是更改 CPU 线程的发出顺序:例如 layer 0 all-gather -> layer 1 all-gather -> layer 0 forward
计算 -> …。在 eager 模式下,通常无法知道哪个层是下一层(例如示例中的 layer 1),当仍在 layer 0 上执行时。因此,显式 forward
预取仅应用于执行顺序从迭代到迭代固定的模型(我们有时称之为“静态图”)。不满足此约束的模型示例是 FLAVA)。
显式 forward
预取仅节省了发出层 forward
计算内核所花费的时间,代价是必须在当前 all-gather 的输出张量仍在使用的同时分配下一个 all-gather 的输出张量。通过在当前 forward
计算内核之前发出下一个 all-gather,下一个 all-gather 可以在 GPU 上更快地启动。对于大多数 LLM 工作负载,情况并非如此,因此没有启用 forward_prefetch=True
的动机。
相比之下,对于 backward
,我们必须使用显式 backward
预取,否则通信和计算将完全没有重叠。原因是我们在 all-gather 和 reduce-scatter 中都使用单个 NCCL 进程组(部分原因是早期 NCCL 版本中,在同一设备上通过相同的 ranks 并发使用多个进程组是不安全的)。单个 NCCL 进程组意味着单个内部 NCCL 流,reduce-scatter 和 all-gather 在该流上串行运行。因此,除非我们显式地将 CPU 发出顺序重新排序为下一个 all-gather -> 当前 reduce-scatter,否则当前 reduce-scatter 将会阻塞下一个 all-gather,从而阻塞下一个 backward
计算,阻止当前 reduce-scatter 重叠。
通信负载大小¶
在 FSDP 中,通信包括
forward
中的参数上的 all-gatherbackward
中的参数上的 all-gatherbackward
中的梯度上的 reduce-scatter
如果使用了激活检查点 (checkpoint()
),则没有额外的通信,因为参数在 backward
期间无论如何都会被预取。
在 FSDP 设计中,每个 rank 的通信负载按如下方式确定:每次调用 FullyShardedDataParallel
都会创建一个通信组,该组包含 module.parameters()
中的参数,但不包括已分配给嵌套 FullyShardedDataParallel
实例的任何参数。例如,对于 Llama,如果您将 FullyShardedDataParallel
应用于每个 transformer 块以及根模块,那么每个 transformer 块都会有一个通信组,最后还会有一个包含初始 embedding 和最终线性层的通信组。每个通信组对应于单个 all-gather 调用和单个 reduce-scatter 调用。通过这种方式,您如何应用 FullyShardedDataParallel
决定了通信大小。一般来说,将 FSDP 应用于每个 transformer 块是 LLM 的一个很好的启发式方法,并且在当前设计下很难做得更好。
让我们考虑一个示例,其中我们有一个基于 Transformer 的模型在 8 个 GPU 上分片,其中分片仅发生在 transformer 块级别,并且每个 transformer 块包含 1.6B 参数,参数为 fp32(每个参数 4 字节)。这意味着一旦分片,每个 transformer 块在每个 rank 上将包含 0.2B 参数。
forward
传递将在 all-gather 中通信0.2*4 = 0.8GB
的块backward
传递将通信 2 次,每次0.8GB
(1 次 all-gather 和 1 次 reduce-scatter)
换句话说,将有 3 次通信,每次通信的负载为 0.8GB
。如果模型由 10 个 transformer 块组成,则总共将有 30 次通信,总共 30*0.8=24GB
。
形式化表示,每个 rank 每次通信的负载大小为 total_transformer_block_params_in_B*dtype_bytes/num_gpus
(GBs)。
请注意,在本示例中,我们没有包括 embedding 所需的额外通信,这也应该考虑在内。并且计算将取决于输入和输出 embedding 是否绑定。如果它们未绑定,则将有 2 倍以上的通信。
FSDP 缓冲区大小¶
首先,让我们介绍为通信分配的缓冲区
forward
当前需要 2 倍 all-gather 缓冲区大小。原因如下
如 FSDP 预取细微之处 中解释的,在显式 forward
预取 (forward_prefetch=True`) 情况下,对于 layer 0 all-gather -> layer 0 forward 计算 -> layer 1 all-gather ,需要 2 个 all-gather 大小的缓冲区,因为一个缓冲区用于当前的 ``forward
,而另一个缓冲区用于执行预取。
虽然隐式 forward
预取 (forward_prefetch=False
,默认) 情况下,相同的序列理论上应该只需要 1 个缓冲区,但实际上仍然是 2 倍 all-gather 大小的缓冲区。原因是,在 flat-parameter FSDP 设计中,我们不从 all-gather 缓冲区中复制出来。用于计算的参数直接查看 all-gather 缓冲区(事实上,“flat parameter”的主要好处正是这个原因)。在这种情况下,当 ‘layer 1 all-gather’ 与 ‘layer 0 forward 计算’ 重叠时,‘layer 0 forward 计算’ 正在使用查看 ‘layer 0 all-gather’ 缓冲区的参数。
那么,一个自然的问题是,您何时想要 forward_prefetch=False
?对于静态图模型(如大多数 LLM),有一个主要的技术原因。更确切地说,实际上,我们为某些 CPU 密集型内部模型快速添加了这个选项,并且没有在单元测试中测试每个代码路径,因此我们对此不太自信。forward_prefetching=False
可能更容易推理,因为我们不必检查记录的前向顺序作为可能的“失败模式”;模块的 all-gather 始终可以在其分析器跟踪中的自身 record_function
标签下找到。
backward
当前至少需要 2 倍 all-gather 缓冲区大小,并且可能更多一点。原因如下
当前的 FSDP 设计使用 recordStream
来管理在一个流中生成并在另一个流中使用的分配,这可能导致比预期更多的内存使用。更多多少可能是“不确定的”,因为它取决于 GPU 内核计时相对于 CPU 的时间。limit_all_gathers=True
参数是对这种情况的缓解措施 - 更多详细信息请参考此讨论 FSDP & CUDACachingAllocator。
现有 FSDP 与 autograd 的工作方式
现有 FSDP all-gather
flat_param
,它是 autograd 叶节点。它调用
torch.split
以获取flat_param
的 1D 视图,对应于其组成的原始参数。它在每个 1D split 上调用
torch.view
以查看回 ND。这意味着在
backward
中,我们最终得到ViewBackward
(ND -> 1D) 和SplitWithSizesBackward
(它是 concat)。特别是,每个单独的梯度都作为单独的分配计算,并且发生显式 concat 以构造 reduce-scatter 输入缓冲区。这意味着实际上在峰值内存点为 reduce-scatter 提供了 2 倍缓冲区大小。
总而言之,对于 backward
,大约是 reduce-scatter 的 2 倍缓冲区大小,加上任何 recordStream
效应。
其次,让我们讨论额外的缓冲区
一旦从所有 ranks 收集到分片参数,它们就需要一个额外的缓冲区 total_transformer_block_params_in_B*dtype_bytes 用于完整参数 - 因此继续之前的示例,如果每个 transformer 块是 1.6B 参数,并且参数是 fp32,那么它将是 1.6*4=6.4GB 缓冲区。
并且需要 2 个这样的缓冲区,因为有一个当前正在使用,另一个正在预取。
总结一下,我们有
2 倍的通信缓冲区,大小为
total_transformer_block_params_in_B*dtype_bytes/num_gpus
2 倍的未分片 transformer 块参数缓冲区
``total_transformer_block_params_in_B*dtype_bytes
或者如果您一直在关注示例
2*1.6*4/8=1.6GB
2**1.6*4=12.8GB
总共 14.4GB
。
现在让我们简要讨论一下 embedding 发生了什么,因为我们已将这些从计算中排除
鉴于我们讨论的规则,您包含在以“通信缓冲区大小按如下方式确定”开头的注释中,我们可以按如下方式分析
假设我们将 FSDP 应用于根模块(例如
Transformer
类)。假设我们进一步将 FSDP 应用于每个 transformer 块(例如TransformerBlock
类)。最常见的是,embedding 和最终线性投影是根
Transformer
类的直接子项。按照我们的规则,这意味着 embedding 和最终线性投影被分配给根
Transformer
的 flat parameter。我们有 _另一个_ 特殊规则,即根模块在前向传播后不会释放其参数,因为无论如何它们都会在反向传播中立即被 all-gather。
综上所述,这意味着根模块的 flat parameter(包括 embedding 和最终投影)在开始前向传播时都会被 all-gather,并保留在 GPU 内存中直到反向传播结束。
如果 embedding 和最终线性层未绑定权重,那么我们 _可以_ 进一步将 FSDP 应用于 embedding 和最终线性层。对于权重绑定的参数,我们要求它们成为同一 flat parameter 的一部分(否则会重复计数)。这将允许 embedding 在其在前向传播中使用后被释放,并且仅在反向传播结束时才被 all-gather。
希望这能更好地理解 - 每个 FSDP 模块都会分配其
module.parameters
中的参数,但已分配给另一个嵌套 FSDP 模块的参数除外,并且 FSDP 模块的forward
定义了其参数的“活动”间隔。因此,嵌套的nn.Module
结构会影响 all-gather/free 调度,从而影响内存/吞吐量性能。