ActionMask¶
- class torchrl.envs.transforms.ActionMask(action_key: NestedKey = 'action', mask_key: NestedKey = 'action_mask')[源代码]¶
自适应动作掩码器。
此转换在执行步骤后从输入 tensordict 读取掩码,并调整单热/分类动作规范的掩码。
注意
此转换在没有环境的情况下使用时将失败。
- 参数:
action_key (NestedKey, 可选) – 可以在其中找到动作张量的键。默认为
"action"
。mask_key (NestedKey, 可选) – 可以在其中找到动作掩码的键。默认为
"action_mask"
。
示例
>>> import torch >>> from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec >>> from torchrl.envs.transforms import ActionMask, TransformedEnv >>> from torchrl.envs.common import EnvBase >>> class MaskedEnv(EnvBase): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) ... self.action_spec = DiscreteTensorSpec(4) ... self.state_spec = CompositeSpec(action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool)) ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3)) ... self.reward_spec = UnboundedContinuousTensorSpec(1) ... ... def _reset(self, tensordict=None): ... td = self.observation_spec.rand() ... td.update(torch.ones_like(self.state_spec.rand())) ... return td ... ... def _step(self, data): ... td = self.observation_spec.rand() ... mask = data.get("action_mask") ... action = data.get("action") ... mask = mask.scatter(-1, action.unsqueeze(-1), 0) ... ... td.set("action_mask", mask) ... td.set("reward", self.reward_spec.rand()) ... td.set("done", ~mask.any().view(1)) ... return td ... ... def _set_seed(self, seed): ... return seed ... >>> torch.manual_seed(0) >>> base_env = MaskedEnv() >>> env = TransformedEnv(base_env, ActionMask()) >>> r = env.rollout(10) >>> env = TransformedEnv(base_env, ActionMask()) >>> r = env.rollout(10) >>> r["action_mask"] tensor([[ True, True, True, True], [ True, True, False, True], [ True, True, False, False], [ True, False, False, False]])