快捷方式

分布式优化器

警告

当前使用 CUDA 张量时不支持分布式优化器

torch.distributed.optim 提供了 DistributedOptimizer,它接受远程参数列表 (RRef) 并在参数所在的 workers 上本地运行优化器。分布式优化器可以使用任何本地优化器 基类 来在每个 worker 上应用梯度。

class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)[source][source]

DistributedOptimizer 接受分散在 workers 上的参数的远程引用,并为每个参数在本地应用给定的优化器。

此类使用 get_gradients() 来检索特定参数的梯度。

来自相同或不同客户端的 step() 的并发调用将在每个 worker 上序列化——因为每个 worker 的优化器一次只能处理一组梯度。但是,不能保证完整的正向-反向-优化器序列会为单个客户端一次性执行。这意味着应用的梯度可能与给定 worker 上执行的最新正向传播不对应。此外,worker 之间没有顺序保证。

DistributedOptimizer 默认启用 TorchScript 创建本地优化器,这样在多线程训练(例如分布式模型并行)的情况下,优化器更新不会被 Python 全局解释器锁 (GIL) 阻塞。此功能目前已为大多数优化器启用。你还可以按照 PyTorch 教程中的 实用代码示例 来为自己的自定义优化器启用 TorchScript 支持。

参数
  • optimizer_class (optim.Optimizer) – 要在每个 worker 上实例化的优化器类。

  • params_rref (list[RRef]) – 需要优化的本地或远程参数的 RRefs 列表。

  • args – 传递给每个 worker 上优化器构造函数的参数。

  • kwargs – 传递给每个 worker 上优化器构造函数的参数。

示例:
>>> import torch.distributed.autograd as dist_autograd
>>> import torch.distributed.rpc as rpc
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>>
>>> with dist_autograd.context() as context_id:
>>>   # Forward pass.
>>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>>   loss = rref1.to_here() + rref2.to_here()
>>>
>>>   # Backward pass.
>>>   dist_autograd.backward(context_id, [loss.sum()])
>>>
>>>   # Optimizer.
>>>   dist_optim = DistributedOptimizer(
>>>      optim.SGD,
>>>      [rref1, rref2],
>>>      lr=0.05,
>>>   )
>>>   dist_optim.step(context_id)
step(context_id)[source][source]

执行单个优化步骤。

这将在包含待优化参数的每个 worker 上调用 torch.optim.Optimizer.step(),并阻塞直到所有 worker 返回。提供的 context_id 将用于检索包含应应用于参数的梯度的相应 context

参数

context_id – 我们应运行优化器步骤的 autograd 上下文 ID。

class torch.distributed.optim.PostLocalSGDOptimizer(optim, averager)[source][source]

封装任意 torch.optim.Optimizer 并运行 post-local SGD。此优化器在每一步都运行本地优化器。在热身阶段后,它会在应用本地优化器后定期平均参数。

参数
  • optim (Optimizer) – 本地优化器。

  • averager (ModelAverager) – 用于运行 post-localSGD 算法的模型平均器实例。

示例

>>> import torch
>>> import torch.distributed as dist
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
>>> import torch.nn as nn
>>> from torch.distributed.optim import PostLocalSGDOptimizer
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>>   PostLocalSGDState,
>>>   post_localSGD_hook,
>>> )
>>>
>>> model = nn.parallel.DistributedDataParallel(
>>>    module, device_ids=[rank], output_device=rank
>>> )
>>>
>>> # Register a post-localSGD communication hook.
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
>>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
>>>     optim=local_optim,
>>>     averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>>
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
>>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
>>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
>>> for step in range(0, 200):
>>>    opt.zero_grad()
>>>    loss = loss_fn(output, labels)
>>>    loss.backward()
>>>    opt.step()
load_state_dict(state_dict)[source][source]

这与 torch.optim.Optimizerload_state_dict() 相同,但也会将模型平均器的步数值恢复到提供的 state_dict 中保存的值。

如果 state_dict 中没有 "step" 条目,则会引发警告并将模型平均器的步数初始化为 0。

state_dict()[source][source]

这与 torch.optim.Optimizerstate_dict() 相同,但会添加一个额外的条目来将模型平均器的步数记录到检查点,以确保重新加载不会再次导致不必要的热身。

step()[source][source]

执行单个优化步骤(参数更新)。

class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[source][source]

封装任意 optim.Optimizer 并将其状态分片到组中的 ranks。

分享方式如 ZeRO 所述。

每个 rank 中的本地优化器实例仅负责更新大约 1 / world_size 参数,因此只需要保留 1 / world_size 的优化器状态。在本地更新参数后,每个 rank 会将其参数广播给所有其他 peers,以使所有模型副本保持相同状态。ZeroRedundancyOptimizer 可以与 torch.nn.parallel.DistributedDataParallel 结合使用,以减少每个 rank 的峰值内存消耗。

ZeroRedundancyOptimizer 使用排序贪心算法在每个 rank 上打包多个参数。每个参数属于单个 rank,不会在 ranks 之间分割。分区是任意的,可能与参数注册或使用顺序不匹配。

参数

params (Iterable) – 提供所有参数的 Iterable,可以是 torch.Tensordict,这些参数将被分片到 ranks。

关键字参数
  • optimizer_class (torch.nn.Optimizer) – 本地优化器的类。

  • process_group (ProcessGroup, optional) – torch.distributedProcessGroup(默认为由 torch.distributed.init_process_group() 初始化的 dist.group.WORLD)。

  • parameters_as_bucket_view (bool, optional) – 如果为 True,则参数会被打包到 bucket 中以加速通信,并且 param.data 字段指向不同偏移量的 bucket 视图;如果为 False,则每个单独的参数单独通信,并且每个 params.data 保持不变(默认为 False)。

  • overlap_with_ddp (bool, optional) – 如果为 True,则 step()DistributedDataParallel 的梯度同步重叠;这要求 (1) optimizer_class 参数是函数式优化器或具有函数式等效项,并且 (2) 注册一个由 ddp_zero_hook.py 中的某个函数构造的 DDP 通信钩子;参数会被打包到与 DistributedDataParallel 中匹配的 bucket 中,这意味着将忽略 parameters_as_bucket_view 参数。如果为 False,则 step() 在反向传播之后独立运行(按正常情况)。(默认为 False

  • **defaults – 任何尾随参数,这些参数将被转发给本地优化器。

示例

>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(),
>>>     optimizer_class=torch.optim.Adam,
>>>     lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

警告

当前,ZeroRedundancyOptimizer 要求传递进来的所有参数具有相同的密集类型。

警告

如果你传递 overlap_with_ddp=True,请注意以下事项:考虑到当前实现 DistributedDataParallelZeroRedundancyOptimizer 重叠的方式,前两次或三次训练迭代不会在优化器步骤中执行参数更新,具体取决于 static_graph=False 还是 static_graph=True。这是因为它需要关于 DistributedDataParallel 使用的梯度 bucketing 策略的信息,该策略直到第二次正向传播(如果 static_graph=False)或直到第三次正向传播(如果 static_graph=True)才最终确定。为了对此进行调整,一种选择是预置 dummy inputs。

警告

ZeroRedundancyOptimizer 是实验性的,未来可能发生变化。

add_param_group(param_group)[source][source]

将参数组添加到 Optimizerparam_groups 中。

在对预训练网络进行微调时,这很有用,因为冻结的层可以被设置为可训练,并随着训练的进行添加到 Optimizer 中。

参数

param_group (dict) – 指定要优化的参数和组特定的优化选项。

警告

此方法处理更新所有分片上的 shards,但需要在所有 ranks 上调用。在 ranks 的子集上调用此方法将导致训练挂起,因为通信原语根据管理的参数被调用,并期望所有 ranks 参与相同的参数集。

consolidate_state_dict(to=0)[source][source]

在目标 rank 上整合 state_dict 列表(每个 rank 一个)。

参数

to (int) – 接收优化器状态的 rank(默认为 0)。

抛出

RuntimeError – 如果 overlap_with_ddp=True 并且此方法在 ZeroRedundancyOptimizer 实例完全初始化之前被调用(初始化在 DistributedDataParallel 梯度 bucket 重建后发生)。

警告

这需要在所有 ranks 上调用。

property join_device: device

返回默认设备。

join_hook(**kwargs)[source][source]

返回 ZeRO join 钩子。

它通过在优化器步骤中模仿集体通信来启用对非均匀输入的训练。

在调用此钩子之前,必须正确设置梯度。

参数

kwargs (dict) – 一个 dict,包含任何用于在运行时修改 join 钩子行为的关键字参数;所有共享相同 join 上下文管理器的 Joinable 实例都转发相同的 kwargs 值。

此钩子不支持任何关键字参数;即 kwargs 未使用。

property join_process_group: Any

返回进程组。

load_state_dict(state_dict)[source][source]

从输入的 state_dict 加载与给定 rank 相关的状态,并根据需要更新本地优化器。

参数

state_dict (dict) – 优化器状态;应为调用 state_dict() 返回的对象。

抛出

RuntimeError – 如果 overlap_with_ddp=True 并且此方法在 ZeroRedundancyOptimizer 实例完全初始化之前被调用(初始化在 DistributedDataParallel 梯度 bucket 重建后发生)。

state_dict()[source][source]

返回此 rank 已知的最后一个全局优化器状态。

抛出

RuntimeError – 如果 overlap_with_ddp=True 并且此方法在 ZeroRedundancyOptimizer 实例完全初始化之前被调用(初始化在 DistributedDataParallel 梯度 bucket 重建后发生);或者如果在调用此方法之前没有调用 consolidate_state_dict()

返回类型

dict[str, Any]

step(closure=None, **kwargs)[source][source]

执行单个优化器步骤并跨所有 rank 同步参数。

参数

closure (可调用对象) – 一个闭包,用于重新评估模型并返回损失;对于大多数优化器是可选的。

返回值

可选的损失值,取决于底层的本地优化器。

返回类型

Optional[float]

注意

任何额外参数将按原样传递给基础优化器。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并解答你的问题

查看资源