torch.Tensor.masked_scatter_¶
- Tensor.masked_scatter_(mask, source)¶
将
source中的元素复制到self张量中,位置为mask为 True 的位置。从source中的元素被复制到self中,从source的位置 0 开始,并依次逐个继续,对于mask为 True 的每个出现情况。mask的形状必须与基础张量的形状可广播。source应该至少包含与mask中 1 的数量一样多的元素。- 参数
mask (BoolTensor) – 布尔掩码
source (Tensor) – 要从中复制的张量
注意
mask对self张量起作用,而不是对给定的source张量起作用。示例
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter_(mask, source) tensor([[0, 0, 0, 0, 1], [2, 3, 0, 4, 5]])