torch.nn.functional.torch.nn.parallel.data_parallel¶
- torch.nn.parallel.data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None)[源代码]¶
在 device_ids 中给定的 GPU 上并行评估 module(input)。
这是 DataParallel 模块的功能版本。
- 参数
module (Module) – 要并行评估的模块
inputs (Tensor) – 模块的输入
device_ids (list of int or torch.device) – 要复制模块的 GPU ID
output_device (list of int or torch.device) – 输出的 GPU 位置 使用 -1 表示 CPU。 (默认值:device_ids[0])
- 返回值
位于 output_device 上的 module(input) 结果的张量
- 返回类型