快捷方式

SafeSequential

class torchrl.modules.tensordict_module.SafeSequential(*args, **kwargs)[源代码]

TensorDictModule 的安全序列。

类似于 nn.Sequence,它将张量通过一系列映射传递,每个映射读取和写入单个张量,此模块将通过查询每个输入模块来读取和写入 tensordict。当使用函数式模块调用 TensorDictSequencial 实例时,预计参数列表(和缓冲区)将连接到单个列表中。

参数:
  • modules (TensorDictModule 的可迭代对象) – 按顺序运行的 TensorDictModule 实例的有序序列。

  • partial_tolerant (bool, 可选) – 如果为 True,则输入 tensordict 可能缺少某些输入键。如果是这样,则仅执行那些在给定存在的键的情况下可以执行的模块。此外,如果输入 tensordict 是 tensordict 的延迟堆栈,并且如果 partial_tolerant 为 True,并且如果堆栈没有所需的键,则 SafeSequential 将扫描子 tensordict 以查找具有所需键的 tensordict(如果有)。

TensorDictSequence 支持函数式、模块化和 vmap 编码:.. rubric:: 示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
>>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamExtractor
>>> from torchrl.modules.tensordict_module import SafeProbabilisticModule
>>> td = TensorDict({"input": torch.randn(3, 4)}, [3,])
>>> spec1 = CompositeSpec(hidden=UnboundedContinuousTensorSpec(4), loc=None, scale=None)
>>> net1 = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor())
>>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["loc", "scale"])
>>> td_module1 = SafeProbabilisticModule(
...     module=module1,
...     spec=spec1,
...     in_keys=["loc", "scale"],
...     out_keys=["hidden"],
...     distribution_class=TanhNormal,
...     return_log_prob=True,
... )
>>> spec2 = UnboundedContinuousTensorSpec(8)
>>> module2 = torch.nn.Linear(4, 8)
>>> td_module2 = TensorDictModule(
...    module=module2,
...    spec=spec2,
...    in_keys=["hidden"],
...    out_keys=["output"],
...    )
>>> td_module = SafeSequential(td_module1, td_module2)
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
...     td_module(td)
>>> print(td)
TensorDict(
    fields={
        hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        output: Tensor(torch.Size([3, 8]), dtype=torch.float32),
        sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),
        scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> # The module spec aggregates all the input specs:
>>> print(td_module.spec)
CompositeSpec(
    hidden: UnboundedContinuousTensorSpec(
        shape=torch.Size([4]), space=None, device=cpu, dtype=torch.float32, domain=continuous),
    loc: None,
    scale: None,
    output: UnboundedContinuousTensorSpec(
        shape=torch.Size([8]), space=None, device=cpu, dtype=torch.float32, domain=continuous))
在 vmap 案例中
>>> from torch import vmap
>>> params = params.expand(4, *params.shape)
>>> td_vmap = vmap(td_module, (None, 0))(td, params)
>>> print(td_vmap)
TensorDict(
    fields={
        hidden: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
        input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
        loc: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
        output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32),
        sample_log_prob: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32),
        scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源