• 文档 >
  • DDP Communication Hooks
快捷方式

DDP 通信 Hook

DDP 通信 hook 是一个通用接口,用于通过重写 DistributedDataParallel 中的 vanilla allreduce 来控制跨 worker 通信梯度的方式。提供了一些内置的通信 hook,用户可以轻松应用其中任何一个来优化通信。此外,该 hook 接口还支持用户自定义通信策略,以用于更高级的用例。

如何使用通信 Hook?

要使用通信 hook,用户只需让 DDP 模型在训练循环开始前注册该 hook,如下所示。

torch.nn.parallel.DistributedDataParallel.register_comm_hook()

通信 Hook 对什么进行操作?

通信 hook 提供了一种灵活的方式来 allreduce 梯度。因此,它主要在 allreduce 之前对每个副本上的梯度进行操作,这些梯度被分桶(bucketized)以增加通信和计算之间的重叠。特别地,torch.distributed.GradBucket 代表一个包含待 allreduce 梯度张量的桶。

class torch.distributed.GradBucket

此类主要将展平的梯度张量(由 buffer() 返回)传递给 DDP 通信 hook。此张量可以进一步分解为此桶内每个参数的张量列表(由 get_per_parameter_tensors() 返回),以应用层级操作。

torch.distributed.GradBucket.index(self: torch._C._distributed_c10d.GradBucket) int

警告

由于桶在第一次迭代后会重建,因此不应依赖训练开始时的索引。

返回值

存储几个连续层梯度的桶的索引。所有梯度均已分桶。

torch.distributed.GradBucket.buffer(self: torch._C._distributed_c10d.GradBucket) torch.Tensor
返回值

一个展平的 1D torch.Tensor 缓冲区,可以进一步分解为此桶内每个参数的张量列表。

torch.distributed.GradBucket.gradients(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]
返回值

一个 torch.Tensor 列表。列表中的每个张量对应一个梯度。

torch.distributed.GradBucket.is_last(self: torch._C._distributed_c10d.GradBucket) bool
返回值

此桶是否是迭代中最后一个进行 allreduce 的桶。这也意味着此桶对应于前向传播中的前几层。

torch.distributed.GradBucket.set_buffer(self: torch._C._distributed_c10d.GradBucket, buffer: torch.Tensor) None

用输入的张量缓冲区替换桶中的张量。

torch.distributed.GradBucket.parameters(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]
返回值

一个 torch.Tensor 列表。列表中的每个张量对应一个模型参数。

默认通信 Hook

默认通信 hook 是简单的 无状态 hook,因此 register_comm_hook 中的输入 state 要么是一个 process group,要么是 None。输入 bucket 是一个 torch.distributed.GradBucket 对象。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket)[source][source]

使用 GradBucket 张量调用 allreduce

一旦梯度张量在所有 worker 上聚合,其 then 回调将计算均值并返回结果。

如果用户注册此 DDP 通信 hook,则 DDP 结果预计与未注册 hook 的情况相同。因此,这不会改变 DDP 的行为,用户可以将其用作参考或修改此 hook 来记录有用信息或用于任何其他目的,同时不影响 DDP 行为。

示例:
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
返回值类型

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket)[source][source]

通过将 GradBucket 强制转换为 torch.float16 并除以 process group 大小来压缩。

此 DDP 通信 hook 实现了一种简单的梯度压缩方法,将 GradBucket 张量强制转换为半精度浮点格式(torch.float16),然后除以 process group 大小。它 allreduce 这些 float16 梯度张量。一旦压缩的梯度张量被 allreduce,链式回调 decompress 将其强制转换回输入数据类型(例如 float32)。

示例:
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
返回值类型

Future[Tensor]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group, bucket)[source][source]

警告:此 API 处于实验阶段,需要 NCCL 版本高于 2.9.6。

此 DDP 通信 hook 实现了一种简单的梯度压缩方法,将 GradBucket 张量强制转换为半精度 Brain 浮点格式torch.bfloat16),然后除以 process group 大小。它 allreduce 这些 bfloat16 梯度张量。一旦压缩的梯度张量被 allreduce,链式回调 decompress 将其强制转换回输入数据类型(例如 float32)。

示例:
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
返回值类型

Future[Tensor]

此外,还提供了一个通信 hook 包装器,以支持将 fp16_compress_hook()bf16_compress_hook() 用作包装器,可以与其他通信 hook 结合使用。

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[source][source]

将输入张量强制转换为 torch.float16,将 hook 的结果强制转换回输入数据类型。

此包装器将给定 DDP 通信 hook 的输入梯度张量强制转换为半精度浮点格式(torch.float16),并将给定 hook 的结果张量强制转换回输入数据类型,例如 float32。因此,fp16_compress_hook 等效于 fp16_compress_wrapper(allreduce_hook)

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
返回值类型

Callable[[Any, GradBucket], Future[Tensor]]

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[source][source]

警告:此 API 处于实验阶段,需要 NCCL 版本高于 2.9.6。

此包装器将给定 DDP 通信 hook 的输入梯度张量强制转换为半精度 `Brain 浮点格式 `_(torch.bfloat16),并将给定 hook 的结果张量强制转换回输入数据类型,例如 float32

因此,bf16_compress_hook 等效于 bf16_compress_wrapper(allreduce_hook)

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
返回值类型

Callable[[Any, GradBucket], Future[Tensor]]

PowerSGD 通信 Hook

PowerSGD (Vogels 等,NeurIPS 2019) 是一种梯度压缩算法,可以提供非常高的压缩率并加速带宽受限的分布式训练。该算法需要同时维护一些超参数和内部状态。因此,PowerSGD 通信 hook 是一个 有状态 的 hook,用户需要提供如下定义的状态对象。

PowerSGD 状态

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=1000, min_compression_rate=2, use_error_feedback=True, warm_start=True, orthogonalization_epsilon=0, random_seed=0, compression_stats_logging_frequency=10000, batch_tensors_with_same_shape=False)[source][source]

在训练期间存储算法的超参数和所有梯度的内部状态。

特别地,matrix_approximation_rankstart_powerSGD_iter 是用户应调整的主要超参数。为了性能,我们建议保持二进制超参数 use_error_feedbackwarm_start 为 True。

  1. matrix_approximation_rank 控制压缩低秩张量的大小,这决定了压缩率。秩越低,压缩越强。

    1.1. 如果 matrix_approximation_rank 过低,模型质量将需要更多训练步骤才能达到,或者永远无法达到,并导致精度损失。

    1.2. 增加 matrix_approximation_rank 会显著增加压缩的计算成本,并且精度可能不会超过某个 matrix_approximation_rank 阈值进一步提高。

为了调优 matrix_approximation_rank,我们建议从 1 开始,并以 2 的倍数增加(例如指数网格搜索,1、2、4…),直到达到令人满意的精度。通常只使用一个较小的值,如 1-4。对于一些 NLP 任务(如原始论文附录 D 所示),该值已增加到 32。

  1. start_powerSGD_iter 会将 PowerSGD 压缩推迟到步骤 start_powerSGD_iter 之后,而 vanilla allreduce 会在步骤 start_powerSGD_iter 之前运行。这种 vanilla allreduce + PowerSGD 的混合方案可以有效提高精度,即使使用相对较小的 matrix_approximation_rank。这是因为训练阶段的开始通常对不准确的梯度非常敏感,过早压缩梯度可能会使训练迅速走向次优轨迹,从而对精度产生不可逆转的影响。

为了调优 start_powerSGD_iter,我们建议从总训练步骤的 10% 开始,并增加它直到达到令人满意的精度。如果在训练中有热身(warm-up)阶段,start_powerSGD_iter 通常应不小于热身步骤的数量。

  1. min_compression_rate 是压缩层时所需的最小压缩率。由于压缩带来的计算开销,只有当带宽能节省足够多时才值得压缩张量,即 (num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols。如果指定的压缩率阈值无法满足,张量将直接进行 allreduce 而不进行压缩。

PowerSGD 压缩开始后,每 compression_stats_logging_frequency 次迭代记录一次压缩统计信息。

  1. orthogonalization_epsilon 是在正交化步骤中添加到每个归一化矩阵列的非常小的值(例如 1e-8),以防止任一列全为 0 时出现除以零错误。如果这已可避免(例如通过批量归一化),建议使用 0 的 epsilon 以获得精度。

  2. batch_tensors_with_same_shape 控制是否在批量操作中压缩和解压缩形状相同的张量,以实现更高的并行性。请注意,您还应该增加桶的大小(即 DDP 构造函数中的 bucket_cap_mb 参数),以使更多形状相同的张量出现在同一个桶中,但这可能会减少计算和通信之间的重叠,并因堆叠相同形状的张量而增加内存占用。如果压缩/解压缩计算是瓶颈,请设置为 True

警告

如果启用错误反馈或热身,DDP 中允许的 start_powerSGD_iter 的最小值为 2。这是因为 DDP 中存在另一个内部优化,它在迭代 1 时重建桶,这可能会与重建过程之前记忆的任何张量冲突。

PowerSGD Hook

警告

PowerSGD 通常需要与模型梯度大小相同的额外内存来启用错误反馈,这可以补偿有偏的压缩通信并提高精度。

警告

PowerSGD 钩子可能与 Apex 自动混合精度包 冲突。请改用 PyTorch 原生自动混合精度包

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)[source][source]

实现 PowerSGD 算法。

这个 DDP 通信钩子实现了 论文中描述的 PowerSGD 梯度压缩算法。一旦梯度张量在所有工作进程中聚合完毕,这个钩子会按如下方式应用压缩:

  1. 将输入的展平一维梯度张量视为一个包含每个参数张量的列表,并将所有张量分为两组:

    1.1 在 allreduce 之前应该被压缩的张量,因为压缩可以显著节省带宽。

    1.2 剩余的张量将不经压缩直接进行 allreduce,包括所有向量张量(用于偏置)。

  2. 处理未压缩的张量:

    2.1. 为这些未压缩张量分配连续内存,并将所有未压缩张量作为一个批次进行 allreduce,不进行压缩;

    2.2. 将单个未压缩张量从未压缩张量的连续内存复制回输入张量。

  3. 处理应该通过 PowerSGD 压缩的张量:

    3.1. 对于每个张量 M,创建两个低秩张量 P 和 Q 来分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;

    3.2. 计算 Ps 中的每个 P,其等于 MQ;

    3.3. 将 Ps 作为一个批次进行 allreduce;

    3.4. 对 Ps 中的每个 P 进行正交化;

    3.5. 计算 Qs 中的每个 Q,其约等于 M^TP;

    3.6. 将 Qs 作为一个批次进行 allreduce;

    3.7. 计算所有被压缩张量中的每个 M,其约等于 PQ^T。

请注意,这个通信钩子在前 state.start_powerSGD_iter 次迭代中强制执行原版 allreduce。这不仅让用户可以更好地控制速度提升和精度之间的权衡,还有助于未来的通信钩子开发者抽象 DDP 内部优化的一些复杂性。

参数
  • state (PowerSGDState) – 用于配置压缩率并支持误差反馈、热启动等的状态信息。要调整压缩配置,主要需要调整 matrix_approximation_rankstart_powerSGD_itermin_compression_rate

  • bucket (dist.GradBucket) – 存储展平的一维梯度张量的桶,该张量批处理了多个按变量划分的张量。请注意,由于 DDP 通信钩子仅支持单进程单设备模式,因此此桶中仅存储一个张量。

返回值

通信的 Future handler,它会就地更新梯度。

返回值类型

Future[Tensor]

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
                          start_powerSGD_iter=10, min_compression_rate=0.5)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)[source][source]

实现简化的 PowerSGD 算法。

这个 DDP 通信钩子实现了 论文中描述的简化的 PowerSGD 梯度压缩算法。这个变体不按层压缩梯度,而是压缩批处理所有梯度的展平输入张量。因此,它比 powerSGD_hook() **更快**,但通常会导致**精度大大降低**,除非 matrix_approximation_rank 为 1。

警告

在这里增加 matrix_approximation_rank 可能不一定能提高精度,因为在没有列/行对齐的情况下批处理每个参数张量可能会破坏低秩结构。因此,用户应始终首先考虑 powerSGD_hook(),仅当 matrix_approximation_rank 为 1 时可以达到令人满意的精度时,才考虑使用此变体。

一旦梯度张量在所有工作进程中聚合完毕,这个钩子会按如下方式应用压缩:

  1. 将输入的展平一维梯度张量视为一个带 0 填充的方形张量 M;

  2. 创建两个低秩张量 P 和 Q 来分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;

  3. 计算 P,其等于 MQ;

  4. 对 P 进行 allreduce;

  5. 对 P 进行正交化;

  6. 计算 Q,其约等于 M^TP;

  7. 对 Q 进行 allreduce;

  8. 计算 M,其约等于 PQ^T。

  9. 将输入张量截断到原始长度。

请注意,这个通信钩子在前 state.start_powerSGD_iter 次迭代中强制执行原版 allreduce。这不仅让用户可以更好地控制速度提升和精度之间的权衡,还有助于未来的通信钩子开发者抽象 DDP 内部优化的一些复杂性。

参数
  • state (PowerSGDState) – 用于配置压缩率并支持误差反馈、热启动等的状态信息。要调整压缩配置,主要需要调整 matrix_approximation_rankstart_powerSGD_iter

  • bucket (dist.GradBucket) – 存储展平的一维梯度张量的桶,该张量批处理了多个按变量划分的张量。请注意,由于 DDP 通信钩子仅支持单进程单设备模式,因此此桶中仅存储一个张量。

返回值

通信的 Future handler,它会就地更新梯度。

返回值类型

Future[Tensor]

示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)

调试通信钩子

顾名思义,调试通信钩子**仅**用于调试和性能优化目的。

警告

调试通信钩子不一定会输出正确的结果。

torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_, bucket)[source][source]

返回一个包装输入的 Future,因此它是一个不产生任何通信开销的空操作(no-op)。

这个钩子**仅**应用于 allreduce 优化上限的分析,而不是正常的梯度同步。例如,如果在注册此钩子后训练时间仅提速不到 10%,通常意味着在这种情况下 allreduce 不是性能瓶颈。如果 GPU 轨迹难以获取,或者轨迹分析因 allreduce 与计算的重叠或跨 rank 的不同步等因素而变得复杂,这种检测会特别有用。

示例:
>>> ddp_model.register_comm_hook(None, noop_hook)
返回值类型

Future[Tensor]

通信钩子的检查点保存

有状态的通信钩子可以作为模型检查点的一部分保存,以支持训练器重启。要使钩子可序列化,应该定义 __setstate____getstate__

警告

__getstate__ 应该从返回的字典中排除不可序列化的属性。

警告

__setstate__ 应该正确初始化从提供的 state 中排除的不可序列化属性。

PowerSGDState 实现了 __setstate____getstate__,可以作为参考。

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState[source][source]
__getstate__()[source][source]

返回一个 Dict[str, Any],它将被 pickle 化并保存。

process_group 不可序列化,因此从返回的状态中排除。

__setstate__(state)[source][source]

接受一个提供的 state 并设置到此 PowerSGDState 实例。

process_group 被设置为默认值。

这里是一个简单、端到端保存和重新加载 PowerSGD state 和 hook 的示例。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24,24)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(24,12)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def run_demo(demo_fn, world_size):
    mp.spawn(
        demo_fn,
        args=(world_size,),
        nprocs=world_size,
        join=True)

def demo_serialization(rank, world_size):
    setup(rank, world_size)

    CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"

    model = SimpleModel().to(rank)
    ddp_model = DistributedDataParallel(model, device_ids=[rank])

    powersgd_hook = powerSGD.powerSGD_hook
    powersgd_state = powerSGD.PowerSGDState(process_group=None)

    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    state = {
        'state_dict': ddp_model.state_dict(),
        'comm_hook': powersgd_hook,
        'comm_hook_state': powersgd_state}

    if rank == 0:
        torch.save(state, CHECKPOINT)

    dist.barrier()
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    checkpoint = torch.load(CHECKPOINT, map_location=map_location)

    new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
    new_ddp_model.load_state_dict(checkpoint['state_dict'])
    powersgd_hook = checkpoint['comm_hook']
    powersgd_state = checkpoint['comm_hook_state']

    new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)

    if rank == 0:
        os.remove(CHECKPOINT)

    cleanup()

if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_serialization, world_size)

致谢

非常感谢 PowerSGD 论文作者 **Thijs Vogels** 对 PowerSGD 通信钩子代码进行了审查,以及提供了对比实验,这些实验表明 PowerSGD 通信钩子的性能与原始论文中的实现相当。

文档

访问 PyTorch 全面的开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源