DataParallel¶
- class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[source][source]¶
在模块级别实现数据并行。
此容器通过在批处理维度中分块,跨指定设备拆分输入,从而并行化给定
module
的应用(其他对象将在每个设备上复制一次)。在前向传播中,模块在每个设备上复制,每个副本处理一部分输入。在后向传播期间,来自每个副本的梯度被求和到原始模块中。批处理大小应大于使用的 GPU 数量。
警告
建议使用
DistributedDataParallel
,而不是此类,来进行多 GPU 训练,即使只有一个节点也是如此。请参阅:使用 nn.parallel.DistributedDataParallel 而不是多进程或 nn.DataParallel 和 分布式数据并行。允许将任意位置和关键字输入传递到 DataParallel,但某些类型会得到特殊处理。张量将在指定的维度(默认为 0)上进行分散。元组、列表和字典类型将进行浅复制。其他类型将在不同线程之间共享,如果在模型的前向传播中写入,则可能会损坏。
并行化的
module
必须在运行此DataParallel
模块之前,将其参数和缓冲区放在device_ids[0]
上。警告
在每次前向传播中,
module
都会在每个设备上复制,因此在forward
中对运行模块的任何更新都将丢失。例如,如果module
具有一个计数器属性,该属性在每次forward
中递增,则它将始终保持初始值,因为更新是在副本上完成的,而副本在forward
后被销毁。但是,DataParallel
保证device[0]
上的副本将使其参数和缓冲区与基本并行化的module
共享存储。因此,对device[0]
上的参数或缓冲区的原地更新将被记录。例如,BatchNorm2d
和spectral_norm()
依赖此行为来更新缓冲区。警告
在
module
及其子模块上定义的前向和后向钩子将被调用len(device_ids)
次,每次都使用位于特定设备上的输入。特别是,钩子仅保证相对于相应设备上的操作以正确的顺序执行。例如,不保证通过register_forward_pre_hook()
设置的钩子在 所有len(device_ids)
forward()
调用之前执行,但每个此类钩子在相应设备的forward()
调用之前执行。警告
当
module
在forward()
中返回标量(即 0 维张量)时,此包装器将返回一个向量,其长度等于数据并行中使用的设备数量,其中包含来自每个设备的结果。注意
在包装在
DataParallel
中的Module
中使用pack sequence -> recurrent network -> unpack sequence
模式时存在细微之处。有关详细信息,请参阅 FAQ 中的 我的循环神经网络无法与数据并行一起工作 部分。- 参数
module (Module) – 要并行化的模块
device_ids (list of int or torch.device) – CUDA 设备(默认:所有设备)
output_device (int or torch.device) – 输出的设备位置(默认:device_ids[0])
- 变量
module (Module) – 要并行化的模块
示例
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var can be on any device, including CPU