torch.Tensor.masked_scatter¶
- Tensor.masked_scatter(mask, tensor) Tensor ¶
torch.Tensor.masked_scatter_() 的非原地版本
注意
输入
self
和mask
会进行 广播。示例
>>> self = torch.tensor([0, 0, 0, 0, 0]) >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter(mask, source) tensor([[0, 0, 0, 0, 1], [2, 3, 0, 4, 5]])