快捷方式

torch.Tensor.scatter_add_

Tensor.scatter_add_(dim, index, src) Tensor

将张量 src 中的所有值添加到 self 中,索引在 index 张量中指定,类似于 scatter_()。对于 src 中的每个值,它将被添加到 self 中的索引,该索引由 src 中的索引指定,用于 dimension != dim,并由 index 中的对应值指定,用于 dimension = dim

对于一个 3 维张量,self 将更新为

self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

selfindexsrc 应具有相同的维度数量。还需要 index.size(d) <= src.size(d) 对于所有维度 d,并且 index.size(d) <= self.size(d) 对于所有维度 d != dim。请注意,indexsrc 不进行广播。

注意

此操作在 CUDA 设备上给出张量时可能表现出非确定性行为。有关详细信息,请参阅 可重复性

注意

反向传播仅在 src.shape == index.shape 时实现。

参数
  • dim (int) – 要索引的轴

  • index (LongTensor) – 要散列和添加的元素的索引,可以为空或与 src 的维度相同。如果为空,则操作将返回未更改的 self

  • src (Tensor) – 要散列和添加的源元素

示例

>>> src = torch.ones((2, 5))
>>> index = torch.tensor([[0, 1, 2, 0, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[1., 0., 0., 1., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.]])
>>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[2., 0., 0., 1., 1.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 1., 1.]])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源