split_trajectories¶
- torchrl.collectors.utils.split_trajectories(rollout_tensordict: TensorDictBase, *, prefix=None, trajectory_key: tensordict._nestedkey.NestedKey | None = None, done_key: tensordict._nestedkey.NestedKey | None = None, as_nested: bool = False) TensorDictBase [来源]¶
用于轨迹分离的实用函数。
接受一个带有键 traj_ids 的 tensordict,该键指示每个轨迹的 id。
由此构建一个 B x T x … 零填充的 tensordict,其中 B 批次,最大持续时间为 T
- 参数:
rollout_tensordict (TensorDictBase) – 一个沿最后一个维度具有相邻轨迹的 rollout。
- 关键词参数:
prefix (NestedKey, 可选) – 用于读取和写入元数据的前缀,例如
"traj_ids"
(每个轨迹的可选整数 id)和指示哪些数据有效、哪些数据无效的"mask"
条目。如果输入具有"collector"
条目,则默认为"collector"
,否则默认为()
(无前缀)。prefix
保留为遗留功能,最终将被弃用。尽可能首选trajectory_key
或done_key
。trajectory_key (NestedKey, 可选) – 指向轨迹 id 的键。取代
done_key
和prefix
。如果未提供,则默认为(prefix, "traj_ids")
。done_key (NestedKey, 可选) – 指向
"done"
信号的键,如果无法直接恢复轨迹。默认为"done"
。as_nested (bool 或 torch.layout, 可选) –
是否将结果作为嵌套张量返回。默认为
False
。如果提供了torch.layout
,它将用于构造嵌套张量,否则将使用默认布局。注意
使用
split_trajectories(tensordict, as_nested=True).to_padded_tensor(mask=mask_key)
应该得到与as_nested=False
完全相同的结果。由于这是一个实验性功能,并且依赖于 nested_tensors,其 API 将来可能会更改,因此我们将其设为可选功能。使用as_nested=True
时,运行时应该更快。注意
提供布局让用户控制嵌套张量是与
torch.strided
还是torch.jagged
布局一起使用。虽然前者在撰写本文时具有稍多的功能,但后者将是 PyTorch 团队未来的主要关注点,因为它与compile()
具有更好的兼容性。
- 返回:
一个新的 tensordict,其前导维度对应于轨迹。还添加了一个
"mask"
布尔条目,共享trajectory_key
前缀和 tensordict 形状。它指示 tensordict 的有效元素,以及如果找不到trajectory_key
,则指示"traj_ids"
条目。
示例
>>> from tensordict import TensorDict >>> import torch >>> from torchrl.collectors.utils import split_trajectories >>> obs = torch.cat([torch.arange(10), torch.arange(5)]) >>> obs_ = torch.cat([torch.arange(1, 11), torch.arange(1, 6)]) >>> done = torch.zeros(15, dtype=torch.bool) >>> done[9] = True >>> trajectory_id = torch.cat([torch.zeros(10, dtype=torch.int32), ... torch.ones(5, dtype=torch.int32)]) >>> data = TensorDict({"obs": obs, ("next", "obs"): obs_, ("next", "done"): done, "trajectory": trajectory_id}, batch_size=[15]) >>> data_split = split_trajectories(data, done_key="done") >>> print(data_split) TensorDict( fields={ mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), traj_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False) >>> # check that split_trajectories got the trajectories right with the done signal >>> assert (data_split["traj_ids"] == data_split["trajectory"]).all() >>> print(data_split["mask"]) tensor([[ True, True, True, True, True, True, True, True, True, True], [ True, True, True, True, True, False, False, False, False, False]]) >>> data_split = split_trajectories(data, trajectory_key="trajectory") >>> print(data_split) TensorDict( fields={ mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False)