• 文档 >
  • 切片、索引和掩码
快捷方式

切片、索引和掩码

作者: Tom Begley

在本教程中,你将学习如何对 TensorDict 进行切片、索引和掩码操作。

正如教程 操作 TensorDict 的形状 中所讨论的,当我们创建一个 TensorDict 时,我们指定了 batch_size,它必须与 TensorDict 中所有条目的前导维度一致。由于我们保证所有条目都共享这些维度,因此我们可以像对 torch.Tensor 进行索引一样对批处理维度进行索引和掩码操作。这些索引将应用于 TensorDict 中所有条目的批处理维度。

例如,给定一个具有两个批处理维度的 TensorDicttensordict[0] 将返回一个结构相同的新 TensorDict,其值对应于原始 TensorDict 中每个条目的第一个“行”。

import torch
from tensordict import TensorDict

tensordict = TensorDict(
    {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)

print(tensordict[0])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

与常规张量一样,语法是相同的。例如,如果我们想删除每个条目的第一行,我们可以如下进行索引

print(tensordict[1:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 4]),
    device=None,
    is_shared=False)

我们可以同时索引多个维度

print(tensordict[:, 2:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)

我们还可以使用 Ellipsis (...) 来表示任意数量的 :,以使选择元组的长度与 tensordict.batch_dims 相同。

print(tensordict[..., 2:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)

使用索引设置值

通常,只要批处理大小兼容,tensordict[index] = new_tensordict 即可正常工作。

tensordict = TensorDict(
    {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)

td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4])
tensordict[:-1] = td2
print(tensordict["a"], tensordict["b"])
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]]) tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [0., 0., 0., 0.]])

掩码

我们像对张量进行掩码一样对 TensorDict 进行掩码。

mask = torch.BoolTensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]])
tensordict[mask]
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([6]),
    device=None,
    is_shared=False)

脚本总运行时间: (0 分 0.004 秒)

画廊由 Sphinx-Gallery 生成

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源