快捷方式

torch.unravel_index

torch.unravel_index(indices, shape)[源代码][源代码]

将扁平索引张量转换为坐标张量的元组,这些坐标张量索引到指定形状的任意张量中。

参数
  • indices (Tensor) – 一个整数张量,包含扁平化版本的任意形状为 shape 的张量的索引。所有元素必须在范围 [0, prod(shape) - 1] 内。

  • shape (int, 整数序列, 或 torch.Size) – 任意张量的形状。所有元素必须是非负数。

返回

输出中的每个第 i 个张量对应于 shape 的第 i 维。每个张量都具有与 indices 相同的形状,并为 indices 给出的每个扁平索引包含一个维度 i 的索引。

返回类型

tuple of Tensors

示例

>>> import torch
>>> torch.unravel_index(torch.tensor(4), (3, 2))
(tensor(2),
 tensor(0))

>>> torch.unravel_index(torch.tensor([4, 1]), (3, 2))
(tensor([2, 0]),
 tensor([0, 1]))

>>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2))
(tensor([0, 0, 1, 1, 2, 2]),
 tensor([0, 1, 0, 1, 0, 1]))

>>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10))
(tensor([1, 5]),
 tensor([2, 6]),
 tensor([3, 7]),
 tensor([4, 8]))

>>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10))
(tensor([[1], [5]]),
 tensor([[2], [6]]),
 tensor([[3], [7]]),
 tensor([[4], [8]]))

>>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100))
(tensor([[12], [56]]),
 tensor([[34], [78]]))

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源