FSDP 备注¶
FSDP 预取细微差别¶
对于将forward
全局聚合与forward
计算重叠,有两种可能的机制
隐式前向预取(始终启用)
显式前向预取(
forward_prefetch=True
)
隐式forward
预取是指依赖于从单独的 CUDA 流发出全局聚合,以允许将全局聚合与之前发出的forward
计算(从 CPU 的角度来看)重叠。例如,如果我们有层 0 全局聚合 -> 层 0 forward
计算 -> 层 1 全局聚合 -> …,那么即使 CPU 线程稍后发出了层 1 全局聚合,它也可以与层 0 forward
计算重叠。(第一个全局聚合将无法与任何内容重叠。)
显式forward
预取是指更改 CPU 线程的发行顺序:例如,层 0 全局聚合 -> 层 1 全局聚合 -> 层 0 forward
计算 -> …. 在急切模式下,通常无法知道在执行层 0 时下一层是哪一层(例如,示例中的层 1)。因此,显式forward
预取仅应用于执行顺序每次迭代都固定的模型(我们有时称之为“静态图”)。一个不满足此约束的模型示例是FLAVA。
显式forward
预取仅节省了发出层forward
计算内核所需的时间,但代价是必须在当前层仍在使用时分配下一个全局聚合的输出张量。通过在当前forward
计算内核之前发出下一个全局聚合,下一个全局聚合可以更快地在 GPU 上开始。对于大多数 LLM 工作负载,情况并非如此,因此没有动力启用forward_prefetch=True
。
相反,对于backward
,我们必须使用显式backward
预取,否则通信和计算将不会重叠。原因是我们对全局聚合和归约散射都使用单个 NCCL 进程组(部分原因是在早期 NCCL 版本中,在同一设备上通过同一等级同时使用多个 NCCL 进程组是不安全的)。单个 NCCL 进程组意味着单个内部 NCCL 流,归约散射和全局聚合在此流上串行运行。因此,除非我们明确地将 CPU 发行顺序重新排序为下一个全局聚合 -> 当前归约散射,否则当前归约散射将阻塞下一个全局聚合,从而阻塞下一个backward
计算,从而阻止当前归约散射重叠。
通信负载大小¶
在 FSDP 中,通信是
在
forward
中对参数进行全局聚合在
backward
中对参数进行全局聚合在
backward
中对梯度进行归约散射
如果使用了激活检查点(checkpoint()
),则没有额外的通信,因为参数在backward
期间无论如何都会被预取。
在 FSDP 设计中,每个等级的通信负载大小如下确定:每次调用FullyShardedDataParallel
都会创建一个通信组,该组包含module.parameters()
中的参数,但任何已分配给嵌套FullyShardedDataParallel
实例的参数除外。例如,对于 Llama,如果您将FullyShardedDataParallel
应用于每个 Transformer 块以及根模块,那么每个 Transformer 块都有一个通信组,最后有一个包含初始嵌入和最终线性层的通信组。每个通信组对应于一个全局聚合调用和一个归约散射调用。这样,您如何应用FullyShardedDataParallel
决定了通信大小。通常,将 FSDP 应用于每个 Transformer 块是 LLMs 的一个好的启发式方法,并且鉴于当前设计,很难做得比这更好。
让我们考虑一个示例,其中我们有一个基于 Transformer 的模型,在 8 个 GPU 上进行分片,其中分片仅发生在 Transformer 块级别,每个 Transformer 块包含 16 亿个参数,参数为 fp32(每个 4 个字节)。这意味着分片后,每个 Transformer 块在每个等级上将包含 2 亿个参数。
前向传递 (forward pass) 将以
0.2*4 = 0.8GB
的块大小进行全聚合 (all-gather) 通信。反向传递 (backward pass) 将进行 2 次
0.8GB
的通信(1 次全聚合和 1 次归约散射 (reduce-scatter))。
换句话说,将有 3 次通信,每次通信的有效载荷为 0.8GB
。如果模型包含 10 个 Transformer 块,则总共有 30 次通信,总共 30*0.8=24GB
。
为了形式化,每个通信中每个进程的有效载荷大小为 total_transformer_block_params_in_B*dtype_bytes/num_gpus
(GB)。
请注意,在此示例中,我们没有包含嵌入所需的额外通信,也应将其考虑在内。并且计算结果将取决于输入和输出嵌入是否绑定。如果它们未绑定,则通信次数将增加 2 倍。
FSDP 缓冲区大小¶
首先,让我们介绍为通信分配的缓冲区。
forward
目前需要 2 倍于全聚合缓冲区的大小。以下是原因:
如 FSDP 预取细微差别 中所述,在显式 forward
预取 (forward_prefetch=True`) case of layer 0 all-gather -> layer 0 forward compute -> layer 1 all-gather there is a need for 2 all-gather-sized buffers, because one buffer is used in the current ``forward
的情况下,需要 2 个全聚合大小的缓冲区,因为一个缓冲区用于当前的 forward
,而另一个缓冲区用于进行预取。
虽然在隐式 forward
预取 (forward_prefetch=False
,默认) 的情况下,理论上只需要 1 个缓冲区,但实际上仍然需要 2 个全聚合大小的缓冲区。原因是在扁平参数 FSDP 设计中,我们不从全聚合缓冲区复制数据。用于计算的参数直接查看全聚合缓冲区(实际上,“扁平参数”的主要优势正是这个原因)。在这种情况下,虽然“层 1 全聚合”与“层 0 前向计算”重叠,但“层 0 前向计算”正在使用查看“层 0 全聚合”缓冲区中的参数。
那么,一个自然的问题是,何时需要 forward_prefetch=False
?对于静态图模型(如大多数大型语言模型),有一个主要的技术原因。更确切地说,我们快速添加了此选项以用于某些 CPU 受限的内部模型,并且尚未在单元测试中使用它测试每条代码路径,因此我们对其信心不足。 forward_prefetching=False
可能稍微更容易理解,因为我们不必将记录的前向顺序检查为可能的“故障模式”;模块的全聚合始终可以在其分析器跟踪中在其自己的 record_function
标签下找到。
backward
目前至少需要 2 倍于全聚合缓冲区的大小,并且可能更多。以下是原因:
当前的 FSDP 设计使用 recordStream
来管理在一个流中生成并在另一个流中使用的分配,这可能导致比预期更多的内存使用量。更多内存使用量在某种程度上是“不确定的”,因为它取决于 GPU 内核时间相对于 CPU 的时间。 limit_all_gathers=True
参数是对此的一种缓解措施 - 有关更多详细信息,请参阅 FSDP & CUDACachingAllocator 中的讨论。
现有 FSDP 与自动微分 (autograd) 的工作方式
现有的 FSDP 对
flat_param
进行全聚合,它是自动微分叶子节点。它调用
torch.split
以获取flat_param
中对应于其组成原始参数的 1D 视图。它对每个 1D 分片调用
torch.view
以将其视图恢复到 ND 形状。这意味着在
backward
中,我们最终得到了ViewBackward
(ND -> 1D) 和SplitWithSizesBackward
(它是连接操作)。特别是,每个单独的梯度都作为单独的分配来计算,并且发生显式连接操作来构造归约散射输入缓冲区。这意味着在该峰值内存点实际上为归约散射分配了 2 倍的缓冲区大小。
总之,对于 backward
,大约需要 2 倍于归约散射的缓冲区大小,加上任何 recordStream
影响。
其次,让我们讨论其他缓冲区。
一旦从所有进程收集了分片参数,它们就需要一个额外的 total_transformer_block_params_in_B*dtype_bytes 大小的缓冲区来存储完整参数 - 因此继续前面的示例,如果每个 Transformer 块有 1.6B 个参数并且参数为 fp32,则缓冲区大小为 1.6*4=6.4GB。
并且需要 2 个这样的缓冲区,因为当前正在使用一个,另一个正在预取。
总而言之,我们有:
2 倍于
total_transformer_block_params_in_B*dtype_bytes/num_gpus
的通信缓冲区。2 倍于未分片 Transformer 块参数缓冲区
``total_transformer_block_params_in_B*dtype_bytes
。
或者如果你一直在关注示例:
2*1.6*4/8=1.6GB
2*1.6*4=12.8GB
总共 14.4GB
。
现在让我们简要讨论一下嵌入的情况,因为我们已将其从计算中排除。
根据我们讨论的规则(你包含在以“通信缓冲区大小按如下方式确定”开头的注释中),我们可以进行如下分析:
假设我们将 FSDP 应用于根模块(例如
Transformer
类)。假设我们进一步将 FSDP 应用于每个 Transformer 块(例如TransformerBlock
类)。最常见的是,嵌入和最终线性投影是根
Transformer
类的直接子节点。根据我们的规则,这意味着嵌入和最终线性投影被分配给根
Transformer
的扁平参数。我们还有另一个特殊规则,即根模块在前向传递后不会释放其参数,因为它们将在反向传递中立即被全聚合。
综合起来,这意味着包括嵌入和最终投影在内的根模块的扁平参数将在前向传递开始时被全聚合,并保留在 GPU 内存中,直到反向传递结束。
如果嵌入和最终线性未进行权重绑定,那么我们_可以_进一步将 FSDP 应用于嵌入和最终线性。对于权重绑定的参数,我们要求它们是同一个扁平参数的一部分(否则会被重复计算)。这将允许嵌入在其在 forward 中使用后被释放,并且仅在反向传递结束时被全聚合。
希望这能更好地理解 – 每个 FSDP 模块在其
module.parameters
中获取参数,除了已分配给另一个嵌套 FSDP 模块的参数,并且 FSDP 模块的forward
定义了其参数的“活动”区间。因此,嵌套的nn.Module
结构会影响全聚合/释放计划,进而影响内存/吞吐量性能。