快捷方式

多 GPU 示例

数据并行是指我们将样本的小批量数据分成多个更小的批量数据,并并行运行每个小批量数据的计算。

数据并行是使用 torch.nn.DataParallel 实现的。可以使用 DataParallel 包装一个模块,它将在批量维度上并行化多个 GPU。

DataParallel

import torch
import torch.nn as nn


class DataParallelModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.block1 = nn.Linear(10, 20)

        # wrap block2 in DataParallel
        self.block2 = nn.Linear(20, 20)
        self.block2 = nn.DataParallel(self.block2)

        self.block3 = nn.Linear(20, 20)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

CPU 模式下的代码不需要更改。

DataParallel 的文档可以在 这里 找到。

包装模块的属性

使用 DataParallel 包装一个模块后,模块的属性(例如自定义方法)将无法访问。这是因为 DataParallel 定义了一些新的成员,而允许其他属性可能会导致名称冲突。对于仍然需要访问属性的用户,可以使用以下子类的解决方案。

class MyDataParallel(nn.DataParallel):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

DataParallel 基于的原语

通常,pytorch 的 nn.parallel 原语可以独立使用。我们已经实现了简单的 MPI 类原语

  • 复制:在多个设备上复制一个模块

  • 散列:在第一维度上分配输入

  • 聚合:在第一维度上聚合和连接输入

  • 并行应用:将一组已分配的输入应用于一组已分配的模型。

为了更好地说明,这里使用这些集合函数编写的 data_parallel 函数

def data_parallel(module, input, device_ids, output_device=None):
    if not device_ids:
        return module(input)

    if output_device is None:
        output_device = device_ids[0]

    replicas = nn.parallel.replicate(module, device_ids)
    inputs = nn.parallel.scatter(input, device_ids)
    replicas = replicas[:len(inputs)]
    outputs = nn.parallel.parallel_apply(replicas, inputs)
    return nn.parallel.gather(outputs, output_device)

模型的一部分在 CPU 上,一部分在 GPU 上

让我们来看一个将网络的一部分放在 CPU 上,一部分放在 GPU 上的简单示例。

device = torch.device("cuda:0")

class DistributedModel(nn.Module):

    def __init__(self):
        super().__init__(
            embedding=nn.Embedding(1000, 10),
            rnn=nn.Linear(10, 10).to(device),
        )

    def forward(self, x):
        # Compute embedding on CPU
        x = self.embedding(x)

        # Transfer to GPU
        x = x.to(device)

        # Compute RNN on GPU
        x = self.rnn(x)
        return x

这是针对之前使用 Torch 的用户提供的一个关于 PyTorch 的简短介绍。还有很多需要学习的内容。

查看我们更全面的入门教程,介绍了 optim 包、数据加载器等:使用 PyTorch 进行深度学习:60 分钟速成.

还可以查看

脚本总运行时间: ( 0 分钟 0.002 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源