快捷方式

torch.kthvalue

torch.kthvalue(input, k, dim=None, keepdim=False, *, out=None)

返回一个名为元组 (values, indices),其中 values 是给定维度 diminput 张量每一行的第 k 个最小元素。而 indices 是找到的每个元素的索引位置。

如果未给出 dim,则选择 input 的最后一个维度。

如果 keepdimTrue,则 valuesindices 张量的大小与 input 相同,除了维度 dim 为大小 1。否则,dim 会被压缩(参见 torch.squeeze()),导致 valuesindices 张量的维度比 input 张量少 1 个。

注意

input 是 CUDA 张量并且存在多个有效的第 k 个值时,此函数可能会不确定地返回其中任何一个的 indices

参数
  • input (张量) – 输入张量。

  • k (int) – 第 k 个最小元素的 k

  • dim (int, 可选) – 沿其查找第 k 个值的维度

  • keepdim (bool) – 输出张量是否保留 dim

关键字参数

out (元组, 可选) – 可以选择提供 (Tensor, LongTensor) 的输出元组用作输出缓冲区

示例

>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.kthvalue(x, 4)
torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3))

>>> x=torch.arange(1.,7.).resize_(2,3)
>>> x
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.]])
>>> torch.kthvalue(x, 2, 0, True)
torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]]))

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源