DDP 通信钩子¶
DDP 通信钩子是一个通用接口,用于通过覆盖 DistributedDataParallel 中的普通 allreduce 来控制如何在工作进程之间通信梯度。提供了一些内置的通信钩子,用户可以轻松地将这些钩子中的任何一个应用于优化通信。此外,钩子接口还可以支持用户定义的通信策略,以满足更高级的用例。
如何使用通信钩子?¶
要使用通信钩子,用户只需让 DDP 模型在训练循环之前注册钩子,如下所示。
torch.nn.parallel.DistributedDataParallel.register_comm_hook()
通信钩子对什么进行操作?¶
通信钩子提供了一种灵活的方式来对梯度进行 allreduce。因此,它主要在 allreduce 之前的每个副本上的梯度上进行操作,这些梯度被分桶以增加通信和计算之间的重叠。特别地,torch.distributed.GradBucket
代表要进行 allreduce 的一个梯度张量桶。
- class torch.distributed.GradBucket¶
此类主要将一个扁平化的梯度张量(由
buffer()
返回)传递给 DDP 通信钩子。此张量可以进一步分解为此桶中的每个参数张量列表(由get_per_parameter_tensors()
返回),以应用逐层操作。
- torch.distributed.GradBucket.index(self: torch._C._distributed_c10d.GradBucket) int ¶
警告
由于桶在第一次迭代后会重建,因此不应该依赖于训练开始时的索引。
- 返回值
存储一些连续层的梯度的桶的索引。所有梯度都被分桶了。
- torch.distributed.GradBucket.buffer(self: torch._C._distributed_c10d.GradBucket) torch.Tensor ¶
- 返回值
一个扁平化的 1D
torch.Tensor
缓冲区,可以进一步分解为此桶中的每个参数张量列表。
- torch.distributed.GradBucket.gradients(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor] ¶
- 返回值
一个
torch.Tensor
列表。列表中的每个张量对应一个梯度。
- torch.distributed.GradBucket.is_last(self: torch._C._distributed_c10d.GradBucket) bool ¶
- 返回值
该桶是否是迭代中最后一个用于allreduce的桶。这也意味着该桶对应于前向传递中的前几层。
- torch.distributed.GradBucket.set_buffer(self: torch._C._distributed_c10d.GradBucket, buffer: torch.Tensor) None ¶
用输入张量缓冲区替换桶中的张量。
- torch.distributed.GradBucket.parameters(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor] ¶
- 返回值
一个
torch.Tensor
列表。列表中的每个张量对应于模型参数。
默认通信钩子¶
默认通信钩子是简单的**无状态**钩子,因此 register_comm_hook
中的输入状态是进程组或 None
。输入 bucket
是一个 torch.distributed.GradBucket
对象。
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket)[source]¶
使用
GradBucket
张量调用allreduce
。一旦梯度张量在所有工作器上聚合,它的
then
回调将计算平均值并返回结果。如果用户注册了这个 DDP 通信钩子,DDP 的结果应该与没有注册钩子的情况相同。因此,这不会改变 DDP 的行为,用户可以将其用作参考或修改这个钩子来记录有用的信息或任何其他目的,同时不会影响 DDP 的行为。
- 示例:
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket)[source]¶
通过将
GradBucket
转换为torch.float16
并除以进程组大小来压缩。此 DDP 通信钩子实现了一种简单的梯度压缩方法,它将
GradBucket
张量转换为半精度浮点格式 (torch.float16
),然后除以进程组大小。它对这些float16
梯度张量进行 allreduce。一旦压缩后的梯度张量完成 allreduce,链式回调decompress
将其转换回输入数据类型(例如float32
)。- 示例:
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group, bucket)[source]¶
警告:此 API 处于实验阶段,需要 NCCL 版本高于 2.9.6。
此 DDP 通信钩子实现了一种简单的梯度压缩方法,它将
GradBucket
张量转换为半精度 脑浮点格式 (torch.bfloat16
),然后除以进程组大小。它对这些bfloat16
梯度张量进行 allreduce。一旦压缩后的梯度张量完成 allreduce,链式回调decompress
将其转换回输入数据类型(例如float32
)。- 示例:
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
此外,还提供了一个通信钩子包装器来支持 fp16_compress_hook()
或 bf16_compress_hook()
作为包装器,可以与其他通信钩子组合。
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[source]¶
将输入张量转换为
torch.float16
,将钩子的结果转换回输入数据类型。此包装器将给定 DDP 通信钩子的输入梯度张量转换为半精度浮点格式 (
torch.float16
),并将给定钩子的结果张量转换回输入数据类型,例如float32
。因此,fp16_compress_hook
等效于fp16_compress_wrapper(allreduce_hook)
。- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
- 返回类型
Callable[[Any, GradBucket], Future[Tensor]]
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[source]¶
警告:此 API 处于实验阶段,需要 NCCL 版本高于 2.9.6。
此包装器将给定 DDP 通信钩子的输入梯度张量转换为半精度 脑浮点格式 <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format> `_ (``torch.bfloat16`),并将给定钩子的结果张量转换回输入数据类型,例如
float32
。因此,
bf16_compress_hook
等效于bf16_compress_wrapper(allreduce_hook)
。- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
- 返回类型
Callable[[Any, GradBucket], Future[Tensor]]
PowerSGD 通信钩子¶
PowerSGD (Vogels 等人,NeurIPS 2019) 是一种梯度压缩算法,它可以提供非常高的压缩率并加速带宽受限的分布式训练。该算法需要维护一些超参数和内部状态。因此,PowerSGD 通信钩子是**有状态**的钩子,用户需要提供如下定义的状态对象。
PowerSGD 状态¶
- class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=1000, min_compression_rate=2, use_error_feedback=True, warm_start=True, orthogonalization_epsilon=0, random_seed=0, compression_stats_logging_frequency=10000, batch_tensors_with_same_shape=False)[source]¶
在训练期间,存储算法的超参数和所有梯度的内部状态。
特别是,
matrix_approximation_rank
和start_powerSGD_iter
是用户应该调整的主要超参数。为了提高性能,建议将二进制超参数use_error_feedback
和warm_start
打开。matrix_approximation_rank
控制压缩低秩张量的尺寸,从而决定压缩率。秩越低,压缩越强。1.1. 如果
matrix_approximation_rank
太低,完整模型质量需要更多训练步骤才能达到或永远无法达到,导致精度损失。1.2.
matrix_approximation_rank
的增加会大幅增加压缩的计算成本,并且在超过一定matrix_approximation_rank
阈值后,精度可能不会进一步提高。
为了调整
matrix_approximation_rank
,建议从 1 开始,以 2 的倍数递增(类似于指数网格搜索,1、2、4、…),直到达到满意的精度。通常只使用 1-4 的小值。对于某些 NLP 任务(如原始论文的附录 D 所示),该值已增加到 32。start_powerSGD_iter
将 PowerSGD 压缩推迟到步骤start_powerSGD_iter
,并在步骤start_powerSGD_iter
之前运行普通 allreduce。这种 **普通 allreduce + PowerSGD** 的混合方案可以有效地提高精度,即使使用相对较小的matrix_approximation_rank
。这是因为,训练阶段的开始通常对不准确的梯度非常敏感,过早地压缩梯度可能会使训练很快走上一条次优的轨迹,从而对精度造成不可逆转的影响。
为了调整
start_powerSGD_iter
,建议从总训练步骤的 10% 开始,并逐渐增加,直到达到满意的精度。如果训练中存在预热阶段,start_powerSGD_iter
通常不应小于预热步骤的数量。min_compression_rate
是层压缩时所需的最小压缩率。由于压缩带来的计算开销,只有在带宽节省足够的情况下,才值得对张量进行压缩,其中(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols
。如果指定的压缩率阈值无法满足,则将直接对张量进行 allreduce,而不进行压缩。
压缩统计信息在 PowerSGD 压缩开始后每
compression_stats_logging_frequency
次迭代记录一次。orthogonalization_epsilon
可以是一个非常小的值(例如,1e-8),在正交化步骤中加到每个归一化矩阵列上,以防止任何列都包含全 0 时出现除以零错误。如果这种情况已经可以避免(例如,通过批次归一化),建议将 epsilon 设置为 0 以提高精度。batch_tensors_with_same_shape
控制是否在批次操作中压缩和解压缩形状相同的张量,以实现更高的并行性。请注意,您还应该增加桶的大小(即 DDP 构造函数中的bucket_cap_mb
参数),以使更多形状相同的张量出现在同一个桶中,但这可能会减少计算和通信之间的重叠,并由于堆叠形状相同的张量而增加内存占用。如果压缩/解压缩计算是瓶颈,则将其设置为True
。
警告
如果启用了错误反馈或预热,DDP 中允许的
start_powerSGD_iter
的最小值为 2。这是因为在 DDP 中,还有一个内部优化在迭代 1 处重建桶,这可能会与在重建过程之前记忆的任何张量发生冲突。
PowerSGD 钩子¶
警告
PowerSGD 通常需要与模型梯度大小相同的额外内存来启用错误反馈,这可以弥补压缩通信带来的偏差,并提高精度。
警告
PowerSGD 钩子可能会与 Apex 自动混合精度包 发生冲突。请改用 PyTorch 本机自动混合精度包。
- torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)[source]¶
实现 PowerSGD 算法。
这个 DDP 通信钩子实现了 论文 中描述的 PowerSGD 梯度压缩算法。一旦梯度张量在所有 worker 上聚合,此钩子将应用压缩,如下所示
将输入扁平化的 1D 梯度张量视为每个参数张量列表,并将所有张量分成两组
1.1. 应该在 allreduce 之前进行压缩的张量,因为压缩可以节省足够的带宽。
1.2. 其余的张量将直接进行 allreduce,而不进行压缩,包括所有向量张量(用于偏差)。
处理未压缩的张量
2.1. 为那些未压缩的张量分配连续内存,并将所有未压缩的张量作为一个批次进行 allreduce,不进行压缩;
2.2. 将各个未压缩的张量从连续内存复制回输入张量。
处理应该通过 PowerSGD 压缩进行压缩的张量
3.1. 对于每个张量 M,创建两个低秩张量 P 和 Q 来分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;
3.2. 计算 Ps 中的每个 P,它等于 MQ;
3.3. 将 Ps 作为一个批次进行 allreduce;
3.4. 将 Ps 中的每个 P 正交化;
3.5. 计算 Qs 中的每个 Q,它近似等于 M^TP;
3.6. 将 Qs 作为一个批次进行 allreduce;
3.7. 计算所有压缩张量中的每个 M,它近似等于 PQ^T。
请注意,此通信钩子在第一个
state.start_powerSGD_iter
次迭代中强制执行普通 allreduce。这不仅使用户能够更好地控制速度提升和精度之间的权衡,而且还有助于抽象化 DDP 内部优化的某些复杂性,以便未来的通信钩子开发人员使用。- 参数
state (PowerSGDState) – 状态信息,用于配置压缩率并支持错误反馈、预热等。为了调整压缩配置,主要需要调整
matrix_approximation_rank
、start_powerSGD_iter
和min_compression_rate
。bucket (dist.GradBucket) – 桶,用于存储一个 1D 扁平化的梯度张量,该张量对多个每个变量张量进行批处理。请注意,由于 DDP 通信钩子只支持单进程单设备模式,因此此桶中只存储一个张量。
- 返回值
通信的未来处理程序,用于就地更新梯度。
- 返回类型
- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10, min_compression_rate=0.5) >>> ddp_model.register_comm_hook(state, powerSGD_hook)
- torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)[source]¶
实现简化的 PowerSGD 算法。
这个 DDP 通信钩子实现了 论文 中描述的简化的 PowerSGD 梯度压缩算法。这种变体不逐层压缩梯度,而是压缩对所有梯度进行批处理的扁平化输入张量。因此,它比
powerSGD_hook()
**更快**,但通常会导致 **更低的精度**,除非matrix_approximation_rank
为 1。警告
在这里增加
matrix_approximation_rank
不一定会提高精度,因为对每个参数张量进行批处理而没有列/行对齐会破坏低秩结构。因此,用户应该始终首先考虑powerSGD_hook()
,只有在matrix_approximation_rank
为 1 时能够达到满意的精度的情况下,才考虑这种变体。一旦梯度张量在所有 worker 上聚合,此钩子将应用压缩,如下所示
将输入扁平化的 1D 梯度张量视为具有 0 填充的方形张量 M;
创建两个低秩张量 P 和 Q 来分解 M,使得 M = PQ^T,其中 Q 从标准正态分布初始化并正交化;
计算 P,它等于 MQ;
将 P 进行 allreduce;
将 P 正交化;
计算 Q,它近似等于 M^TP;
将 Q 进行 allreduce;
计算 M,它近似等于 PQ^T。
将输入张量截断为原始长度。
请注意,此通信钩子在第一个
state.start_powerSGD_iter
次迭代中强制执行普通 allreduce。这不仅使用户能够更好地控制速度提升和精度之间的权衡,而且还有助于抽象化 DDP 内部优化的某些复杂性,以便未来的通信钩子开发人员使用。- 参数
state (PowerSGDState) – 状态信息,用于配置压缩率并支持错误反馈、预热等。为了调整压缩配置,主要需要调整
matrix_approximation_rank
和start_powerSGD_iter
。bucket (dist.GradBucket) – 桶,用于存储一个 1D 扁平化的梯度张量,该张量对多个每个变量张量进行批处理。请注意,由于 DDP 通信钩子只支持单进程单设备模式,因此此桶中只存储一个张量。
- 返回值
通信的未来处理程序,用于就地更新梯度。
- 返回类型
- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
调试通信钩子¶
顾名思义,调试通信钩子 **仅** 用于调试和性能优化目的。
警告
调试通信钩子不一定输出正确的结果。
- torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_, bucket)[source]¶
返回一个包装输入的未来对象,使其成为一个不产生任何通信开销的空操作。
此钩子应该**仅**用于所有减少优化空间分析,而不是正常的梯度同步。例如,如果注册此钩子后,训练时间的加速不到 10%,通常意味着在这个情况下,所有减少不是性能瓶颈。这种检测特别有用,如果无法轻松获取 GPU 跟踪,或者跟踪分析因某些因素(如所有减少与计算的重叠或跨秩的非同步)而变得复杂。
- 示例:
>>> ddp_model.register_comm_hook(None, noop_hook)
通信钩子的检查点¶
有状态的通信钩子可以作为模型检查点的一部分保存,以启用训练器重启。为了使钩子可序列化,需要定义 __setstate__
和 __getstate__
。
警告
__getstate__
应该从返回的字典中排除不可序列化的属性。
警告
__setstate__
应该正确初始化不可序列化的属性,这些属性从提供的 state
中排除。
PowerSGDState
已经实现了 __setstate__
和 __getstate__
,可以作为参考。
- class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState[source]
以下是一个保存和重新加载 PowerSGD 状态和钩子的简单端到端示例。
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(24,24)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(24,12)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def run_demo(demo_fn, world_size):
mp.spawn(
demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
def demo_serialization(rank, world_size):
setup(rank, world_size)
CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"
model = SimpleModel().to(rank)
ddp_model = DistributedDataParallel(model, device_ids=[rank])
powersgd_hook = powerSGD.powerSGD_hook
powersgd_state = powerSGD.PowerSGDState(process_group=None)
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
state = {
'state_dict': ddp_model.state_dict(),
'comm_hook': powersgd_hook,
'comm_hook_state': powersgd_state}
if rank == 0:
torch.save(state, CHECKPOINT)
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
checkpoint = torch.load(CHECKPOINT, map_location=map_location)
new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
new_ddp_model.load_state_dict(checkpoint['state_dict'])
powersgd_hook = checkpoint['comm_hook']
powersgd_state = checkpoint['comm_hook_state']
new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
if rank == 0:
os.remove(CHECKPOINT)
cleanup()
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_serialization, world_size)