• 教程 >
  • 使用异步执行实现批处理 RPC 处理
快捷方式

使用异步执行实现批处理 RPC 处理

作者: 沈立

注意

editgithub 中查看和编辑本教程。

先决条件

本教程演示了如何使用 @rpc.functions.async_execution 装饰器构建批处理 RPC 应用程序,该装饰器有助于通过减少阻塞 RPC 线程数量和合并被调用者的 CUDA 操作来加快训练速度。这与 TorchServe 的批处理推理 的理念相同。

注意

本教程需要 PyTorch v1.6.0 或更高版本。

基础

之前的教程展示了使用 torch.distributed.rpc 构建分布式训练应用程序的步骤,但它们没有详细说明在处理 RPC 请求时被调用方会发生什么。从 PyTorch v1.5 开始,每个 RPC 请求都会阻塞被调用方的一个线程来执行该请求中的函数,直到该函数返回。这适用于许多用例,但有一个警告。如果用户函数阻塞在 IO 上,例如嵌套 RPC 调用或信号,例如等待另一个 RPC 请求解除阻塞,被调用方上的 RPC 线程将不得不空闲等待直到 IO 完成或信号事件发生。结果,RPC 被调用方可能会使用比必要更多的线程。这个问题的原因是,RPC 将用户函数视为黑盒,并且对函数中发生的事情知之甚少。为了允许用户函数屈服并释放 RPC 线程,需要向 RPC 系统提供更多提示。

从 v1.6.0 开始,PyTorch 通过引入两个新概念来解决这个问题

  • 一个 torch.futures.Future 类型,它封装了异步执行,也支持安装回调函数。

  • 一个 @rpc.functions.async_execution 装饰器,允许应用程序告诉被调用方目标函数将返回一个 future 并且可以在执行期间多次暂停和屈服。

使用这两个工具,应用程序代码可以将用户函数分解为多个较小的函数,将它们链接在一起作为 Future 对象上的回调,并返回包含最终结果的 Future。在被调用方,当获取 Future 对象时,它还会安装后续的 RPC 响应准备和通信作为回调,这些回调将在最终结果准备就绪时触发。这样,被调用方就不需要再阻塞一个线程,并等待直到最终返回值准备就绪。有关简单示例,请参阅 @rpc.functions.async_execution 的 API 文档。

除了减少被调用方上的空闲线程数量之外,这些工具还有助于使批处理 RPC 处理更容易和更快。本教程的以下两节将演示如何使用 @rpc.functions.async_execution 装饰器构建分布式批处理更新参数服务器和批处理强化学习应用程序。

批处理更新参数服务器

考虑一个具有一个参数服务器 (PS) 和多个训练器的同步参数服务器训练应用程序。在这个应用程序中,PS 保存参数并等待所有训练器报告梯度。在每次迭代中,它会等到从所有训练器收到梯度,然后一次性更新所有参数。以下代码显示了 PS 类实现。update_and_fetch_model 方法使用 @rpc.functions.async_execution 进行装饰,并将被训练器调用。每次调用都会返回一个 Future 对象,该对象将用更新的模型填充。大多数训练器启动的调用只是将梯度累积到 .grad 字段,立即返回并屈服于 PS 上的 RPC 线程。最后一个到达的训练器将触发优化器步骤并消耗所有先前报告的梯度。然后它使用更新的模型设置 future_model,这反过来又通过 Future 对象通知其他训练器之前的所有请求,并将更新的模型发送给所有训练器。

import threading
import torchvision
import torch
import torch.distributed.rpc as rpc
from torch import optim

num_classes, batch_update_size = 30, 5

class BatchUpdateParameterServer(object):
    def __init__(self, batch_update_size=batch_update_size):
        self.model = torchvision.models.resnet50(num_classes=num_classes)
        self.lock = threading.Lock()
        self.future_model = torch.futures.Future()
        self.batch_update_size = batch_update_size
        self.curr_update_size = 0
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        for p in self.model.parameters():
            p.grad = torch.zeros_like(p)

    def get_model(self):
        return self.model

    @staticmethod
    @rpc.functions.async_execution
    def update_and_fetch_model(ps_rref, grads):
        # Using the RRef to retrieve the local PS instance
        self = ps_rref.local_value()
        with self.lock:
            self.curr_update_size += 1
            # accumulate gradients into .grad field
            for p, g in zip(self.model.parameters(), grads):
                p.grad += g

            # Save the current future_model and return it to make sure the
            # returned Future object holds the correct model even if another
            # thread modifies future_model before this thread returns.
            fut = self.future_model

            if self.curr_update_size >= self.batch_update_size:
                # update the model
                for p in self.model.parameters():
                    p.grad /= self.batch_update_size
                self.curr_update_size = 0
                self.optimizer.step()
                self.optimizer.zero_grad()
                # by settiing the result on the Future object, all previous
                # requests expecting this updated model will be notified and
                # the their responses will be sent accordingly.
                fut.set_result(self.model)
                self.future_model = torch.futures.Future()

        return fut

对于训练器,它们都使用 PS 中的相同参数集进行初始化。在每次迭代中,每个训练器首先运行前向和反向传递以在本地生成梯度。然后,每个训练器使用 RPC 将其梯度报告给 PS,并通过相同 RPC 请求的返回值获取更新的参数。在训练器的实现中,目标函数是否使用 @rpc.functions.async_execution 进行标记并没有什么区别。训练器只需使用 rpc_sync 调用 update_and_fetch_model,这将阻塞训练器,直到返回更新的模型。

batch_size, image_w, image_h  = 20, 64, 64

class Trainer(object):
    def __init__(self, ps_rref):
        self.ps_rref, self.loss_fn = ps_rref, torch.nn.MSELoss()
        self.one_hot_indices = torch.LongTensor(batch_size) \
                                    .random_(0, num_classes) \
                                    .view(batch_size, 1)

    def get_next_batch(self):
        for _ in range(6):
            inputs = torch.randn(batch_size, 3, image_w, image_h)
            labels = torch.zeros(batch_size, num_classes) \
                        .scatter_(1, self.one_hot_indices, 1)
            yield inputs.cuda(), labels.cuda()

    def train(self):
        name = rpc.get_worker_info().name
        # get initial model parameters
        m = self.ps_rref.rpc_sync().get_model().cuda()
        # start training
        for inputs, labels in self.get_next_batch():
            self.loss_fn(m(inputs), labels).backward()
            m = rpc.rpc_sync(
                self.ps_rref.owner(),
                BatchUpdateParameterServer.update_and_fetch_model,
                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
            ).cuda()

在本教程中,我们跳过了启动多个进程的代码,请参阅 examples 仓库以获取完整的实现。请注意,可以在没有 @rpc.functions.async_execution 装饰器的情况下实现批处理。但是,这将需要在 PS 上阻塞更多 RPC 线程,或者使用另一轮 RPC 来获取更新的模型,而后者将增加更多代码复杂性和更多通信开销。

本节使用简单的参数服务器训练示例来展示如何使用 @rpc.functions.async_execution 装饰器实现批处理 RPC 应用程序。在下一节中,我们将使用批处理重新实现先前 Getting started with Distributed RPC Framework 教程中的强化学习示例,并展示它对训练速度的影响。

批处理 CartPole 求解器

本节使用来自 OpenAI Gym 的 CartPole-v1 作为示例来展示批处理 RPC 的性能影响。请注意,由于目标是演示 @rpc.functions.async_execution 的使用,而不是构建最好的 CartPole 求解器或解决大多数不同的 RL 问题,因此我们使用非常简单的策略和奖励计算策略,并专注于多观察者单代理批处理 RPC 实现。我们使用类似于先前教程中的 Policy 模型,如下所示。与先前教程相比,区别在于它的构造函数接受一个额外的 batch 参数,它控制 F.softmaxdim 参数,因为批处理后,forward 函数中的 x 参数包含来自多个观察者的状态,因此维度需要正确更改。其他一切保持不变。

import argparse
import torch.nn as nn
import torch.nn.functional as F

parser = argparse.ArgumentParser(description='PyTorch RPC Batch RL example')
parser.add_argument('--gamma', type=float, default=1.0, metavar='G',
                    help='discount factor (default: 1.0)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
                    help='random seed (default: 543)')
parser.add_argument('--num-episode', type=int, default=10, metavar='E',
                    help='number of episodes (default: 10)')
args = parser.parse_args()

torch.manual_seed(args.seed)

class Policy(nn.Module):
    def __init__(self, batch=True):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)
        self.dim = 2 if batch else 1

    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=self.dim)

Observer 的构造函数也相应调整。它还接受一个 batch 参数,它控制它使用哪个 Agent 函数来选择动作。在批处理模式下,它调用 Agent 上的 select_action_batch 函数,该函数将在稍后介绍,并且该函数将使用 @rpc.functions.async_execution 进行装饰。

import gym
import torch.distributed.rpc as rpc

class Observer:
    def __init__(self, batch=True):
        self.id = rpc.get_worker_info().id - 1
        self.env = gym.make('CartPole-v1')
        self.env.seed(args.seed)
        self.select_action = Agent.select_action_batch if batch else Agent.select_action

与先前的教程 Getting started with Distributed RPC Framework 相比,观察者行为略有不同。它不会在环境停止时退出,而是始终在每一集中运行 n_steps 次迭代。当环境返回时,观察者只需重置环境并重新开始。通过这种设计,代理将从每个观察者接收固定数量的状态,因此可以将它们打包到一个固定大小的张量中。在每一步中,Observer 使用 RPC 将其状态发送给 Agent 并通过返回值获取动作。在每集结束时,它将所有步骤的奖励返回给 Agent。请注意,此 run_episode 函数将被 Agent 使用 RPC 调用。因此,此函数中的 rpc_sync 调用将是嵌套的 RPC 调用。我们可以将此函数标记为 @rpc.functions.async_execution,以避免阻塞 Observer 上的一个线程。但是,由于瓶颈是 Agent 而不是 Observer,因此在 Observer 进程上阻塞一个线程应该没问题。

import torch

class Observer:
    ...

    def run_episode(self, agent_rref, n_steps):
        state, ep_reward = self.env.reset(), NUM_STEPS
        rewards = torch.zeros(n_steps)
        start_step = 0
        for step in range(n_steps):
            state = torch.from_numpy(state).float().unsqueeze(0)
            # send the state to the agent to get an action
            action = rpc.rpc_sync(
                agent_rref.owner(),
                self.select_action,
                args=(agent_rref, self.id, state)
            )

            # apply the action to the environment, and get the reward
            state, reward, done, _ = self.env.step(action)
            rewards[step] = reward

            if done or step + 1 >= n_steps:
                curr_rewards = rewards[start_step:(step + 1)]
                R = 0
                for i in range(curr_rewards.numel() -1, -1, -1):
                    R = curr_rewards[i] + args.gamma * R
                    curr_rewards[i] = R
                state = self.env.reset()
                if start_step == 0:
                    ep_reward = min(ep_reward, step - start_step + 1)
                start_step = step + 1

        return [rewards, ep_reward]

Agent 的构造函数也接受一个 batch 参数,它控制动作概率的批处理方式。在批处理模式下,saved_log_probs 包含一个张量列表,其中每个张量包含一步中所有观察者的动作概率。在没有批处理的情况下,saved_log_probs 是一个字典,其中键是观察者 ID,值是该观察者的动作概率列表。

import threading
from torch.distributed.rpc import RRef

class Agent:
    def __init__(self, world_size, batch=True):
        self.ob_rrefs = []
        self.agent_rref = RRef(self)
        self.rewards = {}
        self.policy = Policy(batch).cuda()
        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
        self.running_reward = 0

        for ob_rank in range(1, world_size):
            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
            self.ob_rrefs.append(rpc.remote(ob_info, Observer, args=(batch,)))
            self.rewards[ob_info.id] = []

        self.states = torch.zeros(len(self.ob_rrefs), 1, 4)
        self.batch = batch
        self.saved_log_probs = [] if batch else {k:[] for k in range(len(self.ob_rrefs))}
        self.future_actions = torch.futures.Future()
        self.lock = threading.Lock()
        self.pending_states = len(self.ob_rrefs)

非批处理 select_acion 只需将状态运行到策略中,保存动作概率并立即将动作返回给观察者。

from torch.distributions import Categorical

class Agent:
    ...

    @staticmethod
    def select_action(agent_rref, ob_id, state):
        self = agent_rref.local_value()
        probs = self.policy(state.cuda())
        m = Categorical(probs)
        action = m.sample()
        self.saved_log_probs[ob_id].append(m.log_prob(action))
        return action.item()

使用批处理,状态将存储在一个二维张量 self.states 中,使用观察者 ID 作为行 ID。然后,它通过安装一个回调函数到批处理生成的 self.future_actions Future 对象来链接一个 Future,该对象将使用该观察者的 ID 填充特定行索引。最后一个到达的观察者将所有批处理状态一次性运行到策略中,并相应地设置 self.future_actions。当发生这种情况时,安装在 self.future_actions 上的所有回调函数将被触发,它们的返回值将用于填充链接的 Future 对象,这反过来又通知 Agent 为所有先前来自其他观察者的 RPC 请求准备和通信响应。

class Agent:
    ...

    @staticmethod
    @rpc.functions.async_execution
    def select_action_batch(agent_rref, ob_id, state):
        self = agent_rref.local_value()
        self.states[ob_id].copy_(state)
        future_action = self.future_actions.then(
            lambda future_actions: future_actions.wait()[ob_id].item()
        )

        with self.lock:
            self.pending_states -= 1
            if self.pending_states == 0:
                self.pending_states = len(self.ob_rrefs)
                probs = self.policy(self.states.cuda())
                m = Categorical(probs)
                actions = m.sample()
                self.saved_log_probs.append(m.log_prob(actions).t()[0])
                future_actions = self.future_actions
                self.future_actions = torch.futures.Future()
                future_actions.set_result(actions.cpu())
        return future_action

现在让我们定义如何将不同的 RPC 函数缝合在一起。Agent 控制每一集的执行。它首先使用 rpc_async 在所有观察者上启动集并阻塞返回的 futures,这些 futures 将用观察者奖励填充。请注意,以下代码使用 RRef 帮助程序 ob_rref.rpc_async()ob_rref RRef 的所有者上启动 run_episode 函数,并提供给定的参数。然后它将保存的动作概率和返回的观察者奖励转换为预期的格式,并启动训练步骤。最后,它重置所有状态并返回当前集的奖励。此函数是运行一集的入口点。

class Agent:
    ...

    def run_episode(self, n_steps=0):
        futs = []
        for ob_rref in self.ob_rrefs:
            # make async RPC to kick off an episode on all observers
            futs.append(ob_rref.rpc_async().run_episode(self.agent_rref, n_steps))

        # wait until all obervers have finished this episode
        rets = torch.futures.wait_all(futs)
        rewards = torch.stack([ret[0] for ret in rets]).cuda().t()
        ep_rewards = sum([ret[1] for ret in rets]) / len(rets)

        # stack saved probs into one tensor
        if self.batch:
            probs = torch.stack(self.saved_log_probs)
        else:
            probs = [torch.stack(self.saved_log_probs[i]) for i in range(len(rets))]
            probs = torch.stack(probs)

        policy_loss = -probs * rewards / len(rets)
        policy_loss.sum().backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        # reset variables
        self.saved_log_probs = [] if self.batch else {k:[] for k in range(len(self.ob_rrefs))}
        self.states = torch.zeros(len(self.ob_rrefs), 1, 4)

        # calculate running rewards
        self.running_reward = 0.5 * ep_rewards + 0.5 * self.running_reward
        return ep_rewards, self.running_reward

其余代码为正常的进程启动和日志记录,与其他 RPC 教程类似。在本教程中,所有观察者都被动地等待来自代理的命令。有关完整实现,请参阅 examples 仓库。

def run_worker(rank, world_size, n_episode, batch, print_log=True):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 0:
        # rank0 is the agent
        rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size)

        agent = Agent(world_size, batch)
        for i_episode in range(n_episode):
            last_reward, running_reward = agent.run_episode(n_steps=NUM_STEPS)

            if print_log:
                print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
                    i_episode, last_reward, running_reward))
    else:
        # other ranks are the observer
        rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
        # observers passively waiting for instructions from agents
    rpc.shutdown()


def main():
    for world_size in range(2, 12):
        delays = []
        for batch in [True, False]:
            tik = time.time()
            mp.spawn(
                run_worker,
                args=(world_size, args.num_episode, batch),
                nprocs=world_size,
                join=True
            )
            tok = time.time()
            delays.append(tok - tik)

        print(f"{world_size}, {delays[0]}, {delays[1]}")


if __name__ == '__main__':
    main()

批量 RPC 有助于将动作推断合并成更少的 CUDA 操作,从而降低了摊销开销。上面的 main 函数使用不同的观察者数量(从 1 到 10)在批量和非批量模式下运行相同的代码。下图显示了使用默认参数值的不同世界规模的执行时间。结果证实了我们的预期,即批量处理有助于加速训练。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的答案

查看资源