注意
转到结尾 下载完整的示例代码。
切片、索引和掩码¶
**作者**:Tom Begley
在本教程中,您将学习如何切片、索引和掩码 TensorDict
。
如教程 操作 TensorDict 的形状 中所述,当我们创建 TensorDict
时,我们会指定一个 batch_size
,它必须与 TensorDict
中所有条目的前导维度一致。由于我们保证所有条目共享这些共同的维度,因此我们能够以与索引 torch.Tensor
相同的方式索引和掩码批处理维度。索引应用于 TensorDict
中所有条目的批处理维度。
例如,给定一个具有两个批处理维度的 TensorDict
,tensordict[0]
返回一个具有相同结构的新 TensorDict
,其值对应于原始 TensorDict
中每个条目第一“行”。
import torch
from tensordict import TensorDict
tensordict = TensorDict(
{"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)
print(tensordict[0])
TensorDict(
fields={
a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([4]),
device=None,
is_shared=False)
与常规张量相同的语法适用。例如,如果我们想要删除每个条目的第一行,我们可以按如下方式索引
print(tensordict[1:])
TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 4]),
device=None,
is_shared=False)
我们可以同时索引多个维度
print(tensordict[:, 2:])
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 2]),
device=None,
is_shared=False)
我们还可以使用 Ellipsis
来表示尽可能多的 :
,以使选择元组的长度与 tensordict.batch_dims
相同。
print(tensordict[..., 2:])
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 2]),
device=None,
is_shared=False)
使用索引设置值¶
通常,只要批处理大小兼容,tensordict[index] = new_tensordict
就可以工作。
tensordict = TensorDict(
{"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)
td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4])
tensordict[:-1] = td2
print(tensordict["a"], tensordict["b"])
tensor([[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]]) tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]])
掩码¶
我们掩码 TensorDict
的方式与掩码张量相同。
TensorDict(
fields={
a: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([6]),
device=None,
is_shared=False)
**脚本的总运行时间:**(0 分钟 0.005 秒)