快捷方式

概述

TensorDict 使得组织数据和编写可重用、通用的 PyTorch 代码变得容易。它最初是为 TorchRL 开发的,我们已经将其剥离为一个单独的库。

TensorDict 主要是一个字典,但也是一个类张量:它支持多个张量操作,这些操作主要是与形状和存储相关的。它旨在高效地序列化或在节点之间或进程之间传输。最后,它附带了自己的 tensordict.nn 模块,该模块与 functorch 兼容,旨在使模型集成和参数操作更容易。

在本页,我们将介绍 TensorDict 并提供一些关于其功能的示例。

动机

TensorDict 允许您编写在不同范式中可重用的通用代码模块。例如,以下循环可以在大多数 SL、SSL、UL 和 RL 任务中重复使用。

>>> for i, tensordict in enumerate(dataset):
...     # the model reads and writes tensordicts
...     tensordict = model(tensordict)
...     loss = loss_module(tensordict)
...     loss.backward()
...     optimizer.step()
...     optimizer.zero_grad()

通过其 tensordict.nn 模块,该包提供了许多工具,可以轻松地在代码库中使用 TensorDict

在多进程或分布式环境中,tensordict 允许您将数据无缝地分派到每个工作器。

>>> # creates batches of 10 datapoints
>>> splits = torch.arange(tensordict.shape[0]).split(10)
>>> for worker in range(workers):
...     idx = splits[worker]
...     pipe[worker].send(tensordict[idx])

TensorDict 提供的一些操作也可以通过 tree_map 完成,但复杂程度更高。

>>> td = TensorDict(
...     {"a": torch.randn(3, 11), "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": td["a"], "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
...     {"a": regular_dicts["a"][i], "b": regular_dicts["b"][i]}
...     for i in range(3)]

嵌套情况更具说服力。

>>> td = TensorDict(
...     {"a": {"c": torch.randn(3, 11)}, "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": {"c": td["a", "c"]}, "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
...     {"a": {"c": regular_dicts["a"]["c"][i]}, "b": regular_dicts["b"][i]}
...     for i in range(3)

在应用解绑操作后,将输出字典分解为三个结构类似的字典,在使用 pytree 进行简单操作时会很快变得相当麻烦。使用 tensordict,我们为想要解绑或拆分嵌套结构的用户提供了简单的 API,而不是计算嵌套拆分/解绑的嵌套结构。

特点

一个 TensorDict 是一个用于张量的字典式容器。要实例化一个 TensorDict,您必须指定键值对以及批次大小。任何 TensorDict 中的值的领先维度必须与批次大小兼容。

>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict(
...     {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)},
...     batch_size=[2, 3],
... )

设置或检索值的语法与普通字典的语法非常相似。

>>> zeros = tensordict["zeros"]
>>> tensordict["twos"] = 2 * torch.ones(2, 3)

人们还可以沿着其 batch_size 索引 tensordict,这使得只需几个字符就可以获得一致的数据切片(请注意,使用省略号使用 tree_map 索引第 n 个领先维度需要更多代码)。

>>> sub_tensordict = tensordict[..., :2]

人们还可以使用 set 方法(inplace=True)或 set_ 方法来进行内容的原地更新。前者是后者的容错版本:如果找不到匹配的键,它将写入一个新的键。

现在可以集体地操作 TensorDict 的内容。例如,要将所有内容放置在特定设备上,只需执行以下操作。

>>> tensordict = tensordict.to("cuda:0")

要重塑批次维度,可以执行以下操作。

>>> tensordict = tensordict.reshape(6)

该类支持许多其他操作,包括 squeeze、unsqueeze、view、permute、unbind、stack、cat 等等。如果缺少操作,TensorDict.apply 方法通常可以提供所需的解决方案。

命名维度

TensorDict 和相关类还支持维度名称。名称可以在构造时给出,也可以稍后进行细化。语义类似于 torch.Tensor 维度名称功能。

>>> tensordict = TensorDict({}, batch_size=[3, 4], names=["a", None])
>>> tensordict.refine_names(..., "b")
>>> tensordict.names = ["z", "y"]
>>> tensordict.rename("m", "n")
>>> tensordict.rename(m="h")

嵌套 TensorDicts

一个 TensorDict 中的值本身可以是 TensorDicts(以下示例中的嵌套字典将被转换为嵌套 TensorDicts)。

>>> tensordict = TensorDict(
...     {
...         "inputs": {
...             "image": torch.rand(100, 28, 28),
...             "mask": torch.randint(2, (100, 28, 28), dtype=torch.uint8)
...         },
...         "outputs": {"logits": torch.randn(100, 10)},
...     },
...     batch_size=[100],
... )

访问或设置嵌套键可以使用字符串元组。

>>> image = tensordict["inputs", "image"]
>>> logits = tensordict.get(("outputs", "logits"))  # alternative way to access
>>> tensordict["outputs", "probabilities"] = torch.sigmoid(logits)

延迟评估

TensorDict 的一些操作会延迟执行,直到访问项目为止。例如,堆叠、压缩、解压缩、置换批次维度和创建视图不会立即在 TensorDict 的所有内容上执行。而是当访问 TensorDict 中的值时,它们会被延迟执行。如果 TensorDict 包含许多值,这可以节省大量不必要的计算。

>>> tensordicts = [TensorDict({
...     "a": torch.rand(10),
...     "b": torch.rand(10, 1000, 1000)}, [10])
...     for _ in range(3)]
>>> stacked = torch.stack(tensordicts, 0)  # no stacking happens here
>>> stacked_a = stacked["a"]  # we stack the a values, b values are not stacked

它还具有我们可以操作堆栈中原始 tensordict 的优点。

>>> stacked["a"] = torch.zeros_like(stacked["a"])
>>> assert (tensordicts[0]["a"] == 0).all()

需要注意的是,get 方法现在已成为一项昂贵的操作,如果多次重复执行,可能会导致一些开销。可以通过在执行 stack 后简单地调用 tensordict.contiguous() 来避免这种情况。为了进一步减轻这种情况,TensorDict 自带了自己的元数据类(MetaTensor),它可以跟踪字典中每个条目的类型、形状、dtype 和设备,而无需执行昂贵的操作。

延迟预分配

假设我们有一些函数 foo() -> TensorDict,并且我们执行以下操作。

>>> tensordict = TensorDict({}, batch_size=[N])
>>> for i in range(N):
...     tensordict[i] = foo()

i == 0 时,空的 TensorDict 将自动填充批次大小为 N 的空张量。在循环的后续迭代中,所有更新都将在原地写入。

TensorDictModule

为了方便将 TensorDict 集成到您的代码库中,我们提供了一个 tensordict.nn 包,允许用户将 TensorDict 实例传递给 nn.Module 对象。

TensorDictModule 包装了 nn.Module 并接受单个 TensorDict 作为输入。您可以指定底层模块应该从哪里获取其输入,以及应该将其输出写入哪里。这是我们可以编写可重用、通用高级代码(例如动机部分中的训练循环)的关键原因。

>>> from tensordict.nn import TensorDictModule
>>> class Net(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = nn.LazyLinear(1)
...
...     def forward(self, x):
...         logits = self.linear(x)
...         return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
...     Net(),
...     in_keys=["input"],
...     out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> tensordict = module(tensordict)
>>> # outputs can now be retrieved from the tensordict
>>> logits = tensordict["outputs", "logits"]
>>> probabilities = tensordict.get(("outputs", "probabilities"))

为了方便使用此类,还可以将张量作为关键字参数传递。

>>> tensordict = module(input=torch.randn(32, 100))

这将返回一个与前一个代码框中相同的 TensorDict

多个 PyTorch 用户的一个关键痛点是 nn.Sequential 无法处理具有多个输入的模块。使用基于键的图形可以轻松解决此问题,因为序列中的每个节点都知道需要读取哪些数据以及应该写入哪里。

为此,我们提供了 TensorDictSequential 类,该类将数据通过一系列 TensorDictModules 传递。序列中的每个模块都从原始 TensorDict 获取其输入,并将输出写入原始 TensorDict,这意味着序列中的模块可以忽略其前驱的输出,或者根据需要从 tensordict 获取额外的输入。以下是一个示例。

>>> class Net(nn.Module):
...     def __init__(self, input_size=100, hidden_size=50, output_size=10):
...         super().__init__()
...         self.fc1 = nn.Linear(input_size, hidden_size)
...         self.fc2 = nn.Linear(hidden_size, output_size)
...
...     def forward(self, x):
...         x = torch.relu(self.fc1(x))
...         return self.fc2(x)
...
... class Masker(nn.Module):
...     def forward(self, x, mask):
...         return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
...     Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
...     Masker(),
...     in_keys=[("intermediate", "x"), ("input", "mask")],
...     out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> tensordict = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> tensordict = module(tensordict)
>>> intermediate_x = tensordict["intermediate", "x"]
>>> probabilities = tensordict["output", "probabilities"]

在此示例中,第二个模块将第一个模块的输出与存储在 TensorDict 中 (“inputs”, “mask”) 下的掩码相结合。

TensorDictSequential 提供了许多其他功能:可以通过查询 in_keys 和 out_keys 属性来访问输入和输出键列表。还可以通过使用所需的输入和输出键集查询 select_subsequence() 来请求子图。这将返回另一个 TensorDictSequential,其中只包含满足这些要求所必需的模块。 TensorDictModule 还与 vmap 和其他 functorch 功能兼容。

函数式编程

我们提供了一个 API,可以在 conjunction with functorch 中使用 TensorDict。例如,TensorDict 使得将模型权重连接起来以进行模型集成变得容易。

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(separator=".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(separator=".")
>>> params = make_functional(model)
>>> # params provided by make_functional match state_dict:
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

函数式 API 与目前在 functorch 中实现的 FunctionalModule 相比,速度不慢,甚至更快。

文档

访问 PyTorch 的全面的开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源