快捷方式

带零冗余优化器的分片优化器状态

在本食谱中,您将学习

要求

什么是 ZeroRedundancyOptimizer

ZeroRedundancyOptimizer》的概念源于《DeepSpeed/ZeRO 项目》和《Marian》,它们将优化器状态跨分布式数据并行进程进行分片,以减少每个进程的内存占用。在《分布式数据并行入门》教程中,我们展示了如何使用《DistributedDataParallel》(DDP)来训练模型。在该教程中,每个进程都维护一个专门的优化器副本。由于 DDP 已经在反向传播过程中同步了梯度,因此所有优化器副本将在每次迭代中对相同的参数和梯度值进行操作,这就是 DDP 如何保持模型副本处于相同状态的方式。通常,优化器还会维护局部状态。例如,《Adam》优化器使用每个参数的《exp_avg》和《exp_avg_sq》状态。因此,《Adam》优化器的内存消耗至少是模型大小的两倍。根据此观察结果,我们可以通过跨 DDP 进程分片优化器状态来减少优化器内存占用。更具体地说,不是为所有参数创建每个参数状态,而是不同 DDP 进程中的每个优化器实例只为所有模型参数的一部分保留优化器状态。优化器《step()》函数只更新其分片中的参数,然后将其更新后的参数广播到所有其他对等 DDP 进程,这样所有模型副本仍然处于相同状态。

如何使用《ZeroRedundancyOptimizer》?

以下代码演示了如何使用《ZeroRedundancyOptimizer》。大部分代码与《分布式数据并行笔记》中介绍的简单 DDP 示例类似。主要区别在于《example》函数中的《if-else》子句,它包装了优化器构造,在《ZeroRedundancyOptimizer》和《Adam》优化器之间切换。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

def print_peak_memory(prefix, device):
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def example(rank, world_size, use_zero):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
    print_peak_memory("Max memory allocated after creating local model", rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    print_peak_memory("Max memory allocated after creating DDP", rank)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=0.01
        )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

    # forward pass
    outputs = ddp_model(torch.randn(20, 2000).to(rank))
    labels = torch.randn(20, 2000).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()

    # update parameters
    print_peak_memory("Max memory allocated before optimizer step()", rank)
    optimizer.step()
    print_peak_memory("Max memory allocated after optimizer step()", rank)

    print(f"params sum is: {sum(model.parameters()).sum()}")



def main():
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)

    print("=== Not Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

输出结果如下所示。当启用《ZeroRedundancyOptimizer》和《Adam》时,优化器《step()》的峰值内存消耗是普通《Adam》内存消耗的一半。这符合我们的预期,因为我们正在跨两个进程对《Adam》优化器状态进行分片。输出结果还表明,使用《ZeroRedundancyOptimizer》,模型参数在一次迭代后仍然最终具有相同的值(使用和不使用《ZeroRedundancyOptimizer》的参数总和相同)。

=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源