TimeMaxPool¶
- class torchrl.envs.transforms.TimeMaxPool(in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, T: int = 1, reset_key: Optional[NestedKey] = None)[source]¶
在最近 T 次观测中,获取每个位置的最大值。
此转换获取所有 in_keys 张量在最近 T 个时间步长内每个位置的最大值。
- 参数:
in_keys (NestedKey 序列, 可选) – 将应用最大池化的输入键。如果留空,则默认为“observation”。
out_keys (NestedKey 序列, 可选) – 将写入输出的输出键。如果留空,则默认为 in_keys。
T (int, 可选) – 应用最大池化的时间步数。
reset_key (NestedKey, 可选) – 用作部分重置指示器的重置键。必须是唯一的。如果未提供,则默认为父环境的唯一重置键(如果只有一个),否则会引发异常。
示例
>>> from torchrl.envs import GymEnv >>> base_env = GymEnv("Pendulum-v1") >>> env = TransformedEnv(base_env, TimeMaxPool(in_keys=["observation"], T=10)) >>> torch.manual_seed(0) >>> env.set_seed(0) >>> rollout = env.rollout(10) >>> print(rollout["observation"]) # values should be increasing up until the 10th step tensor([[ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0216, 0.0000], [ 0.0000, 0.1149, 0.0000], [ 0.0000, 0.1990, 0.0000], [ 0.0000, 0.2749, 0.0000], [ 0.0000, 0.3281, 0.0000], [-0.9290, 0.3702, -0.8978]])
注意
TimeMaxPool
当前仅支持根级别的done
信号。嵌套的done
,例如在 MARL 设置中发现的那些,目前不受支持。如果需要此功能,请在 TorchRL repo 上提出 issue。- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
转换观察 spec,使结果 spec 与转换映射匹配。
- 参数:
observation_spec (TensorSpec) – 转换前的 spec
- 返回:
转换后预期的 spec