快捷方式

TensorDictModule

class tensordict.nn.TensorDictModule(*args, **kwargs)

TensorDictModule 是一个 Python 包装器,它包装了 nn.Module,用于从 TensorDict 中读取和写入数据。

参数:
  • module (Callable) – 一个可调用对象,通常是 torch.nn.Module,用于将输入映射到输出参数空间。其 forward 方法可以返回单个张量、张量元组甚至字典。在后一种情况下,TensorDictModule 的输出键将用于填充输出 tensordict(即 out_keys 中存在的键应存在于 module forward 方法返回的字典中)。

  • in_keys (iterable of NestedKeys, Dict[NestedStr, str]) – 要从输入 tensordict 中读取并传递给 module 的键。如果它包含多个元素,这些值将按照 in_keys 可迭代对象给出的顺序传递。如果 in_keys 是一个字典,它的键必须对应于 tensordict 中要读取的键,而它的值必须匹配函数签名中的关键字参数名称。如果 out_to_in_mapTrue,则映射被反转,使得键对应于函数签名中的关键字参数。

  • out_keys (iterable of str) – 要写入输入 tensordict 的键。out_keys 的长度必须与嵌入模块返回的张量数量匹配。使用 “_” 作为键可以避免将张量写入输出。

关键字参数:
  • out_to_in_map (bool, optional) –

    如果为 True,则 in_keys 被读取时,其键被视为 forward() 方法的参数键,值则是输入 TensorDict 中的键。如果为 FalseNone(默认),则键被视为输入键,值被视为方法的参数键。

    警告

    out_to_in_map 的默认值将在 v0.9 版本中从 False 更改为 True

  • inplace (bool or string, optional) –

    如果为 True(默认),模块的输出将写入提供给 forward() 方法的 tensordict 中。如果为 False,则会创建一个新的 TensorDict 实例,其批大小为空且没有设备。如果为 "empty",将使用 empty() 创建输出 tensordict。

    注意

    如果 inplace=False 并且传递给模块的 tensordict 是 TensorDict 以外的 TensorDictBase 子类,输出仍将是 TensorDict 实例。它的批大小将为空,并且没有设备。设置为 "empty" 可以获得相同的 TensorDictBase 子类型、相同的批大小和设备。在运行时使用 tensordict_out(见下文)可以对输出进行更细粒度的控制。

    注意

    如果 inplace=False 并且 tensordict_out 被传递给 forward() 方法,则 tensordict_out 将优先。这是获取传递给模块的 tensordict 是 TensorDictBase 子类(而不是 TensorDict)的 tensordict_out 的方法,输出仍将是 TensorDict 实例。

在 TensorDictModule 中嵌入神经网络只需要指定输入和输出键。TensorDictModule 支持函数式和常规的 nn.Module 对象。在函数式情况下,必须指定 'params'(和 'buffers')关键字参数

示例

>>> from tensordict import TensorDict
>>> # one can wrap regular nn.Module
>>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"])
>>> input = torch.ones(2, 3, 128)
>>> tgt = torch.zeros(2, 3, 128)
>>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3])
>>> data = module(data)
>>> print(data)
TensorDict(
    fields={
        input: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        out: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        tgt: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 3]),
    device=None,
    is_shared=False)

我们也可以直接传递张量

示例

>>> out = module(input, tgt)
>>> assert out.shape == input.shape
>>> # we can also wrap regular functions
>>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")])
>>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[]))
TensorDict(
    fields={
        input: TensorDict(
            fields={
                x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        output: TensorDict(
            fields={
                x+1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                x-1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

我们可以使用 TensorDictModule 来填充 tensordict

示例

>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"])
>>> print(module(TensorDict({}, batch_size=[])))
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

另一个特性是传递字典作为输入键,以控制值到特定关键字参数的分派。

示例

>>> module = TensorDictModule(lambda x, *, y: x+y,
...     in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False
...     )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['z']
tensor(3.)

如果 out_to_in_map 设置为 True,则 in_keys 映射将被反转。这样,同一个输入键可以用于不同的关键字参数。

示例

>>> module = TensorDictModule(lambda x, *, y, z: x+y+z,
...     in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True
...     )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['t']
tensor(5.)

tensordict 模块的函数式调用很简单

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,])
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
...    module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> params = TensorDict.from_module(td_module)
>>> # functional API
>>> with params.to_module(td_module):
...     td_functional = td_module(td.clone())
>>> print(td_functional)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
在有状态情况下
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
...    module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> td_stateful = td_module(td.clone())
>>> print(td_stateful)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase = None, args=None, *, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase

当未设置 tensordict 参数时,kwargs 用于创建 TensorDict 实例。

文档

查阅 PyTorch 开发者完整文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源