快捷方式

张量索引 API

在 PyTorch C++ API 中索引张量与 Python API 的工作原理非常相似。所有索引类型,例如 None / ... / 整数 / 布尔值 / 切片 / 张量,在 C++ API 中均可用,从而使从 Python 索引代码到 C++ 的转换变得非常简单。主要区别在于,不像 Python API 语法那样使用 [] 运算符,在 C++ API 中,索引方法是

  • torch::Tensor::index (链接)

  • torch::Tensor::index_put_ (链接)

还要注意,例如 None / Ellipsis / Slice 这样的索引类型位于 torch::indexing 命名空间中,建议在任何索引代码之前添加 using namespace torch::indexing,以便方便地使用这些索引类型。

以下是一些将 Python 索引代码转换为 C++ 的示例

获取器

Python

C++(假设 using namespace torch::indexing

tensor[None]

tensor.index({None})

tensor[Ellipsis, ...]

tensor.index({Ellipsis, "..."})

tensor[1, 2]

tensor.index({1, 2})

tensor[True, False]

tensor.index({true, false})

tensor[1::2]

tensor.index({Slice(1, None, 2)})

tensor[torch.tensor([1, 2])]

tensor.index({torch::tensor({1, 2})})

tensor[..., 0, True, 1::2, torch.tensor([1, 2])]

tensor.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})

设置器

Python

C++(假设 using namespace torch::indexing

tensor[None] = 1

tensor.index_put_({None}, 1)

tensor[Ellipsis, ...] = 1

tensor.index_put_({Ellipsis, "..."}, 1)

tensor[1, 2] = 1

tensor.index_put_({1, 2}, 1)

tensor[True, False] = 1

tensor.index_put_({true, false}, 1)

tensor[1::2] = 1

tensor.index_put_({Slice(1, None, 2)}, 1)

tensor[torch.tensor([1, 2])] = 1

tensor.index_put_({torch::tensor({1, 2})}, 1)

tensor[..., 0, True, 1::2, torch.tensor([1, 2])] = 1

tensor.index_put_({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}, 1)

Python/C++ 索引类型之间的转换

Python 和 C++ 索引类型之间的对应关系如下

Python

C++(假设 using namespace torch::indexing

None

None

Ellipsis

Ellipsis

...

"..."

123

123

True

true

False

false

:::

Slice()Slice(None, None)Slice(None, None, None)

1:1::

Slice(1, None)Slice(1, None, None)

:3:3:

Slice(None, 3)Slice(None, 3, None)

::2

Slice(None, None, 2)

1:3

Slice(1, 3)

1::2

Slice(1, None, 2)

:3:2

Slice(None, 3, 2)

1:3:2

Slice(1, 3, 2)

torch.tensor([1, 2])

torch::tensor({1, 2})

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源