快捷方式

DataParallel

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[源代码]

在模块级别实现数据并行。

此容器通过在批处理维度中进行分块(其他对象将在每个设备上复制一次)将给定module的应用并行化到指定的设备上。在正向传递中,模块在每个设备上进行复制,每个副本处理一部分输入。在反向传递过程中,来自每个副本的梯度将累加到原始模块中。

批处理大小应大于使用的 GPU 数量。

警告

建议使用DistributedDataParallel,而不是此类,来进行多 GPU 训练,即使只有一个节点也是如此。参见:使用 nn.parallel.DistributedDataParallel 代替多进程或 nn.DataParallel分布式数据并行

允许将任意位置和关键字输入传递到 DataParallel,但某些类型会进行特殊处理。张量将在指定的 dim 上进行散列(默认值为 0)。元组、列表和字典类型将被浅复制。其他类型将在不同线程之间共享,如果在模型的正向传递中写入,可能会损坏。

并行化的module必须在其参数和缓冲区在device_ids[0]上,然后才能运行此DataParallel模块。

警告

在每次正向传递中,module都会在每个设备上复制,因此对forward中正在运行的模块的任何更新都将丢失。例如,如果module具有在每个forward中递增的计数器属性,它将始终保持初始值,因为更新是在副本上完成的,这些副本在forward之后被销毁。但是,DataParallel保证device[0]上的副本的参数和缓冲区将与基本并行化的module共享存储。因此,对device[0]上的参数或缓冲区的就地更新将被记录。例如,BatchNorm2dspectral_norm() 依赖于此行为来更新缓冲区。

警告

module及其子模块上定义的前向和后向钩子将被调用len(device_ids)次,每次都使用位于特定设备上的输入。特别是,仅保证钩子以正确的顺序相对于对应设备上的操作执行。例如,不能保证通过register_forward_pre_hook()设置的钩子在所有len(device_ids)forward()调用之前执行,但保证每个此类钩子在对应设备的forward()调用之前执行。

警告

moduleforward()中返回标量(即 0 维张量)时,此包装器将返回一个长度等于数据并行中使用的设备数量的向量,其中包含来自每个设备的结果。

注意

在使用 pack sequence -> recurrent network -> unpack sequence 模式封装在 Module 中并使用 DataParallel 时存在一些细微差别。有关详细信息,请参阅常见问题解答中的 我的循环神经网络无法与数据并行一起使用 部分。

参数
变量

module (Module) – 要并行的模块

示例

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源