快捷方式

CatFrames

class torchrl.envs.transforms.CatFrames(N: int, dim: int, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, padding='same', padding_value=0, as_inverse=False, reset_key: NestedKey | None = None, done_key: NestedKey | None = None)[source]

将连续的观测帧连接成一个张量。

例如,这可以解释观察到的特征的运动/速度。在“使用深度强化学习玩 Atari”中提出(https://arxiv.org/pdf/1312.5602.pdf)。

当在转换后的环境中使用时,CatFrames 是一个有状态的类,可以通过调用 reset() 方法将其重置为其本机状态。此方法接受包含 "_reset" 条目的 tensordict,该条目指示要重置哪个缓冲区。

参数:
  • N (int) – 要连接的观测数量。

  • dim (int) – 连接观测的维度。应该是负数,以确保它与不同 batch_size 的环境兼容。

  • in_keys (嵌套键序列, 可选) – 指向要连接的帧的键。默认为 [“pixels”]。

  • out_keys (嵌套键序列, 可选) – 指向要写入输出的键。默认为 in_keys 的值。

  • padding (str, 可选) – 填充方法。 "same""constant" 之一。默认为 "same",即使用第一个值进行填充。

  • padding_value (float, 可选) – 如果 padding="constant",则要使用的填充值。默认为 0。

  • as_inverse (bool, 可选) – 如果为 True,则将变换应用为逆变换。默认为 False

  • reset_key (NestedKey, 可选) – 用作部分重置指示器的重置键。必须是唯一的。如果未提供,则默认为父环境的唯一重置键(如果只有一个),否则会引发异常。

  • done_key (NestedKey, 可选) – 用作部分完成指示器的完成键。必须是唯一的。如果未提供,则默认为 "done"

示例

>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv('Pendulum-v1'),
...     Compose(
...         UnsqueezeTransform(-1, in_keys=["observation"]),
...         CatFrames(N=4, dim=-1, in_keys=["observation"]),
...     )
... )
>>> print(env.rollout(3))

CatFrames 变换也可以离线使用,以在不同的规模上重现在线帧连接的效果(或为了限制内存消耗)。以下示例给出了完整的图景,以及 torchrl.data.ReplayBuffer 的用法

示例

>>> from torchrl.envs.utils import RandomPolicy        >>> from torchrl.envs import UnsqueezeTransform, CatFrames
>>> from torchrl.collectors import SyncDataCollector
>>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension
>>> env = TransformedEnv(
...     GymEnv("CartPole-v1", from_pixels=True),
...     Compose(
...         ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
...         Resize(in_keys=["pixels_trsf"], w=64, h=64),
...         GrayScale(in_keys=["pixels_trsf"]),
...         UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
...         CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
...     )
... )
>>> # we design a collector
>>> collector = SyncDataCollector(
...     env,
...     RandomPolicy(env.action_spec),
...     frames_per_batch=10,
...     total_frames=1000,
... )
>>> for data in collector:
...     print(data)
...     break
>>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here.
>>> # however, we need to point to both the pixel entry at the root and at the next levels:
>>> t = Compose(
...         ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]),
...         Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
...         GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
...         CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
... )
>>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
>>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16)
>>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
>>> rb.add(data_exclude)
>>> s = rb.sample(1) # the buffer has only one element
>>> # let's check that our sample is the same as the batch collected during inference
>>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all()

注意

CatFrames 目前仅支持根部的 "done" 信号。嵌套的 done(例如在 MARL 设置中发现的)目前不支持。如果需要此功能,请在 TorchRL 存储库中提出问题。

forward(tensordict: TensorDictBase) TensorDictBase[source]

读取输入 tensordict,并针对选定的键应用变换。

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

变换观察规范,使其与变换映射匹配。

参数:

observation_spec (TensorSpec) – 变换前的规范

返回:

变换后的预期规范

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

为初学者和高级开发人员提供深入的教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源