RemoveEmptySpecs¶
- class torchrl.envs.transforms.RemoveEmptySpecs(in_keys: Sequence[NestedKey] = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None)[source]¶
从环境中删除空规范和内容。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, ... DiscreteTensorSpec >>> from torchrl.envs import EnvBase, TransformedEnv, RemoveEmptySpecs >>> >>> >>> class DummyEnv(EnvBase): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) ... self.observation_spec = CompositeSpec( ... observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)), ... other=CompositeSpec( ... another_other=CompositeSpec(shape=self.batch_size), ... shape=self.batch_size, ... ), ... shape=self.batch_size, ... ) ... self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3)) ... self.done_spec = DiscreteTensorSpec( ... 2, (*self.batch_size, 1), dtype=torch.bool ... ) ... self.full_done_spec["truncated"] = self.full_done_spec[ ... "terminated"].clone() ... self.reward_spec = CompositeSpec( ... reward=UnboundedContinuousTensorSpec(*self.batch_size, 1), ... other_reward=CompositeSpec(shape=self.batch_size), ... shape=self.batch_size ... ) ... ... def _reset(self, tensordict): ... return self.observation_spec.rand().update(self.full_done_spec.zero()) ... ... def _step(self, tensordict): ... return TensorDict( ... {}, ... batch_size=[] ... ).update(self.observation_spec.rand()).update( ... self.full_done_spec.zero() ... ).update(self.full_reward_spec.rand()) ... ... def _set_seed(self, seed): ... return seed + 1 >>> >>> >>> base_env = DummyEnv() >>> print(base_env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), other: TensorDict( fields={ another_other: TensorDict( fields={ }, batch_size=torch.Size([2]), device=cpu, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), other_reward: TensorDict( fields={ }, batch_size=torch.Size([2]), device=cpu, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) >>> check_env_specs(base_env) >>> env = TransformedEnv(base_env, RemoveEmptySpecs()) >>> print(env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) check_env_specs(env)
- forward(tensordict: TensorDictBase) TensorDictBase ¶
读取输入 tensordict,并对选定的键应用转换。
- transform_input_spec(input_spec: TensorSpec) TensorSpec [source]¶
转换输入规范,以使结果规范匹配转换映射。
- 参数:
input_spec (TensorSpec) – 转换之前的规范
- 返回值:
转换后的预期规范
- transform_output_spec(output_spec: CompositeSpec) CompositeSpec [source]¶
转换输出规范,以使结果规范匹配转换映射。
此方法通常应保持不变。应使用
transform_observation_spec()
、transform_reward_spec()
和transformfull_done_spec()
来实现更改。 :param output_spec: 转换之前的规范 :type output_spec: TensorSpec- 返回值:
转换后的预期规范