快捷方式

tensordict.nn 包

tensordict.nn 包使得能够在 ML 管道中灵活地使用 TensorDict。

由于 TensorDict 将代码的一部分转换为基于键的结构,因此现在可以使用这些键作为钩子来构建复杂的图结构。基本构建块是 TensorDictModule,它使用输入和输出键列表包装 torch.nn.Module 实例

>>> from torch.nn import Transformer
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> import torch
>>> module = TensorDictModule(Transformer(), in_keys=["feature", "target"], out_keys=["prediction"])
>>> data = TensorDict({"feature": torch.randn(10, 11, 512), "target": torch.randn(10, 11, 512)}, [10, 11])
>>> data = module(data)
>>> print(data)
TensorDict(
    fields={
        feature: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
        prediction: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
        target: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32)},
    batch_size=torch.Size([10, 11]),
    device=None,
    is_shared=False)

不需要一定使用 TensorDictModule,一个带有输入和输出键的有序列表(命名为 module.in_keysmodule.out_keys)的自定义 torch.nn.Module 就足够了。

许多 PyTorch 用户的一个关键痛点是 nn.Sequential 无法处理具有多个输入的模块。使用基于键的图可以轻松解决此问题,因为序列中的每个节点都知道需要读取哪些数据以及在哪里写入数据。

为此,我们提供了 TensorDictSequential 类,它将数据通过一系列 TensorDictModules 传递。序列中的每个模块都从原始 TensorDict 中获取输入并向其写入输出,这意味着序列中的模块可以忽略其前驱的输出,或者根据需要从 tensordict 中获取其他输入。这是一个示例

>>> from tensordict.nn import TensorDictSequential
>>> class Net(nn.Module):
...     def __init__(self, input_size=100, hidden_size=50, output_size=10):
...         super().__init__()
...         self.fc1 = nn.Linear(input_size, hidden_size)
...         self.fc2 = nn.Linear(hidden_size, output_size)
...
...     def forward(self, x):
...         x = torch.relu(self.fc1(x))
...         return self.fc2(x)
...
>>> class Masker(nn.Module):
...     def forward(self, x, mask):
...         return torch.softmax(x * mask, dim=1)
...
>>> net = TensorDictModule(
...     Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
...     Masker(),
...     in_keys=[("intermediate", "x"), ("input", "mask")],
...     out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>>
>>> td = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> td = module(td)
>>> print(td)
TensorDict(
    fields={
        input: TensorDict(
            fields={
                mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
                x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        intermediate: TensorDict(
            fields={
                x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        output: TensorDict(
            fields={
                probabilities: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)

我们还可以通过 select_subsequence() 方法轻松选择子图

>>> sub_module = module.select_subsequence(out_keys=[("intermediate", "x")])
>>> td = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> sub_module(td)
>>> print(td)  # the "output" has not been computed
TensorDict(
    fields={
        input: TensorDict(
            fields={
                mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
                x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        intermediate: TensorDict(
            fields={
                x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)

最后,tensordict.nn 带有一个 ProbabilisticTensorDictModule,它允许从网络输出构建分布并从中获取汇总统计信息或样本(以及分布参数)

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.nn.distributions import NormalParamWrapper
>>> from tensordict.nn.functional_modules import make_functional
>>> from tensordict.nn.prototype import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
... )
>>> from torch.distributions import Normal
>>> td = TensorDict(
...     {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.GRUCell(4, 8)
>>> module = TensorDictModule(
...     NormalParamWrapper(net), in_keys=["input", "hidden"], out_keys=["loc", "scale"]
... )
>>> prob_module = ProbabilisticTensorDictModule(
...     in_keys=["loc", "scale"],
...     out_keys=["sample"],
...     distribution_class=Normal,
...     return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(module, prob_module)
>>> td_module(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),
        input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        sample_log_prob: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

TensorDictModuleBase(*args, **kwargs)

TensorDict 模块的基类。

TensorDictModule(*args, **kwargs)

TensorDictModule 是 nn.Module 的 Python 包装器,用于读取和写入 TensorDict。

ProbabilisticTensorDictModule(*args, **kwargs)

一个概率 TD 模块。

TensorDictSequential(*args, **kwargs)

一系列 TensorDictModules。

TensorDictModuleWrapper(*args, **kwargs)

TensorDictModule 对象的包装器类。

函数式

tensordict 包与大多数 functorch 功能兼容。我们还提供了一个专门的函数式 API,它利用 tensordict 的优势来处理函数式程序中的参数。

make_functional() 方法会将模块转换为函数式模块。模块将在原地修改,并返回包含模块参数的 tensordict.TensorDict。此 tensordict 的结构完全反映了模型的结构。在以下示例中,我们展示了

  1. make_functional() 提取模块的参数;

  2. 这些参数的结构与模型的结构完全匹配(尽管可以使用 params.flatten_keys(".") 将其展平)。

  3. 它将模块及其所有子模块转换为函数式。

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> params = make_functional(model)
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> intermediate = model[0](x, params["0"])
>>> out2 = model[1](intermediate, params["1"])
>>> torch.testing.assert_close(out, out2)

或者,也可以使用以下方法构造参数

>>> params = TensorDict({name: param for name, param in model.named_parameters()}, []).unflatten_keys(".")
>>> params = TensorDict(model.state_dict(), [])  # provided that the state_dict() just returns params and buffer tensors

与 functorch 的做法不同,make_functional() 在高级别上不区分参数和缓冲区(它们都打包在一起)。

注意

Tensordict 函数式模块可以通过多种方式使用,参数可以作为参数或关键字参数传递。

>>> params = make_functional(model)
>>> model(input_td, params)
>>> # alternatively
>>> model(input_td, params=params)

但是,目前这将无法工作

>>> get_functional(model)
>>> model(input_td, params)  # breaks!
>>> model(input_td, params=params)  # works

因为 get_functional() 会使用其参数重新填充模块,我们依靠关键字参数 "params" 作为函数式调用的签名。

get_functional(module[, funs_to_decorate])

将 nn.Module 原地转换为函数式模块,并返回此模块的状态版本,该版本可用于函数式设置。

is_functional(module)

检查是否已对模块调用 make_functional()

make_functional(module[, funs_to_decorate, ...])

将 nn.Module 原地转换为函数式模块,并返回其参数。

repopulate_module(model, tensordict)

使用作为嵌套 TensorDict 提供的参数重新填充模块。

集成

函数式方法能够实现简单的集成。我们可以使用 tensordict.nn.EnsembleModule 复制和重新初始化模型副本

>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 10]),
    device=None,
    is_shared=False)

EnsembleModule(*args, **kwargs)

包装模块并重复该模块以形成集成的模块。

编译 TensorDictModules

从 v0.5 版本开始,TensorDict 组件与 compile() 兼容。例如,一个 TensorDictSequential 模块可以使用 torch.compile 进行编译,并达到与常规 PyTorch 模块封装在 TensorDictModule 中类似的运行时性能。

分布

AddStateIndependentNormalScale(scale_shape)

一个 nn.Module,添加可训练的与状态无关的尺度参数。

CompositeDistribution(params, ...[, ...])

分布的组合。

Delta(param[, atol, rtol, batch_shape, ...])

Delta 分布。

OneHotCategorical([logits, probs])

独热分类分布。

TruncatedNormal(loc, scale, a, b[, ...])

截断正态分布。

实用程序

make_tensordict([input_dict, batch_size, device])

根据关键字参数或输入字典创建一个 TensorDict。

dispatch([separator, source, dest, ...])

允许使用 kwargs 调用期望 TensorDict 的函数。

set_interaction_type([type])

将所有 ProbabilisticTDModules 采样设置为所需的类型。

inv_softplus(bias)

逆 softplus 函数。

biased_softplus(bias[, min_val])

带偏置的 softplus 模块。

set_skip_existing([mode, in_key_attr, ...])

用于跳过 TensorDict 图中现有节点的上下文管理器。

skip_existing()

返回是否应该由模块重新计算 tensordict 中的现有条目。

TensorDictParams(parameters, *[, ...])

保存一个包含参数的 TensorDictBase 实例。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源