RemoveEmptySpecs¶
- class torchrl.envs.transforms.RemoveEmptySpecs(in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, in_keys_inv: Optional[Sequence[NestedKey]] = None, out_keys_inv: Optional[Sequence[NestedKey]] = None)[source]¶
从环境中移除空规格和内容。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import Unbounded, Composite, ... Categorical >>> from torchrl.envs import EnvBase, TransformedEnv, RemoveEmptySpecs >>> >>> >>> class DummyEnv(EnvBase): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) ... self.observation_spec = Composite( ... observation=UnboundedContinuous((*self.batch_size, 3)), ... other=Composite( ... another_other=Composite(shape=self.batch_size), ... shape=self.batch_size, ... ), ... shape=self.batch_size, ... ) ... self.action_spec = UnboundedContinuous((*self.batch_size, 3)) ... self.done_spec = Categorical( ... 2, (*self.batch_size, 1), dtype=torch.bool ... ) ... self.full_done_spec["truncated"] = self.full_done_spec[ ... "terminated"].clone() ... self.reward_spec = Composite( ... reward=UnboundedContinuous(*self.batch_size, 1), ... other_reward=Composite(shape=self.batch_size), ... shape=self.batch_size ... ) ... ... def _reset(self, tensordict): ... return self.observation_spec.rand().update(self.full_done_spec.zero()) ... ... def _step(self, tensordict): ... return TensorDict( ... {}, ... batch_size=[] ... ).update(self.observation_spec.rand()).update( ... self.full_done_spec.zero() ... ).update(self.full_reward_spec.rand()) ... ... def _set_seed(self, seed): ... return seed + 1 >>> >>> >>> base_env = DummyEnv() >>> print(base_env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), other: TensorDict( fields={ another_other: TensorDict( fields={ }, batch_size=torch.Size([2]), device=cpu, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), other_reward: TensorDict( fields={ }, batch_size=torch.Size([2]), device=cpu, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) >>> check_env_specs(base_env) >>> env = TransformedEnv(base_env, RemoveEmptySpecs()) >>> print(env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) check_env_specs(env)
- forward(tensordict: TensorDictBase) TensorDictBase ¶
读取输入的 tensordict,并为选定的键应用转换。
- transform_input_spec(input_spec: TensorSpec) TensorSpec [source]¶
转换输入规格,使结果规格与转换映射匹配。
- 参数:
input_spec (TensorSpec) – 转换前的规格
- 返回:
转换后预期的规格
- transform_output_spec(output_spec: Composite) Composite [source]¶
转换输出规格,使结果规格与转换映射匹配。
此方法通常应保持不变。更改应使用
transform_observation_spec()
、transform_reward_spec()
和transformfull_done_spec()
来实现。 :param output_spec: 转换前的规格 :type output_spec: TensorSpec- 返回:
转换后预期的规格