CatFrames¶
- class torchrl.envs.transforms.CatFrames(N: int, dim: int, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, padding='same', padding_value=0, as_inverse=False, reset_key: NestedKey | None = None, done_key: NestedKey | None = None)[source]¶
将连续的观测帧连接成一个张量。
例如,这可以解释观察到的特征的运动/速度。在“使用深度强化学习玩 Atari”中提出(https://arxiv.org/pdf/1312.5602.pdf)。
当在转换后的环境中使用时,
CatFrames
是一个有状态的类,可以通过调用reset()
方法将其重置为其本机状态。此方法接受包含"_reset"
条目的 tensordict,该条目指示要重置哪个缓冲区。- 参数:
N (int) – 要连接的观测数量。
dim (int) – 连接观测的维度。应该是负数,以确保它与不同 batch_size 的环境兼容。
in_keys (嵌套键序列, 可选) – 指向要连接的帧的键。默认为 [“pixels”]。
out_keys (嵌套键序列, 可选) – 指向要写入输出的键。默认为 in_keys 的值。
padding (str, 可选) – 填充方法。
"same"
或"constant"
之一。默认为"same"
,即使用第一个值进行填充。padding_value (float, 可选) – 如果
padding="constant"
,则要使用的填充值。默认为 0。as_inverse (bool, 可选) – 如果为
True
,则将变换应用为逆变换。默认为False
。reset_key (NestedKey, 可选) – 用作部分重置指示器的重置键。必须是唯一的。如果未提供,则默认为父环境的唯一重置键(如果只有一个),否则会引发异常。
done_key (NestedKey, 可选) – 用作部分完成指示器的完成键。必须是唯一的。如果未提供,则默认为
"done"
。
示例
>>> from torchrl.envs.libs.gym import GymEnv >>> env = TransformedEnv(GymEnv('Pendulum-v1'), ... Compose( ... UnsqueezeTransform(-1, in_keys=["observation"]), ... CatFrames(N=4, dim=-1, in_keys=["observation"]), ... ) ... ) >>> print(env.rollout(3))
CatFrames
变换也可以离线使用,以在不同的规模上重现在线帧连接的效果(或为了限制内存消耗)。以下示例给出了完整的图景,以及torchrl.data.ReplayBuffer
的用法示例
>>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.envs import UnsqueezeTransform, CatFrames >>> from torchrl.collectors import SyncDataCollector >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( ... GymEnv("CartPole-v1", from_pixels=True), ... Compose( ... ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]), ... Resize(in_keys=["pixels_trsf"], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf"]), ... UnsqueezeTransform(-4, in_keys=["pixels_trsf"]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]), ... ) ... ) >>> # we design a collector >>> collector = SyncDataCollector( ... env, ... RandomPolicy(env.action_spec), ... frames_per_batch=10, ... total_frames=1000, ... ) >>> for data in collector: ... print(data) ... break >>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here. >>> # however, we need to point to both the pixel entry at the root and at the next levels: >>> t = Compose( ... ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... ) >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) >>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) >>> rb.add(data_exclude) >>> s = rb.sample(1) # the buffer has only one element >>> # let's check that our sample is the same as the batch collected during inference >>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all()
注意
CatFrames
目前仅支持根部的"done"
信号。嵌套的done
(例如在 MARL 设置中发现的)目前不支持。如果需要此功能,请在 TorchRL 存储库中提出问题。- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
变换观察规范,使其与变换映射匹配。
- 参数:
observation_spec (TensorSpec) – 变换前的规范
- 返回:
变换后的预期规范