torch.cuda.comm.scatter¶
- torch.cuda.comm.scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None)[source][source]¶
将张量分散到多个 GPU。
- 参数
tensor (Tensor) – 要分散的张量。可以在 CPU 或 GPU 上。
devices (Iterable[torch.device, str or int], optional) – 一个 GPU 设备的可迭代对象,用于分散张量。
chunk_sizes (Iterable[int], optional) – 要放置在每个设备上的块的大小。长度应与
devices
匹配,并且总和应等于tensor.size(dim)
。如果未指定,tensor
将被平均分成块。dim (int, optional) – 用于分割
tensor
的维度。默认值:0
。streams (Iterable[torch.cuda.Stream], optional) – 一个 Stream 的可迭代对象,用于执行分散操作。如果未指定,将使用默认 Stream。
out (Sequence[Tensor], optional, keyword-only) – 用于存储输出结果的 GPU 张量。这些张量的大小必须与
tensor
匹配,除了dim
维度,在该维度上,总大小必须等于tensor.size(dim)
。
注意
必须且只能指定
devices
和out
中的一个。当指定out
时,不得指定chunk_sizes
,块大小将从out
的大小推断。- 返回
- 如果指定了
devices
, 一个元组,包含
tensor
的块,这些块放置在devices
上。
- 如果指定了
- 如果指定了
out
, 一个元组,包含
out
张量,每个张量都包含tensor
的一个块。
- 如果指定了