MCTSForest¶
- class torchrl.data.MCTSForest(*, data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, max_size: int | None = None, done_keys: List[NestedKey] | None = None, reward_keys: List[NestedKey] = None, observation_keys: List[NestedKey] = None, action_keys: List[NestedKey] = None, excluded_keys: List[NestedKey] = None, consolidated: bool | None = None)[源文件]¶
MCTS 树的集合。
警告
此类目前正处于积极开发中。API 可能会频繁更改,请注意。
此类旨在将 rollout 存储在 storage 中,并基于该数据集中给定的根节点生成树。
- 关键词参数:
data_map (TensorDictMap, 可选) – 用于存储数据(观测、奖励、状态等)的 storage。如果未提供,将使用
observation_keys
和action_keys
的列表作为in_keys
,通过from_tensordict_pair()
进行懒加载初始化。node_map (TensorDictMap, 可选) – 将观测空间映射到索引空间的 map。在内部,node map 用于收集从给定节点发出的所有可能的 branches。例如,如果一个观测在 data map 中有两个相关的 actions 和 outcomes,那么
node_map
将返回一个数据结构,其中包含data_map
中与这两个 outcomes 对应的两个索引。如果未提供,将使用observation_keys
的列表作为in_keys
,并使用QueryModule
作为out_keys
,通过from_tensordict_pair()
进行懒加载初始化。max_size (int, 可选) – maps 的大小。如果未提供,则默认为
data_map.max_size
(如果可找到),然后是node_map.max_size
。如果这些都未提供,则默认为 1000。done_keys (NestedKey 列表, 可选) – 环境的 done keys。如果未提供,则默认为
("done", "terminated", "truncated")
。可以使用get_keys_from_env()
自动确定 keys。action_keys (NestedKey 列表, 可选) – 环境的 action keys。如果未提供,则默认为
("action",)
。可以使用get_keys_from_env()
自动确定 keys。reward_keys (NestedKey 列表, 可选) – 环境的 reward keys。如果未提供,则默认为
("reward",)
。可以使用get_keys_from_env()
自动确定 keys。observation_keys (NestedKey 列表, 可选) – 环境的 observation keys。如果未提供,则默认为
("observation",)
。可以使用get_keys_from_env()
自动确定 keys。excluded_keys (NestedKey 列表, 可选) – 要从数据 storage 中排除的 keys 列表。
consolidated (bool, 可选) – 如果为
True
,则 data_map storage 将在磁盘上进行 consolidated。默认为False
。
示例
>>> from torchrl.envs import GymEnv >>> import torch >>> from tensordict import TensorDict, LazyStackedTensorDict >>> from torchrl.data import TensorDictMap, ListStorage >>> from torchrl.data.map.tree import MCTSForest >>> >>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter >>> # Create the MCTS Forest >>> forest = MCTSForest() >>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle) >>> env = PendulumEnv() >>> obs_keys = list(env.observation_spec.keys(True, True)) >>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys) >>> # Appending transforms to get an "observation" key that concatenates the observations together >>> env = env.append_transform( ... UnsqueezeTransform( ... in_keys=obs_keys, ... out_keys=[("unsqueeze", key) for key in obs_keys], ... dim=-1 ... ) ... ) >>> env = env.append_transform( ... CatTensors([("unsqueeze", key) for key in obs_keys], "observation") ... ) >>> env = env.append_transform(StepCounter()) >>> env.set_seed(0) >>> # Get a reset state, then make a rollout out of it >>> reset_state = env.reset() >>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) >>> # Append the rollout to the forest. We're removing the state entries for clarity >>> rollout0 = rollout0.copy() >>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout0) >>> # The forest should have 6 elements (the length of the rollout) >>> assert len(forest) == 6 >>> # Let's make another rollout from the same reset state >>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) >>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout1) >>> assert len(forest) == 12 >>> # Let's make another final rollout from an intermediate step in the second rollout >>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next")) >>> rollout1b.exclude(*state_keys, inplace=True) >>> rollout1b.get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout1b) >>> assert len(forest) == 18 >>> # Since we have 2 rollouts starting at the same state, our tree should have two >>> # branches if we produce it from the reset entry. Take the state, and call `get_tree`: >>> r = rollout0[0] >>> # Let's get the compact tree that follows the initial reset. A compact tree is >>> # a tree where nodes that have a single child are collapsed. >>> tree = forest.get_tree(r) >>> print(tree.max_length()) 2 >>> print(list(tree.valid_paths())) [(0,), (1, 0), (1, 1)] >>> from tensordict import assert_close >>> # We can manually rebuild the tree >>> assert_close( ... rollout1, ... torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]), ... intersection=True, ... ) True >>> # Or we can rebuild it using the dedicated method >>> assert_close( ... rollout1, ... tree.rollout_from_path((1, 0)), ... intersection=True, ... ) True >>> tree.plot() >>> tree = forest.get_tree(r, compact=False) >>> print(tree.max_length()) 9 >>> print(list(tree.valid_paths())) [(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)] >>> assert_close( ... rollout1, ... tree.rollout_from_path((1, 0, 0, 0, 0, 0)), ... intersection=True, ... ) True
- property action_keys: List[NestedKey]¶
Action Keys。
返回用于从环境输入中检索 actions 的 keys。默认的 action key 是“action”。
- 返回:
表示 action keys 的字符串或元组列表。
- property done_keys: List[NestedKey]¶
Done Keys。
返回用于指示 episode 已结束的 keys。默认的 done keys 是“done”、“terminated”和“truncated”。这些 keys 可用于环境的输出中,以指示 episode 的结束。
- 返回:
表示 done keys 的字符串列表。
- extend(rollout, *, return_node: bool = False)[源文件]¶
将一个 rollout 添加到 forest。
仅在 rollout 相互分歧的点和 rollout 的终点将节点添加到树中。
如果不存在与 rollout 的前几个步骤匹配的现有树,则添加一个新的树。仅为最后一步创建一个节点。
如果存在匹配的现有树,则将 rollout 添加到该树中。如果在某个步骤中 rollout 与树中所有其他 rollout 分歧,则在 rollout 分歧的步骤之前创建一个新节点,并为 rollout 的最后一步创建一个叶节点。如果 rollout 的所有步骤都与之前添加的 rollout 匹配,则没有任何改变。如果 rollout 匹配到树的叶节点,但继续超出该节点,则该节点会扩展到 rollout 的末尾,并且不会创建新的节点。
- 参数:
rollout (TensorDict) – 要添加到 forest 的 rollout。
return_node (bool, 可选) – 如果为
True
,该方法将返回添加的节点。默认为False
。
- 返回:
- 添加到 forest 的节点。仅当
return_node 为 True 时才返回。
- 返回类型:
示例
>>> from torchrl.data import MCTSForest >>> from tensordict import TensorDict >>> import torch >>> forest = MCTSForest() >>> r0 = TensorDict({ ... 'action': torch.tensor([1, 2, 3, 4, 5]), ... 'next': {'observation': torch.tensor([123, 392, 989, 809, 847])}, ... 'observation': torch.tensor([ 0, 123, 392, 989, 809]) ... }, [5]) >>> r1 = TensorDict({ ... 'action': torch.tensor([1, 2, 6, 7]), ... 'next': {'observation': torch.tensor([123, 392, 235, 38])}, ... 'observation': torch.tensor([ 0, 123, 392, 235]) ... }, [4]) >>> td_root = r0[0].exclude("next") >>> forest.extend(r0) >>> forest.extend(r1) >>> tree = forest.get_tree(td_root) >>> print(tree) Tree( count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False), index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), node_data=TensorDict( fields={ observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None), rollout=TensorDict( fields={ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), subtree=Tree( _parent=NonTensorStack( [<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x..., batch_size=torch.Size([2]), device=None), count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False), hash=NonTensorStack( [4341220243998689835, 6745467818783115365], batch_size=torch.Size([2]), device=None), node_data=LazyStackedTensorDict( fields={ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=cpu, is_shared=False, stack_dim=0), node_id=NonTensorStack( [1, 2], batch_size=torch.Size([2]), device=None), rollout=LazyStackedTensorDict( fields={ action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False), next: LazyStackedTensorDict( fields={ observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, -1]), device=cpu, is_shared=False, stack_dim=0), observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, -1]), device=cpu, is_shared=False, stack_dim=0), wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), index=None, subtree=None, specs=None, batch_size=torch.Size([2]), device=None, is_shared=False), wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), hash=None, _parent=None, specs=None, batch_size=torch.Size([]), device=None, is_shared=False)
- get_keys_from_env(env: EnvBase)[源文件]¶
给定一个环境,将缺失的 done、action 和 reward keys 写入 Forest。
现有 keys 不会被覆盖。
- property observation_keys: List[NestedKey]¶
Observation Keys。
返回用于从环境输出中检索 observations 的 keys。默认的 observation key 是“observation”。
- 返回:
表示 observation keys 的字符串或元组列表。
- property reward_keys: List[NestedKey]¶
Reward Keys。
返回用于从环境输出中检索 rewards 的 keys。默认的 reward key 是“reward”。
- 返回:
表示 reward keys 的字符串或元组列表。