TensorDictSequential¶
- class tensordict.nn.TensorDictSequential(*args, **kwargs)¶
TensorDictModules 的序列。
类似于
nn.Sequence
,它通过一个链式映射传递张量,每个映射读取并写入一个张量,此模块将通过查询每个输入模块来读写 tensordict。当使用函数式模块调用TensorDictSequencial
实例时,参数列表(和缓冲区)预计会被连接到单个列表中。- 参数:
模块 (OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]] | List[Callable[[TensorDictBase], TensorDictBase]]) – 有序的可调用对象序列,它们以 TensorDictBase 作为输入并返回 TensorDictBase。这些可以是 TensorDictModuleBase 的实例,或任何符合此签名的其他函数。请注意,如果使用了非 TensorDictModuleBase 的可调用对象,其输入和输出键将不会被跟踪,因此不会影响 TensorDictSequential 的 in_keys 和 out_keys 属性。常规的
dict
输入如有必要将被转换为OrderedDict
。- 关键字参数:
partial_tolerant (bool, 可选) – 如果为 True,输入的 tensordict 可以缺少某些输入键。在这种情况下,将只执行那些根据现有键可以执行的模块。此外,如果输入的 tensordict 是 tensordict 的惰性堆叠,并且 partial_tolerant 为
True
,并且堆叠中缺少必需的键,那么 TensorDictSequential 将扫描子 tensordict,查找是否存在具有必需键的子 tensordict。默认为 False。selected_out_keys (嵌套键的可迭代对象, 可选) – 要选择的输出键列表。如果未提供,将写入所有
out_keys
。
注意
一个
TensorDictSequential
实例可能有很多输出键,出于清晰性或内存目的,可能希望在执行后移除其中一些键。如果出现这种情况,可以在实例化后使用方法select_out_keys()
,或者将 selected_out_keys 传递给构造函数。示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> torch.manual_seed(0) >>> module = TensorDictSequential( ... TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["x+1"]), ... TensorDictModule(nn.Linear(3, 4), in_keys=["x+1"], out_keys=["w*(x+1)+b"]), ... ) >>> # with tensordict input >>> print(module(TensorDict({"x": torch.zeros(3)}, []))) TensorDict( fields={ w*(x+1)+b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), x+1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> # with tensor input: returns all the output keys in the order of the modules, ie "x+1" and "w*(x+1)+b" >>> module(x=torch.zeros(3)) (tensor([1., 1., 1.]), tensor([-0.7214, -0.8748, 0.1571, -0.1138], grad_fn=<AddBackward0>)) >>> module(torch.zeros(3)) (tensor([1., 1., 1.]), tensor([-0.7214, -0.8748, 0.1571, -0.1138], grad_fn=<AddBackward0>))
TensorDictSequence 支持函数式、模块化和 vmap 编程。
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ( ... ProbabilisticTensorDictModule, ... ProbabilisticTensorDictSequential, ... TensorDictModule, ... TensorDictSequential, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) >>> net1 = torch.nn.Linear(4, 8) >>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["params"]) >>> normal_params = TensorDictModule( ... NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"] ... ) >>> td_module1 = ProbabilisticTensorDictSequential( ... module1, ... normal_params, ... ProbabilisticTensorDictModule( ... in_keys=["loc", "scale"], ... out_keys=["hidden"], ... distribution_class=Normal, ... return_log_prob=True, ... ) ... ) >>> module2 = torch.nn.Linear(4, 8) >>> td_module2 = TensorDictModule( ... module=module2, in_keys=["hidden"], out_keys=["output"] ... ) >>> td_module = TensorDictSequential(td_module1, td_module2) >>> params = TensorDict.from_module(td_module) >>> with params.to_module(td_module): ... _ = td_module(td) >>> print(td) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- 在 vmap 的情况下
>>> from torch import vmap >>> params = params.expand(4) >>> def func(td, params): ... with params.to_module(td_module): ... return td_module(td) >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ hidden: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False)
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase ¶
如果未设置 tensordict 参数,则使用 kwargs 创建 TensorDict 实例。
- reset_out_keys()¶
将
out_keys
属性重置为其原始值。返回值:同一个模块,其
out_keys
值恢复为原始值。示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.reset_out_keys() >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- select_out_keys(*selected_out_keys) TensorDictSequential ¶
选择将在输出 tensordict 中找到的键。
这在想要移除复杂图中的中间键,或者当这些键的存在可能引发意外行为时很有用。
原始的
out_keys
仍然可以通过module.out_keys_source
访问。- 参数:
*out_keys (字符串序列 或 字符串元组) – 应在输出 tensordict 中找到的输出键。
返回值:同一个模块,已就地修改并更新了
out_keys
。最简单的用法是结合
TensorDictModule
使用。示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
此功能也适用于分派的参数: 示例
>>> mod(torch.zeros(()), torch.ones(())) tensor(2.)
此更改将就地发生 (即返回同一个模块,但
out_keys
列表已更新)。可以使用TensorDictModuleBase.reset_out_keys()
方法恢复此更改。示例
>>> mod.reset_out_keys() >>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
这也适用于其他类,例如 Sequential: 示例
>>> from tensordict.nn import TensorDictSequential >>> seq = TensorDictSequential( ... TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]), ... TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]), ... ) >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> seq.select_out_keys("z") >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- select_subsequence(in_keys: Optional[Iterable[NestedKey]] = None, out_keys: Optional[Iterable[NestedKey]] = None) TensorDictSequential ¶
返回一个新的 TensorDictSequential,其中仅包含计算给定输入键的给定输出键所需的模块。
- 参数:
in_keys – 我们要选择的子序列的输入键。所有不在
in_keys
中的键将被视为不相关,并且 *仅* 将这些键作为输入的模块将被丢弃。生成的 sequential 模块将遵循模式“所有模块的输出会因任何在中的键的不同值而受到影响”。如果未提供,则假定使用模块的 in_keys
。out_keys – 我们要选择的子序列的输出键。生成的序列中将只包含获取
out_keys
所必需的模块。生成的 sequential 模块将遵循模式“所有对条目的值构成条件的模块。”如果未提供,则假定使用模块的 out_keys
。
- 返回值:
一个新的 TensorDictSequential,其中仅包含根据给定的输入和输出键所需的模块。
示例
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> idn = lambda x: x >>> module = Seq( ... Mod(idn, in_keys=["a"], out_keys=["b"]), ... Mod(idn, in_keys=["b"], out_keys=["c"]), ... Mod(idn, in_keys=["c"], out_keys=["d"]), ... Mod(idn, in_keys=["a"], out_keys=["e"]), ... ) >>> # select all modules whose output depend on "a" >>> module.select_subsequence(in_keys=["a"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['b']) (1): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['b'], out_keys=['c']) (2): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['c'], out_keys=['d']) (3): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['e']) ), device=cpu, in_keys=['a'], out_keys=['b', 'c', 'd', 'e']) >>> # select all modules whose output depend on "c" >>> module.select_subsequence(in_keys=["c"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['c'], out_keys=['d']) ), device=cpu, in_keys=['c'], out_keys=['d']) >>> # select all modules that affect the value of "c" >>> module.select_subsequence(out_keys=["c"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['b']) (1): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['b'], out_keys=['c']) ), device=cpu, in_keys=['a'], out_keys=['b', 'c']) >>> # select all modules that affect the value of "e" >>> module.select_subsequence(out_keys=["e"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['e']) ), device=cpu, in_keys=['a'], out_keys=['e'])
此方法会传播到嵌套的 sequential
>>> module = Seq( ... Seq( ... Mod(idn, in_keys=["a"], out_keys=["b"]), ... Mod(idn, in_keys=["b"], out_keys=["c"]), ... ), ... Seq( ... Mod(idn, in_keys=["b"], out_keys=["d"]), ... Mod(idn, in_keys=["d"], out_keys=["e"]), ... ), ... ) >>> # select submodules whose output will be affected by a change in "b" or "d" AND which output is "e" >>> module.select_subsequence(in_keys=["b", "d"], out_keys=["e"]) TensorDictSequential( module=ModuleList( (0): TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x129efae50>, device=cpu, in_keys=['b'], out_keys=['d']) (1): TensorDictModule( module=<function <lambda> at 0x129efae50>, device=cpu, in_keys=['d'], out_keys=['e']) ), device=cpu, in_keys=['b'], out_keys=['d', 'e']) ), device=cpu, in_keys=['b'], out_keys=['d', 'e'])