快捷方式

TensorDictSequential

class tensordict.nn.TensorDictSequential(*args, **kwargs)

一系列 TensorDictModules。

类似于 nn.Sequence,它将张量通过一系列映射传递,每个映射读取和写入一个张量,此模块将通过查询每个输入模块来读取和写入 tensordict。当使用函数式模块调用 TensorDictSequencial 实例时,预期参数列表(和缓冲区)将在单个列表中连接。

参数:
  • modules (TensorDictModules 的可迭代对象) – 要按顺序运行的 TensorDictModule 实例的有序序列。

  • partial_tolerant (bool, 可选) – 如果为 True,则输入 tensordict 可能会缺少一些输入键。如果是这样,则只会执行那些在给定存在的键的情况下可以执行的模块。此外,如果输入 tensordict 是 tensordicts 的延迟堆栈,并且如果 partial_tolerant 为 True 并且如果堆栈没有所需的键,则 TensorDictSequential 将扫描子 tensordicts 以查找具有所需键的 tensordicts(如果有)。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> torch.manual_seed(0)
>>> module = TensorDictSequential(
...     TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["x+1"]),
...     TensorDictModule(nn.Linear(3, 4), in_keys=["x+1"], out_keys=["w*(x+1)+b"]),
... )
>>> # with tensordict input
>>> print(module(TensorDict({"x": torch.zeros(3)}, [])))
TensorDict(
    fields={
        w*(x+1)+b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
        x+1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> # with tensor input: returns all the output keys in the order of the modules, ie "x+1" and "w*(x+1)+b"
>>> module(x=torch.zeros(3))
(tensor([1., 1., 1.]), tensor([-0.7214, -0.8748,  0.1571, -0.1138], grad_fn=<AddBackward0>))
>>> module(torch.zeros(3))
(tensor([1., 1., 1.]), tensor([-0.7214, -0.8748,  0.1571, -0.1138], grad_fn=<AddBackward0>))

TensorDictSequence 支持函数式、模块化和 vmap 编码:.. rubric:: 示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
...     TensorDictModule,
...     TensorDictSequential,
... )
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch.distributions import Normal
>>> td = TensorDict({"input": torch.randn(3, 4)}, [3,])
>>> net1 = torch.nn.Linear(4, 8)
>>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["params"])
>>> normal_params = TensorDictModule(
...      NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"]
...  )
>>> td_module1 = ProbabilisticTensorDictSequential(
...     module1,
...     normal_params,
...     ProbabilisticTensorDictModule(
...         in_keys=["loc", "scale"],
...         out_keys=["hidden"],
...         distribution_class=Normal,
...         return_log_prob=True,
...     )
... )
>>> module2 = torch.nn.Linear(4, 8)
>>> td_module2 = TensorDictModule(
...    module=module2, in_keys=["hidden"], out_keys=["output"]
... )
>>> td_module = TensorDictSequential(td_module1, td_module2)
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
...     _ = td_module(td)
>>> print(td)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
在 vmap 案例中
>>> from torch import vmap
>>> params = params.expand(4)
>>> def func(td, params):
...     with params.to_module(td_module):
...         return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td, params)
>>> print(td_vmap)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase, tensordict_out: TensorDictBase | None = None, **kwargs: Any) TensorDictBase

当 tensordict 参数未设置时,kwargs 用于创建 TensorDict 的实例。

select_subsequence(in_keys: Iterable[NestedKey] | None = None, out_keys: Iterable[NestedKey] | None = None) TensorDictSequential

返回一个新的 TensorDictSequential,其中只包含根据给定的输入键和输出键计算给定输出键所需的模块。

参数:
  • in_keys – 我们要选择的子序列的输入键。in_keys 中不存在的所有键都将被视为不相关,并且仅以这些键作为输入的模块将被丢弃。生成的顺序模块将遵循模式“所有其输出将受任何 in <in_keys> 的不同值影响的模块”。如果未提供,则假定为模块的 in_keys

  • out_keys – 我们要选择的子序列的输出键。生成的序列中只会包含获取 out_keys 所必需的模块。生成的顺序模块将遵循模式“所有影响 <out_keys> 条目值或值的模块”。如果未提供,则假定为模块的 out_keys

返回:

一个新的 TensorDictSequential,其中只包含根据给定的输入键和输出键所需的模块。

示例

>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> idn = lambda x: x
>>> module = Seq(
...     Mod(idn, in_keys=["a"], out_keys=["b"]),
...     Mod(idn, in_keys=["b"], out_keys=["c"]),
...     Mod(idn, in_keys=["c"], out_keys=["d"]),
...     Mod(idn, in_keys=["a"], out_keys=["e"]),
... )
>>> # select all modules whose output depend on "a"
>>> module.select_subsequence(in_keys=["a"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['b'])
      (1): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['b'],
          out_keys=['c'])
      (2): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['c'],
          out_keys=['d'])
      (3): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['e'])
    ),
    device=cpu,
    in_keys=['a'],
    out_keys=['b', 'c', 'd', 'e'])
>>> # select all modules whose output depend on "c"
>>> module.select_subsequence(in_keys=["c"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['c'],
          out_keys=['d'])
    ),
    device=cpu,
    in_keys=['c'],
    out_keys=['d'])
>>> # select all modules that affect the value of "c"
>>> module.select_subsequence(out_keys=["c"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['b'])
      (1): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['b'],
          out_keys=['c'])
    ),
    device=cpu,
    in_keys=['a'],
    out_keys=['b', 'c'])
>>> # select all modules that affect the value of "e"
>>> module.select_subsequence(out_keys=["e"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['e'])
    ),
    device=cpu,
    in_keys=['a'],
    out_keys=['e'])

此方法会传播到嵌套顺序

>>> module = Seq(
...     Seq(
...         Mod(idn, in_keys=["a"], out_keys=["b"]),
...         Mod(idn, in_keys=["b"], out_keys=["c"]),
...     ),
...     Seq(
...         Mod(idn, in_keys=["b"], out_keys=["d"]),
...         Mod(idn, in_keys=["d"], out_keys=["e"]),
...     ),
... )
>>> # select submodules whose output will be affected by a change in "b" or "d" AND which output is "e"
>>> module.select_subsequence(in_keys=["b", "d"], out_keys=["e"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictSequential(
          module=ModuleList(
            (0): TensorDictModule(
                module=<function <lambda> at 0x129efae50>,
                device=cpu,
                in_keys=['b'],
                out_keys=['d'])
            (1): TensorDictModule(
                module=<function <lambda> at 0x129efae50>,
                device=cpu,
                in_keys=['d'],
                out_keys=['e'])
          ),
          device=cpu,
          in_keys=['b'],
          out_keys=['d', 'e'])
    ),
    device=cpu,
    in_keys=['b'],
    out_keys=['d', 'e'])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源