快捷方式

分布式自动微分设计

本笔记将介绍分布式自动微分的详细设计,并逐步介绍其内部机制。在继续之前,请确保您熟悉 自动微分机制分布式 RPC 框架

背景

假设您有两个节点,并且一个非常简单的模型分布在两个节点上。这可以使用 torch.distributed.rpc 如下实现

import torch
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)

# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)

# Compute some loss.
loss = t5.sum()

分布式自动微分的核心动机是,能够在计算得到的loss上对这种分布式模型执行反向传播,并为所有需要梯度的张量记录相应的梯度。

前向传播过程中的自动微分记录

PyTorch 在前向传播过程中构建自动微分图,该图用于执行反向传播。有关更多详细信息,请参阅自动微分如何编码历史

对于分布式自动微分,我们需要跟踪前向传播过程中的所有 RPC,以确保反向传播能够正确执行。为此,我们在执行 RPC 时将sendrecv函数附加到自动微分图。

  • send函数附加到 RPC 的源,其输出边指向 RPC 输入张量的自动微分函数。反向传播过程中该函数的输入从目标接收,作为相应recv函数的输出。

  • recv函数附加到 RPC 的目标,其输入从使用输入张量在目标上执行的操作符中获取。该函数的输出梯度在反向传播过程中发送到源节点的相应send函数。

  • 每个send-recv对都分配了一个全局唯一的autograd_message_id,以唯一标识该对。这对于在反向传播过程中查找远程节点上的相应函数很有用。

  • 对于RRef,每当我们调用torch.distributed.rpc.RRef.to_here()时,我们都会为所涉及的张量附加一个适当的send-recv对。

例如,这是我们上面示例的自动微分图的样子(为简化起见,排除了 t5.sum())

../_images/send_recv_functions.png

分布式自动微分上下文

每个使用分布式自动微分的正向和反向传递都被分配了一个唯一的 torch.distributed.autograd.context,并且此上下文具有全局唯一的 autograd_context_id。此上下文根据需要在每个节点上创建。

此上下文用于以下目的

  1. 多个运行分布式反向传递的节点可能会在同一个张量上累积梯度,因此在有机会运行优化器之前,张量的 .grad 字段将包含来自各种分布式反向传递的梯度。这类似于在本地多次调用 torch.autograd.backward()。为了提供一种分离每个反向传递梯度的方法,梯度在每个反向传递的 torch.distributed.autograd.context 中累积。

  2. 在正向传递期间,我们在此上下文中存储每个自动微分传递的 sendrecv 函数。这确保我们持有自动微分图中适当节点的引用以使其保持活动状态。除此之外,在反向传递期间很容易查找适当的 sendrecv 函数。

  3. 通常,我们也使用此上下文来存储每个分布式自动微分传递的一些元数据。


从用户的角度来看,自动微分上下文设置如下

import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
  loss = model.forward()
  dist_autograd.backward(context_id, loss)

需要注意的是,您的模型的前向传播必须在分布式自动微分上下文管理器中调用,因为需要一个有效的上下文来确保所有 sendrecv 函数被正确存储,以便在所有参与节点上运行反向传播。

分布式反向传播

在本节中,我们将概述在分布式反向传播期间准确计算依赖关系的挑战,并描述几种算法(及其权衡),说明如何执行分布式反向传播。

计算依赖关系

考虑以下在单台机器上运行的代码片段

import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()

以下是上述代码的自动微分图

../_images/local_dependencies.png

自动微分引擎作为反向传播的一部分执行的第一步是计算自动微分图中每个节点的依赖关系数量。这有助于自动微分引擎了解图中的某个节点何时准备好执行。add(1)mul(0) 中的括号中的数字表示依赖关系的数量。如您所见,这意味着在反向传播期间,add 节点需要 1 个输入,而 mul 节点不需要任何输入(换句话说,不需要执行)。本地自动微分引擎通过从根节点(本例中为 d)遍历图来计算这些依赖关系。

自动微分图中的某些节点可能不会在反向传播中执行这一事实给分布式自动微分带来了挑战。考虑以下使用 RPC 的代码片段。

import torch
import torch.distributed.rpc as rpc

a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)

d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()

上述代码的关联自动微分图将是

../_images/distributed_dependencies.png

计算此分布式自动微分图的依赖关系要困难得多,并且需要一些开销(无论是计算开销还是网络通信开销)。

对于对性能敏感的应用程序,我们可以通过假设每个 sendrecv 函数在反向传播中都是有效的来避免大量开销(大多数应用程序不会执行未使用的 RPC)。这简化了分布式自动微分算法,并且效率更高,但代价是应用程序需要了解这些限制。该算法称为 FAST 模式算法,将在下面详细介绍。

在一般情况下,并非每个 sendrecv 函数都需要在反向传播中有效。为了解决这个问题,我们提出了一种 SMART 模式算法,将在后面的章节中进行描述。请注意,目前只实现了 FAST 模式算法。

FAST 模式算法

该算法的关键假设是,每个 send 函数在反向传播时都具有 1 的依赖关系。换句话说,我们假设我们将通过 RPC 从另一个节点接收梯度。

该算法如下所示

  1. 我们从拥有反向传播根节点的 worker 开始(所有根节点都必须是本地的)。

  2. 查找当前 分布式自动微分上下文 中的所有 send 函数。

  3. 从提供的根节点和我们检索到的所有 send 函数开始,在本地计算依赖关系。

  4. 计算完依赖关系后,使用提供的根节点启动本地自动微分引擎。

  5. 当自动微分引擎执行 recv 函数时,recv 函数通过 RPC 将输入梯度发送到相应的 worker。每个 recv 函数都知道目标 worker ID,因为它是在正向传播过程中记录的。recv 函数还会将 autograd_context_idautograd_message_id 发送到远程主机。

  6. 当远程主机收到此请求时,我们使用 autograd_context_idautograd_message_id 来查找相应的 send 函数。

  7. 如果这是某个 worker 第一次收到针对给定 autograd_context_id 的请求,它将根据上面 1-3 点所述在本地计算依赖关系。

  8. 在第 6 步中检索到的 send 函数随后被排队以在该 worker 的本地 autograd 引擎上执行。

  9. 最后,我们不是在 Tensor 的 .grad 字段上累积梯度,而是根据每个 分布式 Autograd 上下文 分别累积梯度。梯度存储在一个 Dict[Tensor, Tensor] 中,它本质上是一个从 Tensor 到其关联梯度的映射,可以使用 get_gradients() API 检索此映射。


例如,使用分布式 autograd 的完整代码如下所示

import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc

def my_add(t1, t2):
  return torch.add(t1, t2)

# On worker 0:

# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
  t1 = torch.rand((3, 3), requires_grad=True)
  t2 = torch.rand((3, 3), requires_grad=True)

  # Perform some computation remotely.
  t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

  # Perform some computation locally based on remote result.
  t4 = torch.rand((3, 3), requires_grad=True)
  t5 = torch.mul(t3, t4)

  # Compute some loss.
  loss = t5.sum()

  # Run the backward pass.
  dist_autograd.backward(context_id, [loss])

  # Retrieve the gradients from the context.
  dist_autograd.get_gradients(context_id)

带有依赖关系的分布式 autograd 图如下所示(为简便起见,省略了 t5.sum())

../_images/distributed_dependencies_computed.png

FAST 模式算法 应用于上述示例,结果如下所示

  1. Worker 0 上,我们从根节点 losssend1 开始计算依赖关系。因此,send1 被标记为依赖关系 1,而 Worker 0 上的 mul 被标记为依赖关系 1。

  2. 现在,我们在 Worker 0 上启动本地 autograd 引擎。我们首先执行 mul 函数,将其输出累积到 autograd 上下文中,作为 t4 的梯度。然后,我们执行 recv2,它将梯度发送到 Worker 1

  3. 由于这是 Worker 1 第一次收到关于此反向传播的信息,它将开始依赖关系计算,并相应地标记 send2addrecv1 的依赖关系。

  4. 接下来,我们在 Worker 1 的本地 autograd 引擎上排队 send2,它依次执行 addrecv1

  5. 当执行 recv1 时,它将梯度发送到 Worker 0

  6. 由于 Worker 0 已经计算了此反向传播的依赖关系,它只是在本地排队并执行 send1

  7. 最后,t1t2t4 的梯度在 分布式自动微分上下文 中累积。

SMART 模式算法

该算法的完整细节仍在开发中,但对于一般概念,您可以参考 RFC 中的 分布式自动微分算法 SMART 模式 部分。

分布式优化器

DistributedOptimizer 的工作原理如下

  1. 接受要优化的远程参数列表 (RRef)。这些参数也可以是包装在本地 RRef 中的本地参数。

  2. 接受 Optimizer 类作为在所有不同的 RRef 所有者上运行的本地优化器。

  3. 分布式优化器在每个工作节点上创建一个本地 Optimizer 实例,并持有指向它们的 RRef

  4. 当调用 torch.distributed.optim.DistributedOptimizer.step() 时,分布式优化器使用 RPC 远程执行所有工作节点上的本地优化器。必须将分布式自动微分 context_id 作为输入提供给 torch.distributed.optim.DistributedOptimizer.step()。本地优化器使用它来应用存储在相应上下文中的梯度。

  5. 如果多个并发分布式优化器在工作节点上更新相同的参数,这些更新将通过锁进行序列化。

简单的端到端示例

将所有内容整合在一起,以下是一个使用分布式自动微分和分布式优化器的简单端到端示例。如果将代码放入名为“dist_autograd_simple.py”的文件中,则可以使用命令 MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py 运行它。

import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer

def random_tensor():
    return torch.rand((3, 3), requires_grad=True)

def _run_process(rank, dst_rank, world_size):
    name = "worker{}".format(rank)
    dst_name = "worker{}".format(dst_rank)

    # Initialize RPC.
    rpc.init_rpc(
        name=name,
        rank=rank,
        world_size=world_size
    )

    # Use a distributed autograd context.
    with dist_autograd.context() as context_id:
        # Forward pass (create references on remote nodes).
        rref1 = rpc.remote(dst_name, random_tensor)
        rref2 = rpc.remote(dst_name, random_tensor)
        loss = rref1.to_here() + rref2.to_here()

        # Backward pass (run distributed autograd).
        dist_autograd.backward(context_id, [loss.sum()])

        # Build DistributedOptimizer.
        dist_optim = DistributedOptimizer(
        optim.SGD,
        [rref1, rref2],
        lr=0.05,
        )

        # Run the distributed optimizer step.
        dist_optim.step(context_id)

def run_process(rank, world_size):
    dst_rank = (rank + 1) % world_size
    _run_process(rank, dst_rank, world_size)
    rpc.shutdown()

if __name__ == '__main__':
  # Run world_size workers
  world_size = 2
  mp.spawn(run_process, args=(world_size,), nprocs=world_size)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源