快捷方式

DataParallel

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[源代码][源代码]

在模块级别实现数据并行。

此容器通过沿批处理维度(其他对象将在每个设备上复制一次)对输入进行分块,将给定的 module 的应用并行化到指定的设备上。在前向传播中,模块在每个设备上复制,每个副本处理一部分输入。在后向传播中,每个副本的梯度会求和到原始模块中。

批处理大小应大于所使用的 GPU 数量。

警告

建议使用 DistributedDataParallel 类进行多 GPU 训练,即使只有一个节点。参见:使用 nn.parallel.DistributedDataParallel 而非多进程或 nn.DataParallel 以及 分布式数据并行

允许将任意位置和关键字输入传递给 DataParallel,但某些类型会特殊处理。张量会沿指定的维度(默认为 0)进行 分散。tuple、list 和 dict 类型会进行浅复制。其他类型将在不同的线程之间共享,如果在模型的正向传播中写入,则可能损坏。

并行化的 module 在运行此 DataParallel 模块之前,必须将其参数和缓冲区放置在 device_ids[0] 上。

警告

在前向传播时, module 在每个设备上被 复制,因此在 forward 中对正在运行的模块进行的任何更新都将丢失。例如,如果 module 有一个在每次 forward 中递增的计数器属性,它将始终保持初始值,因为更新是在副本上进行的,而副本在 forward 后即被销毁。然而,DataParallel 保证 device[0] 上的副本将其参数和缓冲区与基础并行化 module 共享存储。因此,对 device[0] 上的参数或缓冲区的 就地 更新将被记录下来。例如, BatchNorm2dspectral_norm() 依赖此行为来更新缓冲区。

警告

module 及其子模块上定义的前向和后向钩子将被调用 len(device_ids) 次,每次输入位于特定的设备上。特别是,钩子只保证按照与相应设备上的操作相关的正确顺序执行。例如,不保证通过 register_forward_pre_hook() 设置的钩子在所有 len(device_ids)forward() 调用之前执行,但保证每个此类钩子在该设备相应的 forward() 调用之前执行。

警告

moduleforward() 中返回一个标量(即 0 维张量)时,此封装器将返回一个长度等于数据并行所用设备数量的向量,其中包含来自每个设备的结果。

注意

在使用 Module 封装在 DataParallel 中时,使用 打包序列 -> 循环网络 -> 解包序列 模式时存在一些细微之处。详情请参阅 FAQ 中的 我的循环网络无法与数据并行一起使用 部分。

参数
  • module (Module) – 要并行化的模块

  • device_ids (list of int or torch.device) – CUDA 设备 (默认: 所有设备)

  • output_device (int or torch.device) – 输出设备的位置 (默认: device_ids[0])

变量

module (Module) – 要并行化的模块

示例

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU

文档

访问 PyTorch 全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源