快捷方式

torch.Tensor.masked_scatter_

Tensor.masked_scatter_(mask, source)

mask 为 True 的位置,将 source 中的元素复制到 self 张量中。从 source 的位置 0 开始,对于 mask 为 True 的每一个位置,按顺序逐一将 source 中的元素复制到 self 中。mask 的形状必须与底层张量的形状可广播 (broadcastable)source 中的元素数量应至少与 mask 中值为 1 的数量相等。

参数
  • mask (BoolTensor) – 布尔掩码

  • source (Tensor) – 要从中复制元素的张量

注意

mask 是作用在 self 张量上,而不是给定的 source 张量上。

示例

>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter_(mask, source)
tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])

文档

查阅 PyTorch 全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源