SafeSequential¶
- class torchrl.modules.tensordict_module.SafeSequential(*args, **kwargs)[源代码]¶
TensorDictModule 的安全序列。
类似于
nn.Sequence
,它将张量通过一系列映射传递,每个映射读取和写入单个张量,此模块将通过查询每个输入模块来读取和写入 tensordict。当使用函数式模块调用TensorDictSequencial
实例时,预计参数列表(和缓冲区)将连接到单个列表中。- 参数:
modules (TensorDictModule 的可迭代对象) – 按顺序运行的 TensorDictModule 实例的有序序列。
partial_tolerant (bool, 可选) – 如果为
True
,则输入 tensordict 可能缺少某些输入键。如果是这样,则仅执行那些在给定存在的键的情况下可以执行的模块。此外,如果输入 tensordict 是 tensordict 的延迟堆栈,并且如果 partial_tolerant 为True
,并且如果堆栈没有所需的键,则 SafeSequential 将扫描子 tensordict 以查找具有所需键的 tensordict(如果有)。
TensorDictSequence 支持函数式、模块化和 vmap 编码:.. rubric:: 示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec >>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamExtractor >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) >>> spec1 = CompositeSpec(hidden=UnboundedContinuousTensorSpec(4), loc=None, scale=None) >>> net1 = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["loc", "scale"]) >>> td_module1 = SafeProbabilisticModule( ... module=module1, ... spec=spec1, ... in_keys=["loc", "scale"], ... out_keys=["hidden"], ... distribution_class=TanhNormal, ... return_log_prob=True, ... ) >>> spec2 = UnboundedContinuousTensorSpec(8) >>> module2 = torch.nn.Linear(4, 8) >>> td_module2 = TensorDictModule( ... module=module2, ... spec=spec2, ... in_keys=["hidden"], ... out_keys=["output"], ... ) >>> td_module = SafeSequential(td_module1, td_module2) >>> params = TensorDict.from_module(td_module) >>> with params.to_module(td_module): ... td_module(td) >>> print(td) TensorDict( fields={ hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), input: Tensor(torch.Size([3, 4]), dtype=torch.float32), loc: Tensor(torch.Size([3, 4]), dtype=torch.float32), output: Tensor(torch.Size([3, 8]), dtype=torch.float32), sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # The module spec aggregates all the input specs: >>> print(td_module.spec) CompositeSpec( hidden: UnboundedContinuousTensorSpec( shape=torch.Size([4]), space=None, device=cpu, dtype=torch.float32, domain=continuous), loc: None, scale: None, output: UnboundedContinuousTensorSpec( shape=torch.Size([8]), space=None, device=cpu, dtype=torch.float32, domain=continuous))
- 在 vmap 案例中
>>> from torch import vmap >>> params = params.expand(4, *params.shape) >>> td_vmap = vmap(td_module, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ hidden: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), loc: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32), sample_log_prob: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32), scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False)