torch.cuda.comm.gather¶
- torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)[源代码]¶
从多个 GPU 设备收集张量。
- 参数
tensors (Iterable[Tensor]) – 要收集的张量的可迭代对象。除了
dim
外,所有维度中的张量大小都必须匹配。dim (int, 可选) – 沿其连接张量的维度。默认值:
0
。destination (torch.device, str 或 int, 可选) – 输出设备。可以是 CPU 或 CUDA。默认值:当前 CUDA 设备。
out (Tensor, 可选, 仅限关键字) – 用于存储收集结果的张量。其大小必须与
tensors
的大小匹配,除了dim
,其大小必须等于sum(tensor.size(dim) for tensor in tensors)
。可以位于 CPU 或 CUDA 上。
注意
指定
out
时,不得指定destination
。- 返回值
- 如果指定了
destination
, 位于
destination
设备上的张量,它是沿dim
连接tensors
的结果。
- 如果指定了
- 如果指定了
out
, out
张量,现在包含沿dim
连接tensors
的结果。
- 如果指定了