快捷方式

分布式 Autograd 设计

本笔记将介绍分布式 autograd 的详细设计,并逐步介绍其内部原理。请确保您已熟悉 Autograd 机制分布式 RPC 框架,然后再继续阅读。

背景

假设您有两个节点,并且有一个非常简单的模型分布在两个节点上。这可以使用 torch.distributed.rpc 按如下方式实现

import torch
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)

# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)

# Compute some loss.
loss = t5.sum()

分布式 autograd 背后的主要动机是能够对我们计算出的 loss 的这种分布式模型运行反向传播,并为所有需要梯度的张量记录适当的梯度。

前向传播期间的 Autograd 记录

PyTorch 在前向传播期间构建 autograd 图,该图用于执行反向传播。有关更多详细信息,请参阅 Autograd 如何编码历史记录

对于分布式 autograd,我们需要跟踪前向传播期间的所有 RPC,以确保反向传播得到适当的执行。为此,当执行 RPC 时,我们将 sendrecv 函数附加到 autograd 图。

  • send 函数附加到 RPC 的源,其输出边指向 RPC 输入张量的 autograd 函数。此函数在反向传播期间的输入是从目标接收的,作为相应 recv 函数的输出。

  • recv 函数附加到 RPC 的目标,其输入是从目标上使用输入张量执行的算子中检索的。此函数的输出梯度在反向传播期间发送到源节点,发送到相应的 send 函数。

  • 每个 send-recv 对都分配有一个全局唯一的 autograd_message_id,以唯一标识该对。这在反向传播期间查找远程节点上的相应函数时非常有用。

  • 对于 RRef,每当我们调用 torch.distributed.rpc.RRef.to_here() 时,我们都会为涉及的张量附加相应的 send-recv 对。

例如,以下是上述示例的 autograd 图的样子(为简单起见,排除了 t5.sum())

../_images/send_recv_functions.png

分布式 Autograd 上下文

每个使用分布式 autograd 的前向和反向传播都分配有一个唯一的 torch.distributed.autograd.context,此上下文具有全局唯一的 autograd_context_id。此上下文在每个节点上根据需要创建。

此上下文具有以下目的

  1. 运行分布式反向传播的多个节点可能会在同一张量上累积梯度,因此,在我们有机会运行优化器之前,张量的 .grad 字段将具有来自各种分布式反向传播的梯度。这类似于在本地多次调用 torch.autograd.backward()。为了提供一种分离每个反向传播梯度的途径,梯度会在每个反向传播的 torch.distributed.autograd.context 中累积。

  2. 在前向传播期间,我们将每个 autograd 传播的 sendrecv 函数存储在此上下文中。这确保我们持有对 autograd 图中适当节点的引用,以保持其存活。除此之外,在反向传播期间也很容易查找适当的 sendrecv 函数。

  3. 通常,我们还使用此上下文来存储每个分布式 autograd 传播的一些元数据。


从用户的角度来看,autograd 上下文的设置方式如下

import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
  loss = model.forward()
  dist_autograd.backward(context_id, loss)

重要的是要注意,模型的正向传播必须在分布式 autograd 上下文管理器中调用,因为需要有效的上下文以确保所有 sendrecv 函数都得到正确存储,以便在所有参与节点上运行反向传播。

分布式反向传播

在本节中,我们将概述在分布式反向传播期间准确计算依赖关系的挑战,并描述几种关于如何执行分布式反向传播的算法(具有权衡)。

计算依赖关系

考虑以下代码片段在单台机器上运行

import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()

以下是上述代码的 autograd 图的样子

../_images/local_dependencies.png

autograd 引擎作为反向传播一部分执行的第一步是计算 autograd 图中每个节点的依赖关系数量。这有助于 autograd 引擎了解图中何时可以执行节点。add(1)mul(0) 括号中的数字表示依赖关系的数量。如您所见,这意味着在反向传播期间,add 节点需要 1 个输入,而 mul 节点不需要任何输入(换句话说,不需要执行)。本地 autograd 引擎通过从根节点(在本例中为 d)遍历图来计算这些依赖关系。

autograd 图中的某些节点可能不会在反向传播中执行,这一事实对分布式 autograd 提出了挑战。考虑下面使用 RPC 的代码片段。

import torch
import torch.distributed.rpc as rpc

a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)

d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()

上述代码的关联 autograd 图将是

../_images/distributed_dependencies.png

计算此分布式 autograd 图的依赖关系更具挑战性,并且需要一些开销(计算或网络通信方面)。

对于性能敏感型应用,我们可以通过假设每个 sendrecv 函数在反向传播中都有效(大多数应用不会执行未使用的 RPC)来避免大量开销。这简化了分布式 autograd 算法,并且效率更高,但代价是应用需要了解其局限性。此算法称为 FAST 模式算法,并在下面详细介绍。

在一般情况下,并非每个 sendrecv 函数都必须在反向传播中有效。为了解决这个问题,我们提出了一种 SMART 模式算法,该算法将在后面的章节中介绍。请注意,目前仅实现了 FAST 模式算法。

FAST 模式算法

此算法的关键假设是,当我们运行反向传播时,每个 send 函数的依赖关系为 1。换句话说,我们假设我们将通过 RPC 从另一个节点接收梯度。

该算法如下:

  1. 我们从具有反向传播根(所有根都必须是本地的)的工作节点开始。

  2. 查找当前 分布式 Autograd 上下文的所有 send 函数。

  3. 从提供的根和我们检索到的所有 send 函数开始,在本地计算依赖关系。

  4. 计算依赖关系后,使用提供的根启动本地 autograd 引擎。

  5. 当 autograd 引擎执行 recv 函数时,recv 函数通过 RPC 将输入梯度发送到相应的工作节点。每个 recv 函数都知道目标工作节点 ID,因为它是在前向传播期间记录的。recv 函数还会将 autograd_context_idautograd_message_id 发送到远程主机。

  6. 当远程主机收到此请求时,我们使用 autograd_context_idautograd_message_id 来查找相应的 send 函数。

  7. 如果这是工作节点第一次收到给定 autograd_context_id 的请求,它将按照上面第 1-3 点所述在本地计算依赖关系。

  8. 然后在该工作节点的本地 autograd 引擎上将步骤 6 中检索到的 send 函数排队等待执行。

  9. 最后,我们不是在张量的 .grad 字段上累积梯度,而是为每个 分布式 Autograd 上下文单独累积梯度。梯度存储在 Dict[Tensor, Tensor] 中,这基本上是从张量到其关联梯度的映射,可以使用 get_gradients() API 检索此映射。


例如,带有分布式 autograd 的完整代码如下所示

import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:

# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
  t1 = torch.rand((3, 3), requires_grad=True)
  t2 = torch.rand((3, 3), requires_grad=True)

  # Perform some computation remotely.
  t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

  # Perform some computation locally based on remote result.
  t4 = torch.rand((3, 3), requires_grad=True)
  t5 = torch.mul(t3, t4)

  # Compute some loss.
  loss = t5.sum()

  # Run the backward pass.
  dist_autograd.backward(context_id, [loss])

  # Retrieve the gradients from the context.
  dist_autograd.get_gradients(context_id)

带有依赖关系的分布式 autograd 图如下所示(为简单起见,排除了 t5.sum())

../_images/distributed_dependencies_computed.png

应用于上述示例的 FAST 模式算法 如下所示

  1. Worker 0 上,我们从根 losssend1 开始计算依赖关系。因此,send1 被标记为依赖关系为 1,并且 Worker 0 上的 mul 被标记为依赖关系为 1。

  2. 现在,我们在 Worker 0 上启动本地 autograd 引擎。我们首先执行 mul 函数,将其输出累积在 autograd 上下文中作为 t4 的梯度。然后,我们执行 recv2,它将梯度发送到 Worker 1

  3. 由于这是 Worker 1 第一次听到关于此反向传播的信息,因此它开始依赖关系计算,并适当地标记 send2addrecv1 的依赖关系。

  4. 接下来,我们在 Worker 1 的本地 autograd 引擎上将 send2 排队,这反过来又会执行 addrecv1

  5. 当执行 recv1 时,它会将梯度发送到 Worker 0

  6. 由于 Worker 0 已经计算了此反向传播的依赖关系,因此它只是在本地排队并执行 send1

  7. 最后,t1t2t4 的梯度在 分布式 Autograd 上下文中累积。

SMART 模式算法

此算法的完整细节仍在制定中,但对于一般概念,您可以参考 RFC 中的分布式 Autograd 算法 Smart 模式部分。

分布式优化器

DistributedOptimizer 的工作方式如下

  1. 接受要优化的远程参数列表(RRef)。这些也可能是包装在本地 RRef 中的本地参数。

  2. Optimizer 类作为要在所有不同的 RRef 所有者上运行的本地优化器。

  3. 分布式优化器在每个工作节点上创建本地 Optimizer 的实例,并持有对它们的 RRef

  4. 当调用 torch.distributed.optim.DistributedOptimizer.step() 时,分布式优化器使用 RPC 远程执行适当远程工作节点上的所有本地优化器。分布式 autograd context_id 必须作为 torch.distributed.optim.DistributedOptimizer.step() 的输入提供。本地优化器使用它来应用存储在相应上下文中的梯度。

  5. 如果多个并发分布式优化器正在更新工作节点上的相同参数,则这些更新将通过锁进行序列化。

简单端到端示例

将所有内容放在一起,以下是使用分布式 autograd 和分布式优化器的简单端到端示例。如果代码放在名为“dist_autograd_simple.py”的文件中,则可以使用命令 MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py 运行

import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer

def random_tensor():
    return torch.rand((3, 3), requires_grad=True)

def _run_process(rank, dst_rank, world_size):
    name = "worker{}".format(rank)
    dst_name = "worker{}".format(dst_rank)

    # Initialize RPC.
    rpc.init_rpc(
        name=name,
        rank=rank,
        world_size=world_size
    )

    # Use a distributed autograd context.
    with dist_autograd.context() as context_id:
        # Forward pass (create references on remote nodes).
        rref1 = rpc.remote(dst_name, random_tensor)
        rref2 = rpc.remote(dst_name, random_tensor)
        loss = rref1.to_here() + rref2.to_here()

        # Backward pass (run distributed autograd).
        dist_autograd.backward(context_id, [loss.sum()])

        # Build DistributedOptimizer.
        dist_optim = DistributedOptimizer(
        optim.SGD,
        [rref1, rref2],
        lr=0.05,
        )

        # Run the distributed optimizer step.
        dist_optim.step(context_id)

def run_process(rank, world_size):
    dst_rank = (rank + 1) % world_size
    _run_process(rank, dst_rank, world_size)
    rpc.shutdown()

if __name__ == '__main__':
  # Run world_size workers
  world_size = 2
  mp.spawn(run_process, args=(world_size,), nprocs=world_size)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源