torch.argwhere¶
- torch.argwhere(input) Tensor ¶
返回一个张量,其中包含
input
中所有非零元素的索引。结果中的每一行都包含input
中一个非零元素的索引。结果按字典顺序排序,最后一个索引变化最快(C 样式)。如果
input
有 个维度,则结果索引张量out
的大小为 ,其中 是input
张量中非零元素的总数。注意
此函数类似于 NumPy 的 argwhere。
当
input
位于 CUDA 上时,此函数会导致主机-设备同步。- 参数
{input} –
示例
>>> t = torch.tensor([1, 0, 1]) >>> torch.argwhere(t) tensor([[0], [2]]) >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) >>> torch.argwhere(t) tensor([[0, 0], [0, 2], [1, 1], [1, 2]])