快捷方式

概述

TensorDict 使组织数据和编写可重用、通用的 PyTorch 代码变得容易。它最初是为 TorchRL 开发的,后来我们将其分离出来成为一个独立的库。

TensorDict 主要是一个字典,但也像一个张量类:它支持多种主要与形状和存储相关的张量操作。它被设计成可以高效地从节点到节点或从进程到进程进行序列化或传输。最后,它带有自己的 nn 模块,该模块与 torch.func 兼容,旨在简化模型集成和参数操作。

在本页中,我们将阐述 TensorDict 的动机,并给出它的一些功能示例。

动机

TensorDict 允许您编写可在不同范例中重用的通用代码模块。例如,以下循环可用于大多数 SL、SSL、UL 和 RL 任务。

>>> for i, tensordict in enumerate(dataset):
...     # the model reads and writes tensordicts
...     tensordict = model(tensordict)
...     loss = loss_module(tensordict)
...     loss.backward()
...     optimizer.step()
...     optimizer.zero_grad()

凭借其 nn 模块,该包提供了许多工具,可以轻松地在代码库中使用 TensorDict

在多进程或分布式设置中,TensorDict 允许您将数据无缝地分派给每个工作进程

>>> # creates batches of 10 datapoints
>>> splits = torch.arange(tensordict.shape[0]).split(10)
>>> for worker in range(workers):
...     idx = splits[worker]
...     pipe[worker].send(tensordict[idx])

TensorDict 提供的一些操作也可以通过 tree_map 完成,但这会增加复杂性

>>> td = TensorDict(
...     {"a": torch.randn(3, 11), "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": td["a"], "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
...     {"a": regular_dicts["a"][i], "b": regular_dicts["b"][i]}
...     for i in range(3)]

嵌套的情况更加引人注目

>>> td = TensorDict(
...     {"a": {"c": torch.randn(3, 11)}, "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": {"c": td["a", "c"]}, "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
...     {"a": {"c": regular_dicts["a"]["c"][i]}, "b": regular_dicts["b"][i]}
...     for i in range(3)

在朴素地使用 pytree 时,应用 unbind 操作后将输出字典分解为三个结构相似的字典会迅速变得相当麻烦。使用 tensordict,我们为希望分解或分割嵌套结构的用户提供了一个简单的 API,而不是计算一个嵌套的分割 / 分解的嵌套结构。

功能特性

一个 TensorDict 是一个类似字典的张量容器。要实例化 TensorDict,您可以指定键值对以及批大小(可以通过 TensorDict() 创建一个空的 tensordict)。TensorDict 中任何值的前导维度必须与批大小兼容。

>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict(
...     {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)},
...     batch_size=[2, 3],
... )

设置或检索值的语法与常规字典非常相似。

>>> zeros = tensordict["zeros"]
>>> tensordict["twos"] = 2 * torch.ones(2, 3)

还可以沿着 batch_size 索引 tensordict,这样只需几个字符就能获得数据的一致切片(请注意,使用省略号通过 tree_map 索引第 n 个前导维度需要更多的编码)。

>>> sub_tensordict = tensordict[..., :2]

还可以使用带有 inplace=True 的 set 方法或 set_() 方法对内容进行原地更新。前者是后者的容错版本:如果找不到匹配的键,它将写入一个新键。

TensorDict 的内容现在可以集体操作。例如,要将所有内容放到特定设备上,只需执行

>>> tensordict = tensordict.to("cuda:0")

然后您可以断言 tensordict 的设备是 “cuda:0”

>>> assert tensordict.device == torch.device("cuda:0")

要重塑批维度,可以执行

>>> tensordict = tensordict.reshape(6)

该类支持许多其他操作,包括 squeeze()unsqueeze()view()permute()unbind()stack()cat() 等等。

如果某项操作不存在,通常可以使用 apply() 方法来解决问题。

避免形状操作

在某些情况下,可能需要将张量存储在 TensorDict 中,但在形状操作期间不强制要求批大小一致性。

这可以通过将张量包装在 UnbatchedTensor 实例中来实现。

UnbatchedTensor 在 TensorDict 上进行形状操作时会忽略其形状,从而可以灵活地存储和操作具有任意形状的张量。

>>> from tensordict import UnbatchedTensor
>>> tensordict = TensorDict({"zeros": UnbatchedTensor(torch.zeros(10))}, batch_size=[2, 3])
>>> reshaped_td = tensordict.reshape(6)
>>> reshaped_td["zeros"] is tensordict["zeros"]
True

非张量数据

Tensordict 是一个用于处理张量数据的强大库,但也支持非张量数据。本指南将向您展示如何使用 tensordict 处理非张量数据。

使用非张量数据创建 TensorDict

您可以使用 NonTensorData 类创建包含非张量数据的 TensorDict。

>>> from tensordict import TensorDict, NonTensorData
>>> import torch
>>> td = TensorDict(
...     a=NonTensorData("a string!"),
...     b=torch.zeros(()),
... )
>>> print(td)
TensorDict(
    fields={
        a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

如您所见,NonTensorData 对象像普通张量一样存储在 TensorDict 中。

访问非张量数据

您可以使用键或 get 方法访问非张量数据。常规的 getattr 调用将返回 NonTensorData 对象的内容,而 get() 将返回 NonTensorData 对象本身。

>>> print(td["a"])  # prints: a string!
>>> print(td.get("a"))  # prints: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None)

批处理非张量数据

如果您有一批非张量数据,可以将其存储在指定批大小的 TensorDict 中。

>>> td = TensorDict(
...     a=NonTensorData("a string!"),
...     b=torch.zeros(3),
...     batch_size=[3]
... )
>>> print(td)
TensorDict(
    fields={
        a: NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
        b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

在这种情况下,我们假设 tensordict 的所有元素都具有相同的非张量数据。

>>> print(td[0])
TensorDict(
    fields={
        a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

要为有形状的 tensordict 中的每个元素分配不同的非张量数据对象,您可以使用非张量数据堆栈。

堆叠非张量数据

如果您有一个非张量数据列表想要存储在 TensorDict 中,可以使用 NonTensorStack 类。

>>> td = TensorDict(
...     a=NonTensorStack("a string!", "another string!", "a third string!"),
...     b=torch.zeros(3),
...     batch_size=[3]
... )
>>> print(td)
TensorDict(
    fields={
        a: NonTensorStack(
            ['a string!', 'another string!', 'a third string!'...,
            batch_size=torch.Size([3]),
            device=None),
        b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

您可以访问第一个元素,然后您将获得第一个字符串

>>> print(td[0])
TensorDict(
    fields={
        a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

相比之下,将 NonTensorData 与列表一起使用不会产生相同的结果,因为对于碰巧是列表的非张量数据,通常无法确定该如何处理

>>> td = TensorDict(
...     a=NonTensorData(["a string!", "another string!", "a third string!"]),
...     b=torch.zeros(3),
...     batch_size=[3]
... )
>>> print(td[0])
TensorDict(
    fields={
        a: NonTensorData(data=['a string!', 'another string!', 'a third string!'], batch_size=torch.Size([]), device=None),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

堆叠包含非张量数据的 TensorDict

要堆叠非张量数据,stack() 将检查非张量对象的标识,如果它们匹配,则生成单个 NonTensorData;否则,生成 NonTensorStack

>>> td = TensorDict(
...     a=NonTensorData("a string!"),
... b = torch.zeros(()),
... )
>>> print(torch.stack([td, td]))
TensorDict(
    fields={
        a: NonTensorData(data=a string!, batch_size=torch.Size([2]), device=None),
        b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

如果您想确保结果是堆栈,请改用 lazy_stack()

>>> print(TensorDict.lazy_stack([td, td]))
LazyStackedTensorDict(
    fields={
        a: NonTensorStack(
            ['a string!', 'a string!'],
            batch_size=torch.Size([2]),
            device=None),
        b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False,
    stack_dim=0)

命名维度

TensorDict 及相关类也支持维度命名。可以在构建时或稍后指定名称。其语义类似于 torch.Tensor 的维度命名功能

>>> tensordict = TensorDict({}, batch_size=[3, 4], names=["a", None])
>>> tensordict.refine_names(..., "b")
>>> tensordict.names = ["z", "y"]
>>> tensordict.rename("m", "n")
>>> tensordict.rename(m="h")

嵌套 TensorDict

TensorDict 中的值本身也可以是 TensorDict(下面示例中的嵌套字典将被转换为嵌套 TensorDict)。

>>> tensordict = TensorDict(
...     {
...         "inputs": {
...             "image": torch.rand(100, 28, 28),
...             "mask": torch.randint(2, (100, 28, 28), dtype=torch.uint8)
...         },
...         "outputs": {"logits": torch.randn(100, 10)},
...     },
...     batch_size=[100],
... )

访问或设置嵌套键可以使用字符串元组完成

>>> image = tensordict["inputs", "image"]
>>> logits = tensordict.get(("outputs", "logits"))  # alternative way to access
>>> tensordict["outputs", "probabilities"] = torch.sigmoid(logits)

惰性评估

TensorDict 的一些操作会推迟执行,直到访问其中的项。例如,堆叠、挤压 (squeezing)、非挤压 (unsqueezing)、置换批维度和创建视图等操作不会立即在 TensorDict 的所有内容上执行。相反,它们在访问 TensorDict 中的值时惰性执行。如果 TensorDict 包含许多值,这可以节省大量不必要的计算。

>>> tensordicts = [TensorDict({
...     "a": torch.rand(10),
...     "b": torch.rand(10, 1000, 1000)}, [10])
...     for _ in range(3)]
>>> stacked = torch.stack(tensordicts, 0)  # no stacking happens here
>>> stacked_a = stacked["a"]  # we stack the a values, b values are not stacked

它还有一个优点,就是我们可以操作堆栈中的原始 tensordict

>>> stacked["a"] = torch.zeros_like(stacked["a"])
>>> assert (tensordicts[0]["a"] == 0).all()

需要注意的是,get 方法现在已成为一个昂贵的操作,如果重复多次,可能会导致一些开销。只需在执行 stack 后调用 tensordict.contiguous() 即可避免这种情况。为了进一步缓解此问题,TensorDict 附带了自己的元数据类 (MetaTensor),该类可以跟踪字典中每个条目的类型、形状、dtype 和设备,而无需执行昂贵的操作。

惰性预分配

假设我们有一个函数 foo() -> TensorDict,然后我们执行以下操作

>>> tensordict = TensorDict({}, batch_size=[N])
>>> for i in range(N):
...     tensordict[i] = foo()

i == 0 时,空的 TensorDict 将自动填充批大小为 N 的空张量。在后续的循环迭代中,所有更新都将原地写入。

TensorDictModule

为了方便将 TensorDict 集成到代码库中,我们提供了 tensordict.nn 包,该包允许用户将 TensorDict 实例传递给 Module 对象(或任何可调用对象)。

TensorDictModule 包装了 Module 并接受单个 TensorDict 作为输入。您可以指定底层模块应从何处获取输入以及应将输出写入何处。这是我们可以编写可重用、通用的高级代码(例如动机部分中的训练循环)的关键原因。

>>> from tensordict.nn import TensorDictModule
>>> class Net(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = nn.LazyLinear(1)
...
...     def forward(self, x):
...         logits = self.linear(x)
...         return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
...     Net(),
...     in_keys=["input"],
...     out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> tensordict = module(tensordict)
>>> # outputs can now be retrieved from the tensordict
>>> logits = tensordict["outputs", "logits"]
>>> probabilities = tensordict.get(("outputs", "probabilities"))

为了方便采用此类,还可以将张量作为 kwargs 传递

>>> tensordict = module(input=torch.randn(32, 100))

这将返回一个与前一个代码框中完全相同的 TensorDict。有关此功能的更多背景信息,请参阅 导出教程

许多 PyTorch 用户面临的一个痛点是 nn.Sequential 无法处理具有多个输入的模块。使用基于键的图可以轻松解决此问题,因为序列中的每个节点都知道需要读取哪些数据以及写入到何处。

为此,我们提供了 TensorDictSequential 类,该类将数据通过一系列 TensorDictModules 进行传递。序列中的每个模块都从原始 TensorDict 获取输入,并将输出写入其中,这意味着序列中的模块可以忽略前一个模块的输出,或根据需要从 tensordict 中获取额外的输入。以下是一个示例

>>> class Net(nn.Module):
...     def __init__(self, input_size=100, hidden_size=50, output_size=10):
...         super().__init__()
...         self.fc1 = nn.Linear(input_size, hidden_size)
...         self.fc2 = nn.Linear(hidden_size, output_size)
...
...     def forward(self, x):
...         x = torch.relu(self.fc1(x))
...         return self.fc2(x)
...
... class Masker(nn.Module):
...     def forward(self, x, mask):
...         return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
...     Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
...     Masker(),
...     in_keys=[("intermediate", "x"), ("input", "mask")],
...     out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> tensordict = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> tensordict = module(tensordict)
>>> intermediate_x = tensordict["intermediate", "x"]
>>> probabilities = tensordict["output", "probabilities"]

在此示例中,第二个模块将第一个模块的输出与存储在 TensorDict 中 (“inputs”, “mask”) 键下的掩码组合。

TensorDictSequential 提供了一系列其他功能:可以通过查询 in_keys 和 out_keys 属性来访问输入和输出键列表。还可以通过使用所需的输入和输出键集查询 select_subsequence() 来请求子图。这将返回另一个 TensorDictSequential,其中只包含满足这些要求必不可少的模块。TensorDictModule 也与 vmap() 和其他 torch.func 功能兼容。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并解答您的问题

查看资源