注意
转到末尾 下载完整的示例代码。
操作 TensorDict 的形状¶
作者: Tom Begley
在本教程中,您将学习如何操作 TensorDict
及其内容的形状。
当我们创建一个 TensorDict
时,我们指定一个 batch_size
,该值必须与 TensorDict
中所有条目的前导维度一致。由于我们保证所有条目都共享这些共同的维度,因此 TensorDict
能够公开一些方法,我们可以用这些方法来操作 TensorDict
及其内容的形状。
import torch
from tensordict.tensordict import TensorDict
索引 TensorDict
¶
由于批处理维度保证存在于所有条目上,因此我们可以随意索引它们,并且 TensorDict
的每个条目将以相同的方式被索引。
a = torch.rand(3, 4)
b = torch.rand(3, 4, 5)
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])
indexed_tensordict = tensordict[:2, 1]
assert indexed_tensordict["a"].shape == torch.Size([2])
assert indexed_tensordict["b"].shape == torch.Size([2, 5])
重塑 TensorDict
¶
TensorDict.reshape
的工作原理与 torch.Tensor.reshape()
相同。它适用于 TensorDict
内容中的所有条目,沿着批处理维度 - 请注意以下示例中 b
的形状。它还会更新 batch_size
属性。
reshaped_tensordict = tensordict.reshape(-1)
assert reshaped_tensordict.batch_size == torch.Size([12])
assert reshaped_tensordict["a"].shape == torch.Size([12])
assert reshaped_tensordict["b"].shape == torch.Size([12, 5])
拆分 TensorDict
¶
TensorDict.split
与 torch.Tensor.split()
类似。它将 TensorDict
拆分为块。每个块都是一个 TensorDict
,其结构与原始结构相同,但其条目是原始 TensorDict
中相应条目的视图。
chunks = tensordict.split([3, 1], dim=1)
assert chunks[0].batch_size == torch.Size([3, 3])
assert chunks[1].batch_size == torch.Size([3, 1])
torch.testing.assert_close(chunks[0]["a"], tensordict["a"][:, :-1])
注意
只要函数或方法接受 dim
参数,负维度就会相对于调用函数或方法的 TensorDict
的 batch_size
进行解释。特别是,如果存在具有不同批处理大小的嵌套 TensorDict
值,则负维度始终相对于根的批处理维度进行解释。
tensordict = TensorDict(
{
"a": torch.rand(3, 4),
"nested": TensorDict({"b": torch.rand(3, 4, 5)}, [3, 4, 5])
},
[3, 4],
)
# dim = -2 will be interpreted as the first dimension throughout, as the root
# TensorDict has 2 batch dimensions, even though the nested TensorDict has 3
chunks = tensordict.split([2, 1], dim=-2)
assert chunks[0].batch_size == torch.Size([2, 4])
assert chunks[0]["nested"].batch_size == torch.Size([2, 4, 5])
正如您从本示例中看到的,TensorDict.split
方法的行为完全像我们在调用之前用 dim=tensordict.batch_dims - 2
替换了 dim=-2
一样。
解除绑定¶
TensorDict.unbind
与 torch.Tensor.unbind()
类似,并且在概念上与 TensorDict.split
类似。它会移除指定的维度,并返回沿着该维度所有切片的 tuple
。
slices = tensordict.unbind(dim=1)
assert len(slices) == 4
assert all(s.batch_size == torch.Size([3]) for s in slices)
torch.testing.assert_close(slices[0]["a"], tensordict["a"][:, 0])
堆叠和连接¶
TensorDict
可以与 torch.cat
和 torch.stack
结合使用。
堆叠 TensorDict
¶
堆叠可以延迟或连续进行。延迟堆叠只是一组作为 TensorDict 堆叠呈现的 TensorDict 列表。它允许用户携带一组具有不同内容形状、设备或键集的 TensorDict。另一个优点是堆叠操作可能很昂贵,如果只需要一小部分键,则延迟堆叠将比真正的堆叠快得多。它依赖于 LazyStackedTensorDict
类。在这种情况下,只有在访问时才会按需堆叠值。
from tensordict import LazyStackedTensorDict
cloned_tensordict = tensordict.clone()
stacked_tensordict = LazyStackedTensorDict.lazy_stack(
[tensordict, cloned_tensordict], dim=0
)
print(stacked_tensordict)
# Previously, torch.stack was always returning a lazy stack. For consistency with
# the regular PyTorch API, this behaviour will soon be adapted to deliver only
# dense tensordicts. To control which behaviour you are relying on, you can use
# the :func:`~tensordict.utils.set_lazy_legacy` decorator/context manager:
from tensordict.utils import set_lazy_legacy
with set_lazy_legacy(True): # old behaviour
lazy_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(lazy_stack, LazyStackedTensorDict)
with set_lazy_legacy(False): # new behaviour
dense_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(dense_stack, TensorDict)
LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, 3, 4]),
device=None,
is_shared=False,
stack_dim=0)
如果我们在堆叠维度上索引 LazyStackedTensorDict
,我们将恢复原始的 TensorDict
。
assert stacked_tensordict[0] is tensordict
assert stacked_tensordict[1] is cloned_tensordict
访问 LazyStackedTensorDict
中的键会导致这些值被堆叠。如果键对应于嵌套的 TensorDict
,那么我们将恢复另一个 LazyStackedTensorDict
。
assert stacked_tensordict["a"].shape == torch.Size([2, 3, 4])
注意
由于值是按需堆叠的,多次访问同一个项目意味着它会被多次堆叠,这效率低下。如果您需要多次访问堆叠的 TensorDict
中的值,您可能需要考虑将 LazyStackedTensorDict
转换为连续的 TensorDict
,可以使用 LazyStackedTensorDict.to_tensordict
或 LazyStackedTensorDict.contiguous
方法。
调用这两个方法中的任何一个后,我们将得到一个包含堆叠值的常规 TensorDict
,并且在访问值时不会执行额外的计算。
连接 TensorDict
¶
连接不是延迟完成的,而是调用 torch.cat()
对 TensorDict
实例列表进行操作,它只返回一个 TensorDict
,其条目是列表中元素的连接条目。
concatenated_tensordict = torch.cat([tensordict, cloned_tensordict], dim=0)
assert isinstance(concatenated_tensordict, TensorDict)
assert concatenated_tensordict.batch_size == torch.Size([6, 4])
assert concatenated_tensordict["b"].shape == torch.Size([6, 4, 5])
扩展 TensorDict
¶
我们可以使用 TensorDict.expand
扩展 TensorDict
的所有条目。
exp_tensordict = tensordict.expand(2, *tensordict.batch_size)
assert exp_tensordict.batch_size == torch.Size([2, 3, 4])
torch.testing.assert_close(exp_tensordict["a"][0], exp_tensordict["a"][1])
压缩和解压缩 TensorDict
¶
我们可以使用 squeeze()
和 unsqueeze()
方法压缩或解压缩 TensorDict
的内容。
tensordict = TensorDict({"a": torch.rand(3, 1, 4)}, [3, 1, 4])
squeezed_tensordict = tensordict.squeeze()
assert squeezed_tensordict["a"].shape == torch.Size([3, 4])
print(squeezed_tensordict, end="\n\n")
unsqueezed_tensordict = tensordict.unsqueeze(-1)
assert unsqueezed_tensordict["a"].shape == torch.Size([3, 1, 4, 1])
print(unsqueezed_tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 1, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 1, 4, 1]),
device=None,
is_shared=False)
注意
到目前为止,unsqueeze()
、squeeze()
、view()
、permute()
、transpose()
这些操作都返回这些操作的延迟版本(即,一个容器,其中存储了原始的张量字典,并且在每次访问键时应用该操作)。这种行为将在未来被弃用,并且可以通过 set_lazy_legacy()
函数进行控制。
>>> with set_lazy_legacy(True):
... lazy_unsqueeze = tensordict.unsqueeze(0)
>>> with set_lazy_legacy(False):
... dense_unsqueeze = tensordict.unsqueeze(0)
请记住,这些方法始终只应用于批次维度。条目的任何非批次维度都不会受到影响。
tensordict = TensorDict({"a": torch.rand(3, 1, 1, 4)}, [3, 1])
squeezed_tensordict = tensordict.squeeze()
# only one of the singleton dimensions is dropped as the other
# is not a batch dimension
assert squeezed_tensordict["a"].shape == torch.Size([3, 1, 4])
查看 TensorDict¶
TensorDict
也支持 view
。这会创建一个 _ViewedTensorDict
,它会在访问内容时延迟创建其内容的视图。
tensordict = TensorDict({"a": torch.arange(12)}, [12])
# no views are created at this step
viewed_tensordict = tensordict.view((2, 3, 2))
# the view of "a" is created on-demand when we access it
assert viewed_tensordict["a"].shape == torch.Size([2, 3, 2])
置换批次维度¶
TensorDict.permute
方法可用于置换批次维度,就像 torch.permute()
一样。非批次维度保持不变。
此操作是延迟的,因此只有在我们尝试访问条目时才会置换批次维度。同样,如果您可能需要多次访问特定条目,请考虑转换为 TensorDict
。
tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])
# swap the batch dimensions
permuted_tensordict = tensordict.permute([1, 0])
assert permuted_tensordict["a"].shape == torch.Size([4, 3])
assert permuted_tensordict["b"].shape == torch.Size([4, 3, 5])
使用张量字典作为装饰器¶
对于一堆可逆操作,张量字典可以用作装饰器。这些操作包括 to_module()
用于函数调用、unlock_()
和 lock_()
或者形状操作,例如 view()
、permute()
transpose()
、squeeze()
和 unsqueeze()
。以下是用 transpose
函数的示例。
tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])
with tensordict.transpose(1, 0) as tdt:
tdt.set("c", torch.ones(4, 3)) # we have permuted the dims
# the ``"c"`` entry is now in the tensordict we used as decorator:
#
assert (tensordict.get("c") == 1).all()
在 TensorDict
中收集值¶
TensorDict.gather
方法可用于沿批次维度索引并将结果收集到单个维度中,就像 torch.gather()
一样。
index = torch.randint(4, (3, 4))
gathered_tensordict = tensordict.gather(dim=1, index=index)
print("index:\n", index, end="\n\n")
print("tensordict['a']:\n", tensordict["a"], end="\n\n")
print("gathered_tensordict['a']:\n", gathered_tensordict["a"], end="\n\n")
index:
tensor([[0, 3, 0, 0],
[0, 3, 2, 3],
[3, 0, 0, 0]])
tensordict['a']:
tensor([[0.3468, 0.9238, 0.2069, 0.3544],
[0.1446, 0.2779, 0.7292, 0.0551],
[0.8877, 0.2310, 0.7472, 0.3707]])
gathered_tensordict['a']:
tensor([[0.3468, 0.3544, 0.3468, 0.3468],
[0.1446, 0.0551, 0.7292, 0.0551],
[0.3707, 0.8877, 0.8877, 0.8877]])
脚本总运行时间:(0 分钟 0.008 秒)