CatTensors¶
- class torchrl.envs.transforms.CatTensors(in_keys: Optional[Sequence[NestedKey]] = None, out_key: NestedKey = 'observation_vector', dim: int = - 1, *, del_keys: bool = True, unsqueeze_if_oor: bool = False, sort: bool = True)[source]¶
将多个键连接成一个张量。
当多个键描述单个状态时(例如“observation_position”和“observation_velocity”),此功能特别有用
- 参数:
in_keys (NestedKey 序列) – 要连接的键。如果为 None(或未提供),则将在首次使用变换时从父环境检索键。此行为仅在设置了父环境时有效。
out_key (NestedKey) – 结果张量的键。
dim (int, 可选) – 连接将沿其发生的维度。默认为
-1
。
- 关键字参数:
del_keys (bool, 可选) – 如果为
True
,则输入值将在连接后删除。默认为True
。unsqueeze_if_oor (bool, 可选) – 如果为
True
,CatTensor 将检查指示的维度是否存在于要连接的张量中。如果不存在,则张量将沿该维度取消挤压。默认为False
。sort (bool, 可选) – 如果为
True
,则键将在变换中排序。否则,用户提供的顺序将优先。默认为True
。
示例
>>> transform = CatTensors(in_keys=["key1", "key2"]) >>> td = TensorDict({"key1": torch.zeros(1, 1), ... "key2": torch.ones(1, 1)}, [1]) >>> _ = transform(td) >>> print(td.get("observation_vector")) tensor([[0., 1.]]) >>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True) >>> td = TensorDict({"key1": torch.zeros(1), ... "key2": torch.ones(1)}, []) >>> _ = transform(td) >>> print(td.get("observation_vector").shape) torch.Size([2, 1])
- forward(tensordict: TensorDictBase) TensorDictBase ¶
读取输入 tensordict,并针对选定的键应用变换。
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
变换观察规范,使结果规范与变换映射匹配。
- 参数:
observation_spec (TensorSpec) – 变换前的规范
- 返回:
变换后预期的规范