快捷方式

OpenSpielEnv

torchrl.envs.OpenSpielEnv(*args, **kwargs)[source]

使用游戏字符串构建的 Google DeepMind OpenSpiel 环境包装器。

GitHub: https://github.com/google-deepmind/open_spiel

文档: https://openspiel.readthedocs.io/en/latest/index.html

参数:

game_string (str) – 要包装的游戏的名称。必须是 available_envs 的一部分。

关键字参数:
  • device (torch.device, optional) – 如果提供,则为数据要转换到的设备。默认为 None

  • batch_size (torch.Size, optional) – 环境的批大小。默认为 torch.Size([])

  • allow_done_after_reset (bool, optional) – 如果为 True,则允许环境在调用 reset() 后立即变为 done。默认为 False

  • group_map (MarlGroupMapTypeDict[str, List[str]]], optional) – 如何在 tensordict 中对智能体进行分组以进行输入/输出。有关更多信息,请参阅 MarlGroupMapType。默认为 ALL_IN_ONE_GROUP

  • categorical_actions (bool, optional) – 如果为 True,则分类规格将转换为 TorchRL 等效项 (torchrl.data.Categorical),否则将使用 one-hot 编码 (torchrl.data.OneHot)。默认为 False

  • return_state (bool, optional) – 如果为 True,则 “state” 包含在 reset()step() 的输出中。状态可以传递给 reset() 以重置到该状态,而不是重置到初始状态。默认为 False

变量:

available_envs – 可用于构建的环境

示例

>>> from torchrl.envs import OpenSpielEnv
>>> from tensordict import TensorDict
>>> env = OpenSpielEnv("chess", return_state=True)
>>> td = env.reset()
>>> td = env.step(env.full_action_spec.rand())
>>> print(td)
TensorDict(
    fields={
        agents: TensorDict(
            fields={
                action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        next: TensorDict(
            fields={
                agents: TensorDict(
                    fields={
                        observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
                        reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2]),
                    device=None,
                    is_shared=False),
                current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
                674
                , batch_size=torch.Size([]), device=None),
                terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(env.available_envs)
['2048', 'add_noise', 'amazons', 'backgammon', ...]

reset() 可以恢复特定状态,而不是初始状态,只要 return_state=True

>>> from torchrl.envs import OpenSpielEnv
>>> from tensordict import TensorDict
>>> env = OpenSpielEnv("chess", return_state=True)
>>> td = env.reset()
>>> td = env.step(env.full_action_spec.rand())
>>> td_restore = td["next"]
>>> td = env.step(env.full_action_spec.rand())
>>> # Current state is not equal `td_restore`
>>> (td["next"] == td_restore).all()
False
>>> td = env.reset(td_restore)
>>> # After resetting, now the current state is equal to `td_restore`
>>> (td == td_restore).all()
True

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源