快捷方式

torch.argwhere

torch.argwhere(input) Tensor

返回一个张量,其中包含 input 中所有非零元素的索引。结果中的每一行都包含 input 中一个非零元素的索引。结果按字典顺序排序,最后一个索引变化最快(C 样式)。

如果 inputnn 个维度,则结果索引张量 out 的大小为 (z×n)(z \times n),其中 zzinput 张量中非零元素的总数。

注意

此函数类似于 NumPy 的 argwhere

input 位于 CUDA 上时,此函数会导致主机-设备同步。

参数

{input}

示例

>>> t = torch.tensor([1, 0, 1])
>>> torch.argwhere(t)
tensor([[0],
        [2]])
>>> t = torch.tensor([[1, 0, 1], [0, 1, 1]])
>>> torch.argwhere(t)
tensor([[0, 0],
        [0, 2],
        [1, 1],
        [1, 2]])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源