快捷方式

split_trajectories

torchrl.collectors.utils.split_trajectories(rollout_tensordict: TensorDictBase, *, prefix=None, trajectory_key: NestedKey | None = None, done_key: NestedKey | None = None, as_nested: bool = False) TensorDictBase[source]

一个用于轨迹分离的工具函数。

接收一个 tensordict,其中包含一个 traj_ids 键,该键指示每个轨迹的 ID。

在此基础上,构建一个 B x T x … 零填充的 tensordict,其中 B 为批次大小,T 为最大持续时间。

参数:

rollout_tensordict (TensorDictBase) – 沿最后一个维度包含相邻轨迹的 rollout。

关键字参数:
  • prefix (NestedKey, optional) – 用于读取和写入元数据的前缀,例如 "traj_ids"(每个轨迹的可选整数 ID)以及指示哪些数据有效、哪些无效的 "mask" 条目。如果输入包含 "collector" 条目,则默认为 "collector",否则为 ()(无前缀)。 prefix 作为遗留功能保留,最终将被弃用。尽可能优先使用 trajectory_keydone_key

  • trajectory_key (NestedKey, optional) – 指向轨迹 ID 的键。覆盖 done_keyprefix。如果未提供,则默认为 (prefix, "traj_ids")

  • done_key (NestedKey, optional) – 指向 "done" 信号的键,如果无法直接恢复轨迹。默认为 "done"

  • as_nested (bool or torch.layout, optional) –

    是否将结果作为嵌套张量返回。默认为 False。如果提供了 torch.layout,将使用它来构建嵌套张量,否则将使用默认布局。

    注意

    使用 split_trajectories(tensordict, as_nested=True).to_padded_tensor(mask=mask_key) 应该会得到与 as_nested=False 完全相同的结果。由于这是一个实验性功能,并且依赖于 nested_tensors,其 API 未来可能会更改,因此我们将其设为可选功能。当 as_nested=True 时,运行时应该更快。

    注意

    提供布局允许用户控制嵌套张量是使用 torch.strided 布局还是 torch.jagged 布局。尽管在撰写本文时前者具有稍微更多的功能,但后者因其与 compile() 更好的兼容性,未来将成为 PyTorch 团队的主要关注点。

返回值:

一个新的 tensordict,其前导维度对应于轨迹。同时添加了一个 "mask" 布尔值条目,它共享 trajectory_key 前缀和 tensordict 形状,并指示 tensordict 中的有效元素。如果找不到 trajectory_key,则还会添加一个 "traj_ids" 条目。

示例

>>> from tensordict import TensorDict
>>> import torch
>>> from torchrl.collectors.utils import split_trajectories
>>> obs = torch.cat([torch.arange(10), torch.arange(5)])
>>> obs_ = torch.cat([torch.arange(1, 11), torch.arange(1, 6)])
>>> done = torch.zeros(15, dtype=torch.bool)
>>> done[9] = True
>>> trajectory_id = torch.cat([torch.zeros(10, dtype=torch.int32),
...     torch.ones(5, dtype=torch.int32)])
>>> data = TensorDict({"obs": obs, ("next", "obs"): obs_, ("next", "done"): done, "trajectory": trajectory_id}, batch_size=[15])
>>> data_split = split_trajectories(data, done_key="done")
>>> print(data_split)
TensorDict(
    fields={
        mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([2, 10]),
            device=None,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        traj_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)},
    batch_size=torch.Size([2, 10]),
    device=None,
    is_shared=False)
>>> # check that split_trajectories got the trajectories right with the done signal
>>> assert (data_split["traj_ids"] == data_split["trajectory"]).all()
>>> print(data_split["mask"])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True, False, False, False, False, False]])
>>> data_split = split_trajectories(data, trajectory_key="trajectory")
>>> print(data_split)
TensorDict(
    fields={
        mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([2, 10]),
            device=None,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)},
    batch_size=torch.Size([2, 10]),
    device=None,
    is_shared=False)

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源