torch.cuda.comm.gather¶
- torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None)[source][source]¶
从多个 GPU 设备收集张量。
- 参数
tensors (Iterable[Tensor]) – 要收集的张量可迭代对象。除了
dim
之外的所有维度的张量大小必须匹配。dim (int, optional) – 张量将沿其连接的维度。默认值:
0
。destination (torch.device, str, or int, optional) – 输出设备。可以是 CPU 或 CUDA。默认值:当前的 CUDA 设备。
out (Tensor, optional, keyword-only) – 用于存储收集结果的张量。除了
dim
外,其大小必须与tensors
的大小匹配,在dim
维度上,大小必须等于sum(tensor.size(dim) for tensor in tensors)
。可以在 CPU 或 CUDA 上。
注意
指定
out
时,不得指定destination
。- 返回值
- 如果指定了
destination
, 一个位于
destination
设备上的张量,它是将tensors
沿dim
连接的结果。
- 如果指定了
- 如果指定了
out
, out
张量,现在包含将tensors
沿dim
连接的结果。
- 如果指定了