快捷方式

ActionMask

class torchrl.envs.transforms.ActionMask(action_key: NestedKey = 'action', mask_key: NestedKey = 'action_mask')[源代码]

自适应动作掩码器。

此转换在执行步骤后从输入 tensordict 读取掩码,并调整单热/分类动作规范的掩码。

注意

此转换在没有环境的情况下使用时将失败。

参数:
  • action_key (NestedKey, 可选) – 可以在其中找到动作张量的键。默认为 "action"

  • mask_key (NestedKey, 可选) – 可以在其中找到动作掩码的键。默认为 "action_mask"

示例

>>> import torch
>>> from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec
>>> from torchrl.envs.transforms import ActionMask, TransformedEnv
>>> from torchrl.envs.common import EnvBase
>>> class MaskedEnv(EnvBase):
...     def __init__(self, *args, **kwargs):
...         super().__init__(*args, **kwargs)
...         self.action_spec = DiscreteTensorSpec(4)
...         self.state_spec = CompositeSpec(action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool))
...         self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3))
...         self.reward_spec = UnboundedContinuousTensorSpec(1)
...
...     def _reset(self, tensordict=None):
...         td = self.observation_spec.rand()
...         td.update(torch.ones_like(self.state_spec.rand()))
...         return td
...
...     def _step(self, data):
...         td = self.observation_spec.rand()
...         mask = data.get("action_mask")
...         action = data.get("action")
...         mask = mask.scatter(-1, action.unsqueeze(-1), 0)
...
...         td.set("action_mask", mask)
...         td.set("reward", self.reward_spec.rand())
...         td.set("done", ~mask.any().view(1))
...         return td
...
...     def _set_seed(self, seed):
...         return seed
...
>>> torch.manual_seed(0)
>>> base_env = MaskedEnv()
>>> env = TransformedEnv(base_env, ActionMask())
>>> r = env.rollout(10)
>>> env = TransformedEnv(base_env, ActionMask())
>>> r = env.rollout(10)
>>> r["action_mask"]
tensor([[ True,  True,  True,  True],
        [ True,  True, False,  True],
        [ True,  True, False, False],
        [ True, False, False, False]])
forward(tensordict: TensorDictBase) TensorDictBase[源代码]

读取输入 tensordict,并对选定的键应用转换。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源