快捷方式

分布式 Autograd 设计

本文档将详细介绍分布式 autograd 的设计,并深入探讨其内部工作原理。在继续之前,请确保您熟悉 Autograd 机制分布式 RPC 框架

背景

假设您有两个节点,并且有一个跨这两个节点划分的非常简单的模型。这可以使用 torch.distributed.rpc 来实现,如下所示:

import torch
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)

# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)

# Compute some loss.
loss = t5.sum()

分布式 autograd 的主要动机是能够在此类分布式模型上,使用我们计算出的 loss 执行反向传播,并为所有需要梯度的张量记录相应的梯度。

前向传播期间的 Autograd 记录

PyTorch 在前向传播期间构建 autograd 图,该图用于执行反向传播。有关更多详细信息,请参阅 autograd 如何编码历史记录

对于分布式 autograd,我们需要在前向传播期间跟踪所有 RPC,以确保正确执行反向传播。为此,当执行 RPC 时,我们将 sendrecv 函数附加到 autograd 图中。

  • The send 函数附加到 RPC 的源节点,其输出边指向 RPC 输入张量的 autograd 函数。反向传播期间此函数的输入从目标节点接收,作为相应 recv 函数的输出。

  • The recv 函数附加到 RPC 的目标节点,其输入是从目标节点使用输入张量执行的算子中检索的。此函数的输出梯度在反向传播期间发送到源节点,发送给相应的 send 函数。

  • 每个 send-recv 对都分配一个全局唯一的 autograd_message_id 来唯一标识该对。这对于在反向传播期间查找远程节点上的相应函数非常有用。

  • 对于 RRef,每当我们调用 torch.distributed.rpc.RRef.to_here() 时,都会为相关的张量附加一个相应的 send-recv 对。

例如,上面示例的 autograd 图看起来像这样(为了简单起见,t5.sum() 已省略)

../_images/send_recv_functions.png

分布式 Autograd 上下文

每个使用分布式 autograd 的前向传播和反向传播都会分配一个唯一的 torch.distributed.autograd.context,并且此上下文具有一个全局唯一的 autograd_context_id。此上下文在每个节点上按需创建。

此上下文具有以下用途

  1. 运行分布式反向传播的多个节点可能会在同一个张量上累积梯度,因此在有机会运行优化器之前,张量的 .grad 字段将包含来自各种分布式反向传播的梯度。这类似于在本地多次调用 torch.autograd.backward()。为了提供一种分离每个反向传播梯度的方法,梯度会累积到每个反向传播的 torch.distributed.autograd.context 中。

  2. 在前向传播期间,我们将每个 autograd 传递的 sendrecv 函数存储在此上下文中。这确保我们持有 autograd 图中相应节点的引用,以保持其存活。除此之外,在反向传播期间,很容易查找相应的 sendrecv 函数。

  3. 通常,我们还使用此上下文存储每个分布式 autograd 传递的一些元数据。


从用户的角度来看,autograd 上下文设置如下

import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
  loss = model.forward()
  dist_autograd.backward(context_id, loss)

重要的是要注意,您的模型的前向传播必须在分布式 autograd 上下文管理器中调用,因为需要一个有效的上下文来确保所有 sendrecv 函数都正确存储,以便在所有参与节点上运行反向传播。

分布式反向传播

在本节中,我们概述了在分布式反向传播期间准确计算依赖关系的挑战,并描述了几种算法(包含权衡)来执行分布式反向传播。

计算依赖关系

考虑在单机上运行的以下代码片段

import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()

上述代码的 autograd 图如下所示

../_images/local_dependencies.png

autograd 引擎在反向传播中执行的第一步是计算 autograd 图中每个节点的依赖关系数量。这有助于 autograd 引擎了解图中的节点何时可以执行。add(1)mul(0) 中的括号中的数字表示依赖关系数量。正如您所见,这意味着在反向传播期间,add 节点需要 1 个输入,而 mul 节点不需要任何输入(换句话说,不需要执行)。本地 autograd 引擎通过从根节点(在本例中为 d)遍历图来计算这些依赖关系。

autograd 图中的某些节点在反向传播中可能不执行,这给分布式 autograd 带来了挑战。考虑这个使用 RPC 的代码片段。

import torch
import torch.distributed.rpc as rpc

a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)

d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()

上述代码关联的 autograd 图如下所示

../_images/distributed_dependencies.png

计算此分布式 autograd 图的依赖关系更具挑战性,并且需要一些开销(无论是计算还是网络通信)。

对于对性能敏感的应用,我们可以通过假设每个 sendrecv 函数在反向传播中都有效(大多数应用不会执行未使用的 RPC)来避免大量开销。这简化了分布式 autograd 算法,并且效率更高,但代价是应用需要了解其局限性。该算法称为 FAST 模式算法,下面将详细描述。

在一般情况下,并非每个 sendrecv 函数在反向传播中都是必需的。为了解决这个问题,我们提出了 SMART 模式算法,在后面的章节中进行描述。请注意,目前仅实现了 FAST 模式算法。

FAST 模式算法

该算法的关键假设是,在运行反向传播时,每个 send 函数的依赖关系数量为 1。换句话说,我们假设将从另一个节点通过 RPC 接收到梯度。

算法如下

  1. 我们从拥有反向传播根节点的 worker 开始(所有根节点必须是本地的)。

  2. 查找当前 分布式 Autograd 上下文的所有 send 函数。

  3. 从提供的根节点和所有检索到的 send 函数开始,在本地计算依赖关系。

  4. 计算依赖关系后,使用提供的根节点启动本地 autograd 引擎。

  5. 当 autograd 引擎执行 recv 函数时,recv 函数通过 RPC 将输入梯度发送到相应的 worker。每个 recv 函数都知道目标 worker ID,因为它在前向传播中被记录下来。recv 函数还会将 autograd_context_idautograd_message_id 发送到远程主机。

  6. 当远程主机收到此请求时,我们使用 autograd_context_idautograd_message_id 查找相应的 send 函数。

  7. 如果这是 worker 第一次收到给定 autograd_context_id 的请求,它将按照上面第 1-3 点所述在本地计算依赖关系。

  8. 在第 6 点中检索到的 send 函数随后在该 worker 的本地 autograd 引擎上排队等待执行。

  9. 最后,我们将梯度分别累积到每个 分布式 Autograd 上下文中,而不是累积到 Tensor 的 .grad 字段上。梯度存储在一个 Dict[Tensor, Tensor] 中,这基本上是将 Tensor 映射到其关联梯度的映射,并且可以使用 get_gradients() API 检索此映射。


例如,使用分布式 autograd 的完整代码如下所示

import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:

# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
  t1 = torch.rand((3, 3), requires_grad=True)
  t2 = torch.rand((3, 3), requires_grad=True)

  # Perform some computation remotely.
  t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

  # Perform some computation locally based on remote result.
  t4 = torch.rand((3, 3), requires_grad=True)
  t5 = torch.mul(t3, t4)

  # Compute some loss.
  loss = t5.sum()

  # Run the backward pass.
  dist_autograd.backward(context_id, [loss])

  # Retrieve the gradients from the context.
  dist_autograd.get_gradients(context_id)

带有依赖关系的分布式 autograd 图如下所示(为了简单起见,t5.sum() 已省略)

../_images/distributed_dependencies_computed.png

FAST 模式算法应用于上述示例,步骤如下

  1. Worker 0 上,我们从根节点 losssend1 开始计算依赖关系。因此,send1 的依赖关系标记为 1,Worker 0 上的 mul 的依赖关系标记为 1。

  2. 现在,我们在 Worker 0 上启动本地 autograd 引擎。我们首先执行 mul 函数,将其输出在 autograd 上下文中累积为 t4 的梯度。然后,我们执行 recv2,它将梯度发送到 Worker 1

  3. 由于这是 Worker 1 第一次听说此反向传播,它开始计算依赖关系并适当地标记 send2addrecv1 的依赖关系。

  4. 接下来,我们在 Worker 1 的本地 autograd 引擎上将 send2 入队,send2 随后执行 addrecv1

  5. 当执行 recv1 时,它会将梯度发送到 Worker 0

  6. 由于 Worker 0 已经计算了此反向传播的依赖关系,它只是在本地将 send1 入队并执行。

  7. 最后,t1t2t4 的梯度累积到 分布式 Autograd 上下文中。

SMART 模式算法

该算法的详细信息仍在完善中,但基本思想可以参考 RFC 中的“分布式 Autograd 算法智能模式”部分。

分布式优化器

The DistributedOptimizer 的工作原理如下

  1. 接受一个要优化的远程参数列表(RRef)。这些也可以是包含在本地 RRef 中的本地参数。

  2. 接受一个 Optimizer 类作为本地优化器,在所有不同的 RRef 所有者上运行。

  3. 分布式优化器在每个 worker 节点上创建本地 Optimizer 的实例,并持有它们的 RRef

  4. 调用 torch.distributed.optim.DistributedOptimizer.step() 时,分布式优化器使用 RPC 远程执行所有本地优化器在相应的远程 worker 上。必须将分布式 autograd context_id 作为输入提供给 torch.distributed.optim.DistributedOptimizer.step()。本地优化器使用它来应用存储在相应上下文中的梯度。

  5. 如果多个并发分布式优化器正在更新 worker 上的相同参数,这些更新会通过锁进行序列化。

简单端到端示例

将所有部分组合起来,以下是一个使用分布式 autograd 和分布式优化器的简单端到端示例。如果将代码保存在名为“dist_autograd_simple.py”的文件中,可以使用命令 MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py 运行。

import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer

def random_tensor():
    return torch.rand((3, 3), requires_grad=True)

def _run_process(rank, dst_rank, world_size):
    name = "worker{}".format(rank)
    dst_name = "worker{}".format(dst_rank)

    # Initialize RPC.
    rpc.init_rpc(
        name=name,
        rank=rank,
        world_size=world_size
    )

    # Use a distributed autograd context.
    with dist_autograd.context() as context_id:
        # Forward pass (create references on remote nodes).
        rref1 = rpc.remote(dst_name, random_tensor)
        rref2 = rpc.remote(dst_name, random_tensor)
        loss = rref1.to_here() + rref2.to_here()

        # Backward pass (run distributed autograd).
        dist_autograd.backward(context_id, [loss.sum()])

        # Build DistributedOptimizer.
        dist_optim = DistributedOptimizer(
        optim.SGD,
        [rref1, rref2],
        lr=0.05,
        )

        # Run the distributed optimizer step.
        dist_optim.step(context_id)

def run_process(rank, world_size):
    dst_rank = (rank + 1) % world_size
    _run_process(rank, dst_rank, world_size)
    rpc.shutdown()

if __name__ == '__main__':
  # Run world_size workers
  world_size = 2
  mp.spawn(run_process, args=(world_size,), nprocs=world_size)

文档

访问 PyTorch 全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源