快捷方式

CatTensors

class torchrl.envs.transforms.CatTensors(in_keys: 序列[NestedKey] | None = None, out_key: NestedKey = 'observation_vector', dim: int = - 1, *, del_keys: bool = True, unsqueeze_if_oor: bool = False, sort: bool = True)[源]

将多个键连接成一个张量。

当多个键描述单个状态时(例如,“observation_position” 和 “observation_velocity”),这尤其有用。

参数:
  • in_keys (嵌套键序列) – 要连接的键。如果为 None(或未提供),则在首次使用此变换时将从父环境中检索这些键。此行为仅在设置了父环境时有效。

  • out_key (NestedKey) – 结果张量的键。

  • dim (int, 可选) – 进行连接的维度。默认为 -1

关键字参数:
  • del_keys (bool, 可选) – 如果为 True,输入值将在连接后被删除。默认为 True

  • unsqueeze_if_oor (bool, 可选) – 如果为 True,CatTensor 将检查要连接的张量是否存在指定维度。如果不存在,则将沿该维度对张量进行 unsqueeze(扩展维度)。默认为 False

  • sort (bool, 可选) – 如果为 True,键将在变换中进行排序。否则,将使用用户提供的顺序。默认为 True

示例

>>> transform = CatTensors(in_keys=["key1", "key2"])
>>> td = TensorDict({"key1": torch.zeros(1, 1),
...     "key2": torch.ones(1, 1)}, [1])
>>> _ = transform(td)
>>> print(td.get("observation_vector"))
tensor([[0., 1.]])
>>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True)
>>> td = TensorDict({"key1": torch.zeros(1),
...     "key2": torch.ones(1)}, [])
>>> _ = transform(td)
>>> print(td.get("observation_vector").shape)
torch.Size([2, 1])
forward(tensordict: TensorDictBase) TensorDictBase

读取输入的 tensordict,并对选定的键应用变换。

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[源]

变换观察规范,使结果规范与变换映射匹配。

参数:

observation_spec (TensorSpec) – 变换前的规范

返回值:

变换后预期的规范

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源