torch.kthvalue¶
- torch.kthvalue(input, k, dim=None, keepdim=False, *, out=None)¶
返回一个名为元组
(values, indices)
,其中values
是给定维度dim
中input
张量每一行的第k
个最小元素。而indices
是找到的每个元素的索引位置。如果未给出
dim
,则选择 input 的最后一个维度。如果
keepdim
为True
,则values
和indices
张量的大小与input
相同,除了维度dim
为大小 1。否则,dim
会被压缩(参见torch.squeeze()
),导致values
和indices
张量的维度比input
张量少 1 个。注意
当
input
是 CUDA 张量并且存在多个有效的第k
个值时,此函数可能会不确定地返回其中任何一个的indices
。- 参数
- 关键字参数
out (元组, 可选) – 可以选择提供 (Tensor, LongTensor) 的输出元组用作输出缓冲区
示例
>>> x = torch.arange(1., 6.) >>> x tensor([ 1., 2., 3., 4., 5.]) >>> torch.kthvalue(x, 4) torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) >>> x=torch.arange(1.,7.).resize_(2,3) >>> x tensor([[ 1., 2., 3.], [ 4., 5., 6.]]) >>> torch.kthvalue(x, 2, 0, True) torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]]))