torch.cuda.comm.gather¶
- torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)[源代码][源代码]¶
从多个 GPU 设备收集张量。
- 参数
tensors (Iterable[Tensor]) – 要收集的张量的可迭代对象。除了
dim
之外,所有维度上的张量大小都必须匹配。dim (int, 可选) – 张量将沿其连接的维度。默认值:
0
。destination (torch.device, str, 或 int, 可选) – 输出设备。可以是 CPU 或 CUDA。默认值:当前的 CUDA 设备。
out (Tensor, 可选, 仅关键字参数) – 用于存储收集结果的张量。其大小必须与
tensors
的大小匹配,除了dim
维度,该维度的大小必须等于sum(tensor.size(dim) for tensor in tensors)
。可以在 CPU 或 CUDA 上。
注意
当指定
out
时,不得指定destination
。- 返回
- 如果指定了
destination
, 位于
destination
设备上的张量,它是沿dim
连接tensors
的结果。
- 如果指定了
- 如果指定了
out
, out
张量,现在包含沿dim
连接tensors
的结果。
- 如果指定了