快捷方式

Tree

class torchrl.data.Tree(count: 'int | torch.Tensor' = None, wins: 'int | torch.Tensor' = None, index: 'torch.Tensor | None' = None, hash: 'int | None' = None, node_id: 'int | None' = None, rollout: 'TensorDict | None' = None, node_data: 'TensorDict | None' = None, subtree: "'Tree'" = None, _parent: 'weakref.ref | List[weakref.ref] | None' = None, specs: 'Composite | None' = None, *, batch_size, device=None, names=None)[source]
property branching_action: torch.Tensor | TensorDictBase | None

返回由此特定节点分支出的 Action。

返回

如果节点没有父节点,则返回 Tensor、TensorDict 或 None。

另请参阅

当 Rollout 数据包含单个 Step 时,这将等于 prev_action

另请参阅

所有与树中给定节点(或 Observation)关联的 Action.

property device: device

获取 Tensor 类别的 Device 类型。

dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

将 TensorDict 保存到磁盘。

此函数是 memmap() 的代理。

edges() List[Tuple[int, int]][source]

获取树中的 Edge 列表。

每个 Edge 表示为两个节点 ID 的元组:父节点 ID 和子节点 ID。树使用广度优先搜索 (BFS) 遍历,以确保访问所有 Edge。

返回

一个元组列表,其中每个元组包含一个父节点 ID 和一个子节点 ID。

classmethod fields()

返回描述此 Dataclass 字段的元组。

接受 Dataclass 或其实例。元组元素类型为 Field。

classmethod from_tensordict(tensordict, non_tensordict=None, safe=True)

用于实例化新 Tensor 类别对象的 Tensor 类别包装器。

参数
  • tensordict (TensorDict) – Tensor 类型的字典

  • non_tensordict (dict) – 包含非 Tensor 和嵌套 Tensor 类别对象的字典

property full_action_spec

树的 Action Spec。

这是 Tree.specs[‘input_spec’, ‘full_action_spec’] 的别名。

property full_done_spec

树的 Done Spec。

这是 Tree.specs[‘output_spec’, ‘full_done_spec’] 的别名。

property full_observation_spec

树的 Observation Spec。

这是 Tree.specs[‘output_spec’, ‘full_observation_spec’] 的别名。

property full_reward_spec

树的 Reward Spec。

这是 Tree.specs[‘output_spec’, ‘full_reward_spec’] 的别名。

property full_state_spec

树的 State Spec。

这是 Tree.specs[‘input_spec’, ‘full_state_spec’] 的别名。

fully_expanded(env: EnvBase) bool[source]

如果子节点数量等于 Environment 基数,则返回 True。

get(key: NestedKey, *args, **kwargs)

获取使用输入 Key 存储的值。

参数
  • key (str, tuple of str) – 要查询的 Key。如果是 str 元组,则等效于 getattr 的链式调用。

  • default – 如果在 Tensorclass 中找不到 Key 的默认值。

返回

使用输入 Key 存储的值

get_vertex_by_hash(hash: int) Tree[source]

遍历树并返回与给定 Hash 相对应的节点。

get_vertex_by_id(id: int) Tree[source]

遍历树并返回与给定 ID 相对应的节点。

property is_terminal: bool | torch.Tensor

如果树没有子节点,则返回 True。

classmethod load(prefix: str | Path, *args, **kwargs) T

从磁盘加载 TensorDict。

此类方法是 load_memmap() 的代理。

load_(prefix: str | Path, *args, **kwargs)

将 TensorDict 从磁盘加载到当前 TensorDict 中。

此类方法是 load_memmap_() 的代理。

classmethod load_memmap(prefix: str | Path, device: torch.device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) T

从磁盘加载 Memory-mapped TensorDict。

参数
  • prefix (str or Path to folder) – 应从中获取已保存 TensorDict 的文件夹路径。

  • device (torch.device or equivalent, optional) – 如果提供,数据将异步地转换为该 Device。支持 “meta” Device,在这种情况下数据不会被加载,而是创建一组空的“meta” Tensor。这对于了解总模型大小和结构而无需实际打开任何文件很有用。

  • non_blocking (bool, optional) – 如果为 True,在 Device 上加载 Tensor 后不会调用 synchronize。默认为 False

  • out (TensorDictBase, optional) – 可选的 TensorDict,数据应写入其中。

示例

>>> from tensordict import TensorDict
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

此方法还支持加载嵌套的 TensorDict。

示例

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0

TensorDict 也可以加载到“meta” Device 上,或者作为 Fake Tensor 加载。

示例

>>> import tempfile
>>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}})
>>> with tempfile.TemporaryDirectory() as path:
...     td.save(path)
...     td_load = TensorDict.load_memmap(path, device="meta")
...     print("meta:", td_load)
...     from torch._subclasses import FakeTensorMode
...     with FakeTensorMode():
...         td_load = TensorDict.load_memmap(path)
...         print("fake:", td_load)
meta: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=meta,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=meta,
    is_shared=False)
fake: TensorDict(
    fields={
        a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)

尝试将 State_dict 就地加载到目标 Tensorclass 上。

classmethod make_node(data: TensorDictBase, *, device: torch.device | None = None, batch_size: torch.Size | None = None, specs: Composite | None = None) Tree[source]

根据给定数据创建一个新节点。

max_length()[source]

返回树中所有有效路径的最大长度。

路径长度定义为路径中的节点数量。如果树为空,则返回 0。

返回

树中所有有效路径的最大长度。

返回类型

int

memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T

将所有 Tensor 写入新 TensorDict 中相应的 Memory-mapped Tensor。

参数
  • prefix (str) – Memory-mapped Tensor 存储所在的目录前缀。目录树结构将模仿 TensorDict 的结构。

  • copy_existing (bool) – 如果为 False(默认),如果 TensorDict 中的某个条目已经是存储在磁盘上的 Tensor 且具有关联文件,但未按照 Prefix 保存在正确位置,则会引发异常。如果为 True,任何现有 Tensor 都将被复制到新位置。

关键字参数
  • num_threads (int, optional) – 用于写入 Memmap Tensor 的线程数量。默认为 0

  • return_early (bool, optional) – 如果为 Truenum_threads>0,该方法将返回 TensorDict 的 Future。

  • share_non_tensor (bool, optional) – 如果为 True,非 Tensor 数据将在进程间共享,并且在单个节点内任意 worker 上的写入操作(如原地更新或设置)将更新所有其他 worker 上的值。如果非 Tensor 叶子节点的数量很高(例如,共享大型非 Tensor 数据堆栈),这可能导致 OOM 或类似错误。默认为 False

  • existsok (bool, optional) – 如果为 False,如果相同路径下已存在 Tensor,则会引发异常。默认为 True

TensorDict 随后被锁定,这意味着任何非原地写入操作(例如,重命名、设置或移除条目)将抛出异常。一旦 TensorDict 解锁,内存映射属性将变为 False,因为跨进程身份不再保证。

返回

如果 return_early=False,则返回一个 Tensor 存储在磁盘上的新 tensordict;否则返回一个 TensorDictFuture 实例。

注意

以这种方式序列化对于深度嵌套的 tensordict 来说可能很慢,因此不建议在训练循环内调用此方法。

memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T

将所有 Tensor 原地写入到相应的内存映射 Tensor。

参数
  • prefix (str) – Memory-mapped Tensor 存储所在的目录前缀。目录树结构将模仿 TensorDict 的结构。

  • copy_existing (bool) – 如果为 False(默认),如果 TensorDict 中的某个条目已经是存储在磁盘上的 Tensor 且具有关联文件,但未按照 Prefix 保存在正确位置,则会引发异常。如果为 True,任何现有 Tensor 都将被复制到新位置。

关键字参数
  • num_threads (int, optional) – 用于写入 Memmap Tensor 的线程数量。默认为 0

  • return_early (bool, optional) – 如果为 Truenum_threads>0,该方法将返回一个 tensordict 的 future。可以通过使用 future.result() 来查询返回的 tensordict。

  • share_non_tensor (bool, optional) – 如果为 True,非 Tensor 数据将在进程间共享,并且在单个节点内任意 worker 上的写入操作(如原地更新或设置)将更新所有其他 worker 上的值。如果非 Tensor 叶子节点的数量很高(例如,共享大型非 Tensor 数据堆栈),这可能导致 OOM 或类似错误。默认为 False

  • existsok (bool, optional) – 如果为 False,如果相同路径下已存在 Tensor,则会引发异常。默认为 True

TensorDict 随后被锁定,这意味着任何非原地写入操作(例如,重命名、设置或移除条目)将抛出异常。一旦 TensorDict 解锁,内存映射属性将变为 False,因为跨进程身份不再保证。

返回

如果 return_early=False,返回自身;否则返回一个 TensorDictFuture 实例。

注意

以这种方式序列化对于深度嵌套的 tensordict 来说可能很慢,因此不建议在训练循环内调用此方法。

memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

创建一个具有与原始 tensordict 相同形状的无内容的内存映射 tensordict。

参数
  • prefix (str) – Memory-mapped Tensor 存储所在的目录前缀。目录树结构将模仿 TensorDict 的结构。

  • copy_existing (bool) – 如果为 False(默认),如果 TensorDict 中的某个条目已经是存储在磁盘上的 Tensor 且具有关联文件,但未按照 Prefix 保存在正确位置,则会引发异常。如果为 True,任何现有 Tensor 都将被复制到新位置。

关键字参数
  • num_threads (int, optional) – 用于写入 Memmap Tensor 的线程数量。默认为 0

  • return_early (bool, optional) – 如果为 Truenum_threads>0,该方法将返回 TensorDict 的 Future。

  • share_non_tensor (bool, optional) – 如果为 True,非 Tensor 数据将在进程间共享,并且在单个节点内任意 worker 上的写入操作(如原地更新或设置)将更新所有其他 worker 上的值。如果非 Tensor 叶子节点的数量很高(例如,共享大型非 Tensor 数据堆栈),这可能导致 OOM 或类似错误。默认为 False

  • existsok (bool, optional) – 如果为 False,如果相同路径下已存在 Tensor,则会引发异常。默认为 True

TensorDict 随后被锁定,这意味着任何非原地写入操作(例如,重命名、设置或移除条目)将抛出异常。一旦 TensorDict 解锁,内存映射属性将变为 False,因为跨进程身份不再保证。

返回

如果 return_early=False,则返回一个数据存储为内存映射 tensor 的新 TensorDict 实例;否则返回一个 TensorDictFuture 实例。

注意

这是将一组大型缓冲区写入磁盘的推荐方法,因为 memmap_() 将复制信息,这对于大型内容可能很慢。

示例

>>> td = TensorDict({
...     "a": torch.zeros((3, 64, 64), dtype=torch.uint8),
...     "b": torch.zeros(1, dtype=torch.int64),
... }, batch_size=[]).expand(1_000_000)  # expand does not allocate new memory
>>> buffer = td.memmap_like("/path/to/dataset")
memmap_refresh_()

如果内存映射 tensordict 具有 saved_path,则刷新其内容。

如果没有与其关联的路径,此方法将引发异常。

property node_observation: torch.Tensor | TensorDictBase

返回与此特定节点关联的观测值。

这是定义节点在分支发生之前的观测值(或观测值集合)。如果节点包含 rollout() 属性,则节点观测值通常与上次执行的操作产生的观测值相同,即 node.rollout[..., -1]["next", "observation"]

如果与树的规格关联的观测值键不止一个,则返回一个 TensorDict 实例。

为了更一致的表示,请参阅 node_observations

property node_observations: torch.Tensor | TensorDictBase

返回以 TensorDict 格式表示的、与此特定节点关联的观测值。

这是定义节点在分支发生之前的观测值(或观测值集合)。如果节点包含 rollout() 属性,则节点观测值通常与上次执行的操作产生的观测值相同,即 node.rollout[..., -1]["next", "observation"]

如果与树的规格关联的观测值键不止一个,则返回一个 TensorDict 实例。

为了更一致的表示,请参阅 node_observations

property num_children: int

此节点的子节点数量。

等于 self.subtree 堆栈中的元素数量。

num_vertices(*, count_repeat: bool = False) int[source]

返回 Tree 中唯一顶点的数量。

关键字参数

count_repeat (bool, optional) –

确定是否计算重复顶点。

  • 如果为 False,则仅计算每个唯一顶点一次。

  • 如果为 True,如果顶点出现在不同的路径中,则多次计算。

默认为 False

返回

Tree 中唯一顶点的数量。

返回类型

int

property parent: Tree | None

节点的父节点。

如果节点有父节点并且此对象仍然存在于 python 工作空间中,此属性将返回它。

对于重新分支的树,此属性可能返回一个树堆栈,其中堆栈中的每个索引对应于不同的父节点。

注意

parent 属性的内容将匹配,但身份不匹配:tensorclass 对象是使用相同的 tensor(即,指向相同内存位置的 tensor)重建的。

返回

如果父节点数据超出作用域或节点是根节点,则返回包含父节点数据的 Tree,否则返回 None

plot(backend: str = 'plotly', figure: str = 'tree', info: List[str] = None, make_labels: Callable[[Any, ...], Any] | None = None)[source]

使用指定的后端和图类型绘制树的可视化图。

参数
  • backend – 要使用的绘图后端。目前仅支持 'plotly'。

  • figure – 要绘制的图类型。可以是 'tree' 或 'box'。

  • info – 要包含在图中的附加信息列表(目前未使用)。

  • make_labels – 一个可选函数,用于为绘图生成自定义标签。

引发:

NotImplementedError – 如果指定了不受支持的后端或图类型。

property prev_action: torch.Tensor | TensorDictBase | None

就在生成此节点的观测值之前执行的操作。

返回

如果节点没有父节点,则返回 Tensor、TensorDict 或 None。

另请参阅

只要 rollout 数据包含单个步骤,这将等于 branching_action

另请参阅

所有与树中给定节点(或 Observation)关联的 Action.

rollout_from_path(path: Tuple[int]) TensorDictBase | None[source]

检索沿着树中给定路径的 rollout 数据。

对于路径中的每个节点,rollout 数据沿着最后一个维度(dim=-1)连接。如果沿着路径未找到 rollout 数据,则返回 None

参数

path – 一个表示树中路径的整数元组。

返回

沿着路径连接后的 rollout 数据,如果未找到数据则返回 None。

save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T

将 TensorDict 保存到磁盘。

此函数是 memmap() 的代理。

property selected_actions: torch.Tensor | TensorDictBase | None

返回一个包含从此节点分支出去的所有选定操作的 tensor。

set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)

设置一个新的键值对。

参数
  • key (str, tuple of str) – 要设置的键的名称。如果是字符串元组,则等同于链式调用 getattr 后跟一个最终的 setattr。

  • value (Any) – 要存储在 tensorclass 中的值。

  • inplace (bool, optional) – 如果为 True,set 将尝试原地更新值。如果为 False 或键不存在,值将被简单写入到其目的地。

返回

自身

state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]

返回一个 state_dict 字典,可用于保存和加载 tensorclass 中的数据。

to_tensordict(*, retain_none: bool | None = None) TensorDict

将 tensorclass 转换为常规的 TensorDict。

复制所有条目。内存映射和共享内存 tensor 被转换为常规 tensor。

参数

retain_none (bool) –

如果为 TrueNone 值将被写入 tensordict 中。否则将被丢弃。默认: True

注意

从 v0.8 开始,默认值将更改为 False

返回

一个包含与 tensorclass 相同值的新 TensorDict 对象。

unbind(dim: int)

返回一个由索引 tensorclass 实例组成的元组,这些实例沿指定维度解除绑定。

结果 tensorclass 实例将共享初始 tensorclass 实例的存储。

valid_paths()[source]

生成树中的所有有效路径。

有效路径是从根节点开始并在叶子节点结束的子节点索引序列。每个路径表示为一个整数元组,其中每个整数对应于一个子节点的索引。

生成:

tuple – 树中的一个有效路径。

vertices(*, key_type: Literal['id', 'hash', 'path'] = 'hash') Dict[int | Tuple[int], Tree][source]

返回一个包含 Tree 顶点的映射。

关键字参数

key_type (Literal["id", "hash", "path"], optional) –

指定用于顶点的键的类型。

  • "id": 使用顶点 ID 作为键。

  • "hash": 使用顶点的哈希值作为键。

  • "path": 使用到顶点的路径作为键。这可能导致字典的长度比使用 "id" 或 "hash" 时更长,因为同一个节点可能是多条轨迹的一部分。

    默认为 "hash"

(注:原文此处有额外一句 "Defaults to an empty string, which may imply a default behavior.",与前一句矛盾且在此处无上下文,已省略不译。)

返回

一个将键映射到 Tree 顶点的字典。

返回类型

Dict[int | Tuple[int], Tree]

property visits: int | torch.Tensor

返回与此特定节点关联的访问次数。

这是 count 属性的别名。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源