Tree¶
- class torchrl.data.Tree(count: 'int | torch.Tensor' = None, wins: 'int | torch.Tensor' = None, index: 'torch.Tensor | None' = None, hash: 'int | None' = None, node_id: 'int | None' = None, rollout: 'TensorDict | None' = None, node_data: 'TensorDict | None' = None, subtree: "'Tree'" = None, _parent: 'weakref.ref | List[weakref.ref] | None' = None, specs: 'Composite | None' = None, *, batch_size, device=None, names=None)[source]¶
- property branching_action: torch.Tensor | TensorDictBase | None¶
返回由此特定节点分支出的 Action。
- 返回:
如果节点没有父节点,则返回 Tensor、TensorDict 或 None。
另请参阅
当 Rollout 数据包含单个 Step 时,这将等于
prev_action
。另请参阅
所有与树中给定节点(或 Observation)关联的 Action
.
- dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
将 TensorDict 保存到磁盘。
此函数是
memmap()
的代理。
- edges() List[Tuple[int, int]] [source]¶
获取树中的 Edge 列表。
每个 Edge 表示为两个节点 ID 的元组:父节点 ID 和子节点 ID。树使用广度优先搜索 (BFS) 遍历,以确保访问所有 Edge。
- 返回:
一个元组列表,其中每个元组包含一个父节点 ID 和一个子节点 ID。
- classmethod fields()¶
返回描述此 Dataclass 字段的元组。
接受 Dataclass 或其实例。元组元素类型为 Field。
- classmethod from_tensordict(tensordict, non_tensordict=None, safe=True)¶
用于实例化新 Tensor 类别对象的 Tensor 类别包装器。
- 参数:
tensordict (TensorDict) – Tensor 类型的字典
non_tensordict (dict) – 包含非 Tensor 和嵌套 Tensor 类别对象的字典
- property full_action_spec¶
树的 Action Spec。
这是 Tree.specs[‘input_spec’, ‘full_action_spec’] 的别名。
- property full_done_spec¶
树的 Done Spec。
这是 Tree.specs[‘output_spec’, ‘full_done_spec’] 的别名。
- property full_observation_spec¶
树的 Observation Spec。
这是 Tree.specs[‘output_spec’, ‘full_observation_spec’] 的别名。
- property full_reward_spec¶
树的 Reward Spec。
这是 Tree.specs[‘output_spec’, ‘full_reward_spec’] 的别名。
- property full_state_spec¶
树的 State Spec。
这是 Tree.specs[‘input_spec’, ‘full_state_spec’] 的别名。
- get(key: NestedKey, *args, **kwargs)¶
获取使用输入 Key 存储的值。
- 参数:
key (str, tuple of str) – 要查询的 Key。如果是 str 元组,则等效于 getattr 的链式调用。
default – 如果在 Tensorclass 中找不到 Key 的默认值。
- 返回:
使用输入 Key 存储的值
- property is_terminal: bool | torch.Tensor¶
如果树没有子节点,则返回 True。
- classmethod load(prefix: str | Path, *args, **kwargs) T ¶
从磁盘加载 TensorDict。
此类方法是
load_memmap()
的代理。
- load_(prefix: str | Path, *args, **kwargs)¶
将 TensorDict 从磁盘加载到当前 TensorDict 中。
此类方法是
load_memmap_()
的代理。
- classmethod load_memmap(prefix: str | Path, device: torch.device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) T ¶
从磁盘加载 Memory-mapped TensorDict。
- 参数:
prefix (str or Path to folder) – 应从中获取已保存 TensorDict 的文件夹路径。
device (torch.device or equivalent, optional) – 如果提供,数据将异步地转换为该 Device。支持 “meta” Device,在这种情况下数据不会被加载,而是创建一组空的“meta” Tensor。这对于了解总模型大小和结构而无需实际打开任何文件很有用。
non_blocking (bool, optional) – 如果为
True
,在 Device 上加载 Tensor 后不会调用 synchronize。默认为False
。out (TensorDictBase, optional) – 可选的 TensorDict,数据应写入其中。
示例
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
此方法还支持加载嵌套的 TensorDict。
示例
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
TensorDict 也可以加载到“meta” Device 上,或者作为 Fake Tensor 加载。
示例
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)¶
尝试将 State_dict 就地加载到目标 Tensorclass 上。
- classmethod make_node(data: TensorDictBase, *, device: torch.device | None = None, batch_size: torch.Size | None = None, specs: Composite | None = None) Tree [source]¶
根据给定数据创建一个新节点。
- memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T ¶
将所有 Tensor 写入新 TensorDict 中相应的 Memory-mapped Tensor。
- 参数:
prefix (str) – Memory-mapped Tensor 存储所在的目录前缀。目录树结构将模仿 TensorDict 的结构。
copy_existing (bool) – 如果为 False(默认),如果 TensorDict 中的某个条目已经是存储在磁盘上的 Tensor 且具有关联文件,但未按照 Prefix 保存在正确位置,则会引发异常。如果为
True
,任何现有 Tensor 都将被复制到新位置。
- 关键字参数:
num_threads (int, optional) – 用于写入 Memmap Tensor 的线程数量。默认为 0。
return_early (bool, optional) – 如果为
True
且num_threads>0
,该方法将返回 TensorDict 的 Future。share_non_tensor (bool, optional) – 如果为
True
,非 Tensor 数据将在进程间共享,并且在单个节点内任意 worker 上的写入操作(如原地更新或设置)将更新所有其他 worker 上的值。如果非 Tensor 叶子节点的数量很高(例如,共享大型非 Tensor 数据堆栈),这可能导致 OOM 或类似错误。默认为False
。existsok (bool, optional) – 如果为
False
,如果相同路径下已存在 Tensor,则会引发异常。默认为True
。
TensorDict 随后被锁定,这意味着任何非原地写入操作(例如,重命名、设置或移除条目)将抛出异常。一旦 TensorDict 解锁,内存映射属性将变为
False
,因为跨进程身份不再保证。- 返回:
如果
return_early=False
,则返回一个 Tensor 存储在磁盘上的新 tensordict;否则返回一个TensorDictFuture
实例。
注意
以这种方式序列化对于深度嵌套的 tensordict 来说可能很慢,因此不建议在训练循环内调用此方法。
- memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T ¶
将所有 Tensor 原地写入到相应的内存映射 Tensor。
- 参数:
prefix (str) – Memory-mapped Tensor 存储所在的目录前缀。目录树结构将模仿 TensorDict 的结构。
copy_existing (bool) – 如果为 False(默认),如果 TensorDict 中的某个条目已经是存储在磁盘上的 Tensor 且具有关联文件,但未按照 Prefix 保存在正确位置,则会引发异常。如果为
True
,任何现有 Tensor 都将被复制到新位置。
- 关键字参数:
num_threads (int, optional) – 用于写入 Memmap Tensor 的线程数量。默认为 0。
return_early (bool, optional) – 如果为
True
且num_threads>0
,该方法将返回一个 tensordict 的 future。可以通过使用 future.result() 来查询返回的 tensordict。share_non_tensor (bool, optional) – 如果为
True
,非 Tensor 数据将在进程间共享,并且在单个节点内任意 worker 上的写入操作(如原地更新或设置)将更新所有其他 worker 上的值。如果非 Tensor 叶子节点的数量很高(例如,共享大型非 Tensor 数据堆栈),这可能导致 OOM 或类似错误。默认为False
。existsok (bool, optional) – 如果为
False
,如果相同路径下已存在 Tensor,则会引发异常。默认为True
。
TensorDict 随后被锁定,这意味着任何非原地写入操作(例如,重命名、设置或移除条目)将抛出异常。一旦 TensorDict 解锁,内存映射属性将变为
False
,因为跨进程身份不再保证。- 返回:
如果
return_early=False
,返回自身;否则返回一个TensorDictFuture
实例。
注意
以这种方式序列化对于深度嵌套的 tensordict 来说可能很慢,因此不建议在训练循环内调用此方法。
- memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
创建一个具有与原始 tensordict 相同形状的无内容的内存映射 tensordict。
- 参数:
prefix (str) – Memory-mapped Tensor 存储所在的目录前缀。目录树结构将模仿 TensorDict 的结构。
copy_existing (bool) – 如果为 False(默认),如果 TensorDict 中的某个条目已经是存储在磁盘上的 Tensor 且具有关联文件,但未按照 Prefix 保存在正确位置,则会引发异常。如果为
True
,任何现有 Tensor 都将被复制到新位置。
- 关键字参数:
num_threads (int, optional) – 用于写入 Memmap Tensor 的线程数量。默认为 0。
return_early (bool, optional) – 如果为
True
且num_threads>0
,该方法将返回 TensorDict 的 Future。share_non_tensor (bool, optional) – 如果为
True
,非 Tensor 数据将在进程间共享,并且在单个节点内任意 worker 上的写入操作(如原地更新或设置)将更新所有其他 worker 上的值。如果非 Tensor 叶子节点的数量很高(例如,共享大型非 Tensor 数据堆栈),这可能导致 OOM 或类似错误。默认为False
。existsok (bool, optional) – 如果为
False
,如果相同路径下已存在 Tensor,则会引发异常。默认为True
。
TensorDict 随后被锁定,这意味着任何非原地写入操作(例如,重命名、设置或移除条目)将抛出异常。一旦 TensorDict 解锁,内存映射属性将变为
False
,因为跨进程身份不再保证。- 返回:
如果
return_early=False
,则返回一个数据存储为内存映射 tensor 的新TensorDict
实例;否则返回一个TensorDictFuture
实例。
注意
这是将一组大型缓冲区写入磁盘的推荐方法,因为
memmap_()
将复制信息,这对于大型内容可能很慢。示例
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()¶
如果内存映射 tensordict 具有
saved_path
,则刷新其内容。如果没有与其关联的路径,此方法将引发异常。
- property node_observation: torch.Tensor | TensorDictBase¶
返回与此特定节点关联的观测值。
这是定义节点在分支发生之前的观测值(或观测值集合)。如果节点包含
rollout()
属性,则节点观测值通常与上次执行的操作产生的观测值相同,即node.rollout[..., -1]["next", "observation"]
。如果与树的规格关联的观测值键不止一个,则返回一个
TensorDict
实例。为了更一致的表示,请参阅
node_observations
。
- property node_observations: torch.Tensor | TensorDictBase¶
返回以 TensorDict 格式表示的、与此特定节点关联的观测值。
这是定义节点在分支发生之前的观测值(或观测值集合)。如果节点包含
rollout()
属性,则节点观测值通常与上次执行的操作产生的观测值相同,即node.rollout[..., -1]["next", "observation"]
。如果与树的规格关联的观测值键不止一个,则返回一个
TensorDict
实例。为了更一致的表示,请参阅
node_observations
。
- property num_children: int¶
此节点的子节点数量。
等于
self.subtree
堆栈中的元素数量。
- num_vertices(*, count_repeat: bool = False) int [source]¶
返回 Tree 中唯一顶点的数量。
- 关键字参数:
count_repeat (bool, optional) –
确定是否计算重复顶点。
如果为
False
,则仅计算每个唯一顶点一次。如果为
True
,如果顶点出现在不同的路径中,则多次计算。
默认为
False
。- 返回:
Tree 中唯一顶点的数量。
- 返回类型:
int
- property parent: Tree | None¶
节点的父节点。
如果节点有父节点并且此对象仍然存在于 python 工作空间中,此属性将返回它。
对于重新分支的树,此属性可能返回一个树堆栈,其中堆栈中的每个索引对应于不同的父节点。
注意
parent 属性的内容将匹配,但身份不匹配:tensorclass 对象是使用相同的 tensor(即,指向相同内存位置的 tensor)重建的。
- 返回:
如果父节点数据超出作用域或节点是根节点,则返回包含父节点数据的
Tree
,否则返回None
。
- plot(backend: str = 'plotly', figure: str = 'tree', info: List[str] = None, make_labels: Callable[[Any, ...], Any] | None = None)[source]¶
使用指定的后端和图类型绘制树的可视化图。
- 参数:
backend – 要使用的绘图后端。目前仅支持 'plotly'。
figure – 要绘制的图类型。可以是 'tree' 或 'box'。
info – 要包含在图中的附加信息列表(目前未使用)。
make_labels – 一个可选函数,用于为绘图生成自定义标签。
- 引发:
NotImplementedError – 如果指定了不受支持的后端或图类型。
- property prev_action: torch.Tensor | TensorDictBase | None¶
就在生成此节点的观测值之前执行的操作。
- 返回:
如果节点没有父节点,则返回 Tensor、TensorDict 或 None。
另请参阅
只要 rollout 数据包含单个步骤,这将等于
branching_action
。另请参阅
所有与树中给定节点(或 Observation)关联的 Action
.
- rollout_from_path(path: Tuple[int]) TensorDictBase | None [source]¶
检索沿着树中给定路径的 rollout 数据。
对于路径中的每个节点,rollout 数据沿着最后一个维度(dim=-1)连接。如果沿着路径未找到 rollout 数据,则返回
None
。- 参数:
path – 一个表示树中路径的整数元组。
- 返回:
沿着路径连接后的 rollout 数据,如果未找到数据则返回 None。
- save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
将 TensorDict 保存到磁盘。
此函数是
memmap()
的代理。
- property selected_actions: torch.Tensor | TensorDictBase | None¶
返回一个包含从此节点分支出去的所有选定操作的 tensor。
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)¶
设置一个新的键值对。
- 参数:
key (str, tuple of str) – 要设置的键的名称。如果是字符串元组,则等同于链式调用 getattr 后跟一个最终的 setattr。
value (Any) – 要存储在 tensorclass 中的值。
inplace (bool, optional) – 如果为
True
,set 将尝试原地更新值。如果为False
或键不存在,值将被简单写入到其目的地。
- 返回:
自身
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any] ¶
返回一个 state_dict 字典,可用于保存和加载 tensorclass 中的数据。
- to_tensordict(*, retain_none: bool | None = None) TensorDict ¶
将 tensorclass 转换为常规的 TensorDict。
复制所有条目。内存映射和共享内存 tensor 被转换为常规 tensor。
- 参数:
retain_none (bool) –
如果为
True
,None
值将被写入 tensordict 中。否则将被丢弃。默认:True
。注意
从 v0.8 开始,默认值将更改为
False
。- 返回:
一个包含与 tensorclass 相同值的新 TensorDict 对象。
- unbind(dim: int)¶
返回一个由索引 tensorclass 实例组成的元组,这些实例沿指定维度解除绑定。
结果 tensorclass 实例将共享初始 tensorclass 实例的存储。
- valid_paths()[source]¶
生成树中的所有有效路径。
有效路径是从根节点开始并在叶子节点结束的子节点索引序列。每个路径表示为一个整数元组,其中每个整数对应于一个子节点的索引。
- 生成:
tuple – 树中的一个有效路径。
- vertices(*, key_type: Literal['id', 'hash', 'path'] = 'hash') Dict[int | Tuple[int], Tree] [source]¶
返回一个包含 Tree 顶点的映射。
- 关键字参数:
key_type (Literal["id", "hash", "path"], optional) –
指定用于顶点的键的类型。
"id": 使用顶点 ID 作为键。
"hash": 使用顶点的哈希值作为键。
- "path": 使用到顶点的路径作为键。这可能导致字典的长度比使用 "id" 或 "hash" 时更长,因为同一个节点可能是多条轨迹的一部分。
默认为
"hash"
。
(注:原文此处有额外一句 "Defaults to an empty string, which may imply a default behavior.",与前一句矛盾且在此处无上下文,已省略不译。)
- 返回:
一个将键映射到 Tree 顶点的字典。
- 返回类型:
Dict[int | Tuple[int], Tree]
- property visits: int | torch.Tensor¶
返回与此特定节点关联的访问次数。
这是
count
属性的别名。