快捷方式

DDP 通信钩子

DDP 通信钩子是一个通用接口,通过覆盖 DistributedDataParallel 中的普通 allreduce,来控制如何在工作进程之间通信梯度。提供了一些内置通信钩子,用户可以轻松地应用这些钩子来优化通信。此外,钩子接口还可以支持用户定义的通信策略,以满足更高级的使用场景。

如何使用通信钩子?

要使用通信钩子,用户只需在训练循环之前让 DDP 模型注册钩子,如下所示。

torch.nn.parallel.DistributedDataParallel.register_comm_hook()

通信钩子作用于什么?

通信钩子提供了一种灵活的方式来进行梯度 allreduce。因此,它主要作用于每个副本上的梯度(在 allreduce 之前),这些梯度被分组以增加通信和计算之间的重叠。具体来说,torch.distributed.GradBucket 表示要 allreduce 的梯度张量组。

class torch.distributed.GradBucket

此类主要将扁平化的梯度张量(由 buffer() 返回)传递给 DDP 通信钩子。这个张量可以进一步分解为该组中的每个参数张量的列表(由 get_per_parameter_tensors() 返回),以便应用层级操作。

torch.distributed.GradBucket.index(self: torch._C._distributed_c10d.GradBucket) int

警告

由于桶在第一次迭代后会被重建,因此不应该依赖训练开始时的索引。

返回值

存储几个连续层梯度的桶的索引。所有梯度都被分组。

torch.distributed.GradBucket.buffer(self: torch._C._distributed_c10d.GradBucket) torch.Tensor
返回值

一个扁平化的 1D torch.Tensor 缓冲区,它可以进一步分解为该组中的每个参数张量的列表。

torch.distributed.GradBucket.gradients(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]
返回值

一个 torch.Tensor 列表。列表中的每个张量对应于一个梯度。

torch.distributed.GradBucket.is_last(self: torch._C._distributed_c10d.GradBucket) bool
返回值

此桶是否为迭代中要 allreduce 的最后一个桶。这也意味着此桶对应于正向传递中的前几层。

torch.distributed.GradBucket.set_buffer(self: torch._C._distributed_c10d.GradBucket, buffer: torch.Tensor) None

用输入张量 buffer 替换桶中的张量。

torch.distributed.GradBucket.parameters(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]
返回值

一个 torch.Tensor 的列表。列表中的每个张量都对应一个模型参数。

默认通信钩子

默认通信钩子是简单的无状态钩子,因此 register_comm_hook 中的输入状态是进程组或 None。输入 bucket 是一个 torch.distributed.GradBucket 对象。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket)[source]

使用 GradBucket 张量调用 allreduce

在梯度张量跨所有工作进程聚合后,其 then 回调将计算平均值并返回结果。

如果用户注册了此 DDP 通信钩子,则 DDP 结果应与未注册钩子的情况相同。因此,这不会改变 DDP 的行为,用户可以使用它作为参考或修改此钩子以记录有用的信息或任何其他目的,同时不会影响 DDP 行为。

示例:
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
返回值类型

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket)[source]

通过将 GradBucket 转换为 torch.float16 并除以进程组大小来进行压缩。

此 DDP 通信钩子实现了一种简单的梯度压缩方法,它将 GradBucket 张量转换为半精度浮点格式 (torch.float16),然后除以进程组大小。它对这些 float16 梯度张量执行 allreduce 操作。一旦压缩后的梯度张量被 allreduce,链式回调 decompress 会将其转换回输入数据类型(例如 float32)。

示例:
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
返回值类型

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group, bucket)[source]

警告:此 API 处于实验阶段,它需要 NCCL 版本高于 2.9.6。

此 DDP 通信钩子实现了一种简单的梯度压缩方法,它将 GradBucket 张量转换为半精度Brain 浮点格式 (torch.bfloat16),然后除以进程组大小。它对这些 bfloat16 梯度张量执行 allreduce 操作。一旦压缩后的梯度张量被 allreduce,链式回调 decompress 会将其转换回输入数据类型(例如 float32)。

示例:
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
返回值类型

Future[Tensor]

此外,提供了一个通信钩子包装器来支持 fp16_compress_hook()bf16_compress_hook() 作为包装器,它可以与其他通信钩子组合使用。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[source]

将输入张量转换为 torch.float16,将钩子的结果转换回输入数据类型。

此包装器将给定 DDP 通信钩子的输入梯度张量转换为半精度浮点格式 (torch.float16),并将给定钩子的结果张量转换回输入数据类型,例如 float32。因此,fp16_compress_hook 等效于 fp16_compress_wrapper(allreduce_hook)

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
返回值类型

Callable[[Any, GradBucket], Future[Tensor]]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[source]

警告:此 API 处于实验阶段,它需要 NCCL 版本高于 2.9.6。

此包装器将给定 DDP 通信钩子的输入梯度张量转换为半精度 Brain 浮点格式 <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format> `_ (``torch.bfloat16`),并将给定钩子的结果张量转换回输入数据类型,例如 float32

因此,bf16_compress_hook 等效于 bf16_compress_wrapper(allreduce_hook)

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
返回值类型

Callable[[Any, GradBucket], Future[Tensor]]

PowerSGD 通信钩子

PowerSGD (Vogels 等人,NeurIPS 2019) 是一种梯度压缩算法,它可以提供非常高的压缩率并加速带宽受限的分布式训练。该算法需要维护一些超参数和内部状态。因此,PowerSGD 通信钩子是一个有状态钩子,用户需要提供如下定义的状态对象。

PowerSGD 状态

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=1000, min_compression_rate=2, use_error_feedback=True, warm_start=True, orthogonalization_epsilon=0, random_seed=0, compression_stats_logging_frequency=10000, batch_tensors_with_same_shape=False)[source]

存储所有梯度的算法超参数和内部状态以供训练使用。

特别是,matrix_approximation_rankstart_powerSGD_iter 是用户应该调整的主要超参数。为了性能,我们建议将二元超参数 use_error_feedbackwarm_start 打开。

  1. matrix_approximation_rank 控制压缩低秩张量的尺寸,这决定了压缩率。秩越低,压缩越强。

    1.1. 如果 matrix_approximation_rank 太低,完整模型质量需要更多训练步骤才能达到,或者永远无法达到,并导致精度损失。

    1.2. matrix_approximation_rank 的增加会显著增加压缩的计算成本,并且精度可能不会在某个 matrix_approximation_rank 阈值之外进一步提高。

为了调整 matrix_approximation_rank,我们建议从 1 开始,以 2 的倍数增加(如指数网格搜索,1、2、4、…),直到达到令人满意的精度。通常只使用一个小值 1-4。对于某些 NLP 任务(如原始论文的附录 D 中所示),该值已增加到 32。

  1. start_powerSGD_iter 将 PowerSGD 压缩推迟到第 start_powerSGD_iter 步,而普通 allreduce 则在第 start_powerSGD_iter 步之前运行。这种 **普通 allreduce + PowerSGD** 的混合方案可以有效地提高精度,即使使用相对较小的 matrix_approximation_rank。这是因为训练阶段的开始通常对不准确的梯度非常敏感,过早压缩梯度可能会使训练快速采取次优轨迹,从而导致对精度的不可恢复的影响。

为了调整 start_powerSGD_iter,我们建议从总训练步骤的 10% 开始,并逐渐增加它,直到达到令人满意的精度。如果训练中存在预热阶段,start_powerSGD_iter 通常不应该小于预热步骤的数量。

  1. min_compression_rate 是压缩层时所需的最小压缩率。由于压缩带来的计算开销,只有在带宽节省足够的情况下,才值得压缩张量,其中 (num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols。如果无法满足指定的压缩率阈值,则张量将直接进行 allreduce,不会压缩。

压缩统计信息将在 PowerSGD 压缩开始后每 compression_stats_logging_frequency 次迭代记录一次。

  1. orthogonalization_epsilon 可以是一个非常小的值(例如,1e-8),添加到正交化步骤中每个归一化矩阵列中,以防止任何列全为 0 时出现除以零错误。如果这已经可以防止(例如,通过批次归一化),则建议使用 0 的 epsilon 以提高精度。

  2. batch_tensors_with_same_shape 控制是否以批处理操作压缩和解压缩形状相同的张量以实现更高的并行性。请注意,您还应该增加桶大小(即 DDP 构造函数中的 bucket_cap_mb 参数),以使更多形状相同的张量出现在同一个桶中,但这可能会减少计算和通信之间的重叠,并由于堆叠相同形状的张量而增加内存占用。如果压缩/解压缩计算是瓶颈,则将其设置为 True

警告

如果启用了错误反馈或预热,DDP 中允许的 start_powerSGD_iter 的最小值为 2。这是因为存在另一个内部优化,它在 DDP 的第 1 次迭代时重建桶,这可能与重建过程之前记忆的任何张量发生冲突。

PowerSGD 挂钩

警告

PowerSGD 通常需要与模型梯度相同大小的额外内存来启用错误反馈,这可以弥补有偏差的压缩通信并提高精度。

警告

PowerSGD 挂钩可能与 Apex 自动混合精度包 发生冲突。请改用 PyTorch 原生自动混合精度包

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)[source]

实现 PowerSGD 算法。

此 DDP 通信挂钩实现了 论文 中描述的 PowerSGD 梯度压缩算法。一旦梯度张量在所有工作器上聚合,此挂钩将按如下方式应用压缩。

  1. 将输入扁平化的 1D 梯度张量视为每个参数张量的列表,并将所有张量分成两组。

    1.1 应该在 allreduce 之前压缩的张量,因为压缩可以节省足够的带宽。

    1.2 其余的张量将直接进行 allreduce,不会压缩,包括所有向量张量(用于偏差)。

  2. 处理未压缩的张量

    2.1. 为那些未压缩的张量分配连续内存,并将所有未压缩的张量作为一批进行 allreduce,不进行压缩;

    2.2. 将单个未压缩的张量从连续内存复制回输入张量。

  3. 处理应该由 PowerSGD 压缩压缩的张量

    3.1. 对于每个张量 M,创建两个低秩张量 P 和 Q 来分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;

    3.2. 计算每个 Ps 中的 P,它等于 MQ;

    3.3. 将 Ps 作为一批进行 allreduce;

    3.4. 对 Ps 中的每个 P 进行正交化;

    3.5. 计算每个 Qs 中的 Q,它近似等于 M^TP;

    3.6. 将 Qs 作为一批进行 allreduce;

    3.7. 计算所有压缩张量中的每个 M,它近似等于 PQ^T。

请注意,此通信挂钩对前 state.start_powerSGD_iter 次迭代强制执行普通 allreduce。这不仅使用户可以更好地控制加速和精度之间的权衡,而且还有助于抽象掉 DDP 内部优化的一些复杂性,以供未来的通信挂钩开发人员使用。

参数
  • state (PowerSGDState) – 用于配置压缩率并支持错误反馈、预热等的狀態信息。为了调整压缩配置,主要需要调整 matrix_approximation_rankstart_powerSGD_itermin_compression_rate

  • bucket (dist.GradBucket) – 用于存储批处理多个每个变量张量的 1D 扁平化梯度张量的桶。请注意,由于 DDP 通信挂钩只支持单进程单设备模式,因此此桶中只存储了一个张量。

返回值

通信的未来处理程序,它会就地更新梯度。

返回值类型

Future[Tensor]

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
                          start_powerSGD_iter=10, min_compression_rate=0.5)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)[source]

实现简化的 PowerSGD 算法。

此 DDP 通信挂钩实现了 论文 中描述的简化的 PowerSGD 梯度压缩算法。此变体不会逐层压缩梯度,而是压缩批处理所有梯度的扁平化输入张量。因此,它比 powerSGD_hook() **更快**,但通常会导致 **更低的精度**,除非 matrix_approximation_rank 为 1。

警告

这里增加 matrix_approximation_rank 并不一定能提高精度,因为在没有列/行对齐的情况下对每个参数张量进行批处理会破坏低秩结构。因此,用户应该始终首先考虑 powerSGD_hook(),只有在 matrix_approximation_rank 为 1 时可以达到令人满意的精度时,才考虑此变体。

一旦梯度张量在所有工作器上聚合,此挂钩将按如下方式应用压缩。

  1. 将输入扁平化的 1D 梯度张量视为具有 0 填充的方形张量 M;

  2. 创建两个低秩张量 P 和 Q 来分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;

  3. 计算 P,它等于 MQ;

  4. 对 P 进行 allreduce;

  5. 对 P 进行正交化;

  6. 计算 Q,它近似等于 M^TP;

  7. 对 Q 进行 allreduce;

  8. 计算 M,它近似等于 PQ^T。

  9. 将输入张量截断为原始长度。

请注意,此通信挂钩对前 state.start_powerSGD_iter 次迭代强制执行普通 allreduce。这不仅使用户可以更好地控制加速和精度之间的权衡,而且还有助于抽象掉 DDP 内部优化的一些复杂性,以供未来的通信挂钩开发人员使用。

参数
  • state (PowerSGDState) – 用于配置压缩率并支持错误反馈、预热等的狀態信息。为了调整压缩配置,主要需要调整 matrix_approximation_rankstart_powerSGD_iter

  • bucket (dist.GradBucket) – 用于存储批处理多个每个变量张量的 1D 扁平化梯度张量的桶。请注意,由于 DDP 通信挂钩只支持单进程单设备模式,因此此桶中只存储了一个张量。

返回值

通信的未来处理程序,它会就地更新梯度。

返回值类型

Future[Tensor]

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)

调试通信挂钩

顾名思义,调试通信挂钩**仅**用于调试和性能优化目的。

警告

调试通信挂钩并不一定输出正确的结果。

torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_, bucket)[source]

返回一个包装输入的 future,因此它是一个不产生任何通信开销的 no-op。

此钩子用于所有减少优化的头部空间分析,而不是正常的梯度同步。例如,如果注册此钩子后,训练时间的加速率不到 10%,通常意味着所有减少不是这种情况下的性能瓶颈。如果难以检索 GPU 跟踪或跟踪分析很复杂(例如所有减少和计算之间的重叠或跨等级的非同步),这种检测方法可能特别有用。

示例:
>>> ddp_model.register_comm_hook(None, noop_hook)
返回值类型

Future[Tensor]

通信钩子的检查点

有状态的通信钩子可以作为模型检查点的一部分保存,以启用训练器重新启动。要使钩子可序列化,应定义 __setstate____getstate__

警告

__getstate__ 应从返回的字典中排除不可序列化的属性。

警告

__setstate__ 应正确初始化不可序列化的属性,这些属性已从提供的 state 中排除。

PowerSGDState 已实现 __setstate____getstate__,可作为参考。

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState[source]
__getstate__()[source]

返回将被腌制并保存的 Dict[str, Any]

process_group 不可序列化,因此已从返回的状态中排除。

__setstate__(state)[source]

获取提供的 state 并将其设置为此 PowerSGDState 实例。

process_group 被设置为默认值。

以下是一个简单且端到端的示例,用于保存和重新加载 PowerSGD 状态和钩子。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24,24)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(24,12)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def run_demo(demo_fn, world_size):
    mp.spawn(
        demo_fn,
        args=(world_size,),
        nprocs=world_size,
        join=True)

def demo_serialization(rank, world_size):
    setup(rank, world_size)

    CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"

    model = SimpleModel().to(rank)
    ddp_model = DistributedDataParallel(model, device_ids=[rank])

    powersgd_hook = powerSGD.powerSGD_hook
    powersgd_state = powerSGD.PowerSGDState(process_group=None)

    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    state = {
        'state_dict': ddp_model.state_dict(),
        'comm_hook': powersgd_hook,
        'comm_hook_state': powersgd_state}

    if rank == 0:
        torch.save(state, CHECKPOINT)

    dist.barrier()
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    checkpoint = torch.load(CHECKPOINT, map_location=map_location)

    new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
    new_ddp_model.load_state_dict(checkpoint['state_dict'])
    powersgd_hook = checkpoint['comm_hook']
    powersgd_state = checkpoint['comm_hook_state']

    new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    if rank == 0:
        os.remove(CHECKPOINT)

    cleanup()

if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_serialization, world_size)

致谢

衷心感谢 PowerSGD 论文作者Thijs Vogels 对 PowerSGD 通信钩子的代码审查,以及 对比实验,证明了 PowerSGD 通信钩子的性能与原始 论文 中的实现相一致。

文档

访问 PyTorch 的全面的开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源