PyTorch 分布式概述¶
作者: Will Constable
注意
在 github 中查看和编辑本教程。
这是 torch.distributed
包的概述页面。此页面的目标是将文档分类到不同的主题中,并简要描述每个主题。如果这是你第一次使用 PyTorch 构建分布式训练应用程序,建议使用此文档导航到最适合你的用例的技术。
简介¶
PyTorch 分布式库包含一组并行模块、通信层以及用于启动和调试大型训练作业的基础设施。
分片原语¶
DTensor
和 DeviceMesh
是用于根据 N 维进程组中分片或复制的张量构建并行性的原语。
DTensor 表示一个分片和/或复制的张量,并自动进行通信以根据操作需要重新分片张量。
DeviceMesh 将加速器设备通信器抽象为一个多维数组,用于管理底层的
ProcessGroup
实例,以进行多维并行中的集体通信。尝试我们的 Device Mesh 教程 以了解更多信息。
通信 API¶
- PyTorch 分布式通信层 (C10D) 提供了集体通信 API(例如,all_reduce
和 all_gather)以及点对点通信 API(例如,send 和 isend),这些 API 在所有并行实现中都在底层使用。 使用 PyTorch 编写分布式应用程序 展示了使用 c10d 通信 API 的示例。
应用并行性以扩展您的模型¶
数据并行是一种广泛采用的单程序多数据训练范式,其中模型在每个进程上进行复制,每个模型副本计算一组不同的输入数据样本的局部梯度,在每个优化器步骤之前,在数据并行通信器组内对梯度进行平均。
当模型无法容纳在 GPU 中时,需要使用模型并行技术(或分片数据并行),并且可以将它们组合在一起形成多维 (N-D) 并行技术。
在决定为您的模型选择哪些并行技术时,请使用以下常见指南
如果您的模型适合单个 GPU,但您希望使用多个 GPU 轻松扩展训练,请使用 DistributedDataParallel (DDP)。
如果您使用多个节点,请使用 torchrun 启动多个 PyTorch 进程。
另请参阅:分布式数据并行的入门指南
当您的模型无法容纳在一个 GPU 上时,请使用 FullyShardedDataParallel (FSDP)。
另请参阅:FSDP 入门指南
如果遇到 FSDP 的扩展限制,请使用 张量并行 (TP) 和/或 流水线并行 (PP)。
尝试我们的 张量并行教程
注意
数据并行训练也适用于 自动混合精度 (AMP)。