torch.sort¶
- torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)¶
按值沿给定维度对
input
张量中的元素进行升序排序。如果未指定
dim
,则选择 input 的最后一个维度。如果
descending
为True
,则按值对元素进行降序排序。如果
stable
为True
,则排序例程将是稳定的,保留等价元素的原始顺序。返回一个命名元组 (values, indices),其中 values 是排序后的值,indices 是原始 input 张量中元素的索引。
- 参数
- 关键字参数
out (元组, 可选) – 可选提供的用于作输出缓冲区的 (Tensor, LongTensor) 输出元组
示例
>>> x = torch.randn(3, 4) >>> sorted, indices = torch.sort(x) >>> sorted tensor([[-0.2162, 0.0608, 0.6719, 2.3332], [-0.5793, 0.0061, 0.6058, 0.9497], [-0.5071, 0.3343, 0.9553, 1.0960]]) >>> indices tensor([[ 1, 0, 2, 3], [ 3, 1, 0, 2], [ 0, 3, 1, 2]]) >>> sorted, indices = torch.sort(x, 0) >>> sorted tensor([[-0.5071, -0.2162, 0.6719, -0.5793], [ 0.0608, 0.0061, 0.9497, 0.3343], [ 0.6058, 0.9553, 1.0960, 2.3332]]) >>> indices tensor([[ 2, 0, 0, 1], [ 0, 1, 1, 2], [ 1, 2, 2, 0]]) >>> x = torch.tensor([0, 1] * 9) >>> x.sort() torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) >>> x.sort(stable=True) torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17]))