快捷方式

使用 ZeroRedundancyOptimizer 分片优化器状态

创建于:2021 年 2 月 26 日 | 上次更新:2021 年 10 月 20 日 | 上次验证:未验证

在本食谱中,您将学习

要求

ZeroRedundancyOptimizer 是什么?

ZeroRedundancyOptimizer 的想法来自 DeepSpeed/ZeRO 项目Marian,它们跨分布式数据并行进程分片优化器状态,以减少每个进程的内存占用。在分布式数据并行入门教程中,我们展示了如何使用 DistributedDataParallel (DDP) 来训练模型。在该教程中,每个进程都保留优化器的专用副本。由于 DDP 已经在反向传播中同步了梯度,因此所有优化器副本将在每次迭代中对相同的参数和梯度值进行操作,这就是 DDP 如何保持模型副本处于相同状态的原因。通常,优化器还会维护本地状态。例如,Adam 优化器使用每参数 exp_avgexp_avg_sq 状态。因此,Adam 优化器的内存消耗至少是模型大小的两倍。鉴于这一观察结果,我们可以通过跨 DDP 进程分片优化器状态来减少优化器内存占用。更具体地说,每个 DDP 进程中的每个优化器实例不是为所有参数创建每参数状态,而仅保留所有模型参数分片的优化器状态。优化器 step() 函数仅更新其分片中的参数,然后将其更新后的参数广播到所有其他对等 DDP 进程,以便所有模型副本仍然处于相同的状态。

如何使用 ZeroRedundancyOptimizer

下面的代码演示了如何使用 ZeroRedundancyOptimizer。大部分代码与分布式数据并行注释中介绍的简单 DDP 示例类似。主要区别在于 example 函数中的 if-else 子句,它包装了优化器构造,在 ZeroRedundancyOptimizerAdam 优化器之间切换。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

def print_peak_memory(prefix, device):
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def example(rank, world_size, use_zero):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
    print_peak_memory("Max memory allocated after creating local model", rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    print_peak_memory("Max memory allocated after creating DDP", rank)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=0.01
        )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

    # forward pass
    outputs = ddp_model(torch.randn(20, 2000).to(rank))
    labels = torch.randn(20, 2000).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()

    # update parameters
    print_peak_memory("Max memory allocated before optimizer step()", rank)
    optimizer.step()
    print_peak_memory("Max memory allocated after optimizer step()", rank)

    print(f"params sum is: {sum(model.parameters()).sum()}")



def main():
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)

    print("=== Not Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

输出如下所示。当使用 Adam 启用 ZeroRedundancyOptimizer 时,优化器 step() 峰值内存消耗是普通 Adam 内存消耗的一半。这与我们的预期一致,因为我们正在跨两个进程分片 Adam 优化器状态。输出还显示,使用 ZeroRedundancyOptimizer 后,模型参数在一次迭代后仍然以相同的值结束(参数总和在使用和不使用 ZeroRedundancyOptimizer 的情况下是相同的)。

=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源