• 文档 >
  • 多进程最佳实践
快捷方式

多进程最佳实践

torch.multiprocessing 是 Python 的 multiprocessing 模块的直接替代品。它支持完全相同的操作,但进行了扩展,因此通过 multiprocessing.Queue 发送的所有张量都将把数据移动到共享内存中,并且只会向另一个进程发送句柄。

注意

当一个 Tensor 被发送到另一个进程时,Tensor 数据是共享的。如果 torch.Tensor.grad 不为 None,它也将被共享。在没有 torch.Tensor.grad 字段的 Tensor 被发送到另一个进程之后,它会创建一个标准的进程特定 .grad Tensor,它不像 Tensor 的数据那样,不会在所有进程之间自动共享。

这允许实现各种训练方法,例如 Hogwild、A3C 或任何其他需要异步操作的方法。

多进程中的 CUDA

CUDA 运行时不支持 fork 启动方法;在子进程中使用 CUDA 需要 spawnforkserver 启动方法。

注意

可以通过以下两种方法之一设置启动方法:使用 multiprocessing.get_context(...) 创建上下文,或者直接使用 multiprocessing.set_start_method(...)

与 CPU 张量不同,发送进程需要保留原始张量,只要接收进程保留该张量的副本。它是幕后实现的,但需要用户遵循最佳实践以确保程序能够正确运行。例如,发送进程必须在消费者进程具有对张量的引用时保持活动状态,并且如果消费者进程通过致命信号异常退出,则引用计数无法为您提供帮助。请参阅 此部分

另请参阅:使用 nn.parallel.DistributedDataParallel 而不是 multiprocessing 或 nn.DataParallel

最佳实践和技巧

避免和解决死锁

在生成新进程时,有很多事情可能会出错,死锁的最常见原因是后台线程。如果任何线程持有锁或导入模块,并且调用了 fork,则子进程很可能处于损坏状态,并且会死锁或以其他方式失败。请注意,即使您没有这样做,Python 内置库也会这样做 - 不需要进一步研究 multiprocessingmultiprocessing.Queue 实际上是一个非常复杂的类,它会生成多个线程来序列化、发送和接收对象,它们也会导致上述问题。如果您发现自己处于这种情况,请尝试使用 SimpleQueue,它不使用任何额外的线程。

我们正在尽力为您提供便利,确保这些死锁不会发生,但有些事情超出了我们的控制范围。如果您遇到一些无法解决的错误,请尝试在论坛上寻求帮助,我们将看看是否可以解决它。

重复使用通过队列传递的缓冲区

请记住,每次将 Tensor 放入 multiprocessing.Queue 中时,它都必须移动到共享内存中。如果它已经共享,则这是一个无操作,否则它将导致额外的内存复制,这会减慢整个过程。即使您有一组进程将数据发送到单个进程,也要让它将缓冲区发送回来 - 这几乎是免费的,并且将避免在发送下一批时进行复制。

异步多进程训练(例如 Hogwild)

使用 torch.multiprocessing,可以异步训练模型,参数要么始终共享,要么定期同步。在第一种情况下,我们建议发送整个模型对象,而在第二种情况下,我们建议只发送 state_dict()

我们建议使用 multiprocessing.Queue 在进程之间传递各种 PyTorch 对象。例如,在使用 fork 启动方法时,可以继承已经存在于共享内存中的张量和存储,但是这种方法很容易出现错误,应该谨慎使用,并且只应该由高级用户使用。队列,即使它们有时不是一个优雅的解决方案,但在所有情况下都能正常工作。

警告

您应该注意全局语句,这些语句没有用 if __name__ == '__main__' 保护。如果使用不同于 fork 的启动方法,它们将在所有子进程中执行。

Hogwild

可以在 示例仓库 中找到一个具体的 Hogwild 实现,但为了展示代码的整体结构,下面也给出一个简化的示例

import torch.multiprocessing as mp
from model import MyModel

def train(model):
    # Construct data_loader, optimizer, etc.
    for data, labels in data_loader:
        optimizer.zero_grad()
        loss_fn(model(data), labels).backward()
        optimizer.step()  # This will update the shared parameters

if __name__ == '__main__':
    num_processes = 4
    model = MyModel()
    # NOTE: this is required for the ``fork`` method to work
    model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(model,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

多进程中的 CPU

不适当的多进程会导致 CPU 超额订阅,导致不同的进程争夺 CPU 资源,导致效率低下。

本教程将解释什么是 CPU 超额订阅以及如何避免它。

CPU 超额订阅

CPU 超额订阅是一个技术术语,指的是分配给系统的 vCPU 总数超过硬件上可用的 vCPU 总数的情况。

这会导致对 CPU 资源的严重争夺。在这种情况下,进程之间会频繁切换,这会增加进程切换开销并降低系统整体效率。

示例仓库 中的 Hogwild 实现中,查看代码示例了解 CPU 超额订阅。

在 CPU 上使用 4 个进程运行以下命令时,训练示例

python main.py --num-processes 4

假设机器上有 N 个 vCPU,执行上面的命令将生成 4 个子进程。每个子进程将为自身分配 N 个 vCPU,从而需要 4*N 个 vCPU。但是,机器上只有 N 个 vCPU 可用。因此,不同的进程将争夺资源,导致频繁的进程切换。

以下观察结果表明存在 CPU 超额订阅

  1. 高 CPU 利用率:通过使用 htop 命令,您可以观察到 CPU 利用率始终很高,通常达到或超过其最大容量。这表明对 CPU 资源的需求超过了可用的物理内核,导致进程之间争夺 CPU 时间。

  2. 频繁的上下文切换,系统效率低下:在 CPU 超额订阅的情况下,进程争夺 CPU 时间,操作系统需要快速地在不同进程之间切换以公平地分配资源。这种频繁的上下文切换会增加开销并降低系统整体效率。

避免 CPU 超额订阅

避免 CPU 超额订阅的一个好方法是适当的资源分配。确保同时运行的进程或线程数量不超过可用的 CPU 资源。

在这种情况下,一个解决方案是在子进程中指定适当的线程数量。这可以通过使用子进程中的 torch.set_num_threads(int) 函数为每个进程设置线程数量来实现。

假设机器上有 N 个 vCPU,将生成 M 个进程,每个进程使用的最大 num_threads 值将是 floor(N/M)。为了避免在 mnist_hogwild 示例中出现 CPU 超额订阅,需要对 示例仓库 中的 train.py 文件进行以下更改。

def train(rank, args, model, device, dataset, dataloader_kwargs):
    torch.manual_seed(args.seed + rank)

    #### define the num threads used in current sub-processes
    torch.set_num_threads(floor(N/M))

    train_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    for epoch in range(1, args.epochs + 1):
        train_epoch(epoch, args, model, device, train_loader, optimizer)

使用 torch.set_num_threads(floor(N/M)) 为每个进程设置 num_thread。您需要用可用的 vCPU 数量替换 N,用选择的进程数量替换 M。适当的 num_thread 值将根据具体任务而有所不同。但是,作为一个一般性指导,num_thread 的最大值应该是 floor(N/M),以避免 CPU 超额订阅。在 mnist_hogwild 训练示例中,在避免 CPU 超额订阅后,您可以实现 30 倍的性能提升。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源