• 文档 >
  • torchdata.nodes (beta)
快捷方式

注意

2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2

我们正在重新聚焦 torchdata 仓库,使其成为 torch.utils.data.DataLoader 的迭代增强。我们不计划继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新审视 pytorch/pytorch 中的 DataPipes 引用。在 torchdata==0.8.0 版本(2024 年 7 月)中,它们将被标记为已弃用,在 0.10.0 版本(2024 年末)中,它们将被删除。建议现有用户锁定到 torchdata<=0.9.0 或更旧版本,直到他们能够迁移。后续版本将不包含 DataPipes 或 DataLoaderV2。如果您有建议或意见,请联系我们(请使用 此 issue 进行反馈)

torchdata.nodes (beta)

class torchdata.nodes.BaseNode(*args, **kwargs)

基类: Iterator[T]

BaseNode 是在 torchdata.nodes 中创建可组合数据加载 DAG 的基类。

大多数最终用户不会直接迭代 BaseNode 实例,而是将其包装在 torchdata.nodes.Loader 中,后者将 DAG 转换为更熟悉的 Iterable。

node = MyBaseNodeImpl()
loader = Loader(node)
# loader supports state_dict() and load_state_dict()

for epoch in range(5):
    for idx, batch in enumerate(loader):
        ...

# or if using node directly:
node = MyBaseNodeImpl()
for epoch in range(5):
    node.reset()
    for idx, batch in enumerate(loader):
        ...
get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()

next() T

子类必须实现此方法,而不是 __next。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration

reset(initial_state: Optional[dict] = None)

将迭代器重置到开头,或重置为 initial_state 传入的状态。

Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。

state_dict() Dict[str, Any]

获取此 BaseNode 的 state_dict。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()。

class torchdata.nodes.Batcher(source: BaseNode[T], batch_size: int, drop_last: bool = True)

基类: BaseNode[List[T]]

Batcher 节点将来自源节点的数据分批处理成 batch_size 大小的批次。如果源节点已耗尽,它将返回批次或引发 StopIteration。如果 drop_last 为 True,则如果最后一个批次小于 batch_size,则会丢弃最后一个批次。如果 drop_last 为 False,则即使最后一个批次小于 batch_size,也会返回最后一个批次。

参数:
  • source (BaseNode[T]) – 要从中批量处理数据的源节点。

  • batch_size (int) – 批次的大小。

  • drop_last (bool) – 是否在最后一个批次小于 batch_size 时丢弃最后一个批次。默认为 True。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()

next() List[T]

子类必须实现此方法,而不是 __next。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置为 initial_state 传入的状态。

Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。

class torchdata.nodes.IterableWrapper(iterable: Iterable[T])

基类: BaseNode[T]

将任何 Iterable(包括 torch.utils.data.IterableDataset)转换为 BaseNode 的轻薄包装器。

如果 iterable 实现了有状态协议,它将使用其 state_dict/load_state_dict 方法保存和恢复其状态。

参数:

iterable (Iterable[T]) – 要转换为 BaseNode 的 Iterable。IterableWrapper 对其调用 iter()。

警告:

注意在 Iterable 上定义的 state_dict/load_state_dict 与 Iterator 之间的区别。仅使用 Iterable 的 state_dict/load_state_dict。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()

next() T

子类必须实现此方法,而不是 __next。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置为 initial_state 传入的状态。

Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。

class torchdata.nodes.Loader(root: BaseNode[T], restart_on_stop_iteration: bool = True)

基类: Generic[T]

包装根 BaseNode(一个迭代器)并提供有状态的可迭代接口。

最后返回的迭代器的状态由 state_dict() 方法返回,并且可以使用 load_state_dict() 方法加载。

参数:
  • root (BaseNode[T]) – 数据管道的根节点。

  • restart_on_stop_iteration (bool) – 是否在迭代器到达末尾时重新启动迭代器。默认为 True

load_state_dict(state_dict: Dict[str, Any])

加载一个 state_dict,它将用于初始化从此加载器请求的下一个 iter()。

参数:

state_dict (Dict[str, Any]) – 要加载的 state_dict。应从调用 state_dict() 生成。

state_dict() Dict[str, Any]

返回一个 state_dict,将来可以将其传递给 load_state_dict() 以恢复迭代。

state_dict 将来自最近一次调用 iter() 返回的迭代器。如果尚未创建迭代器,则将创建一个新迭代器并从中返回 state_dict。

torchdata.nodes.MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Sampler[K]) BaseNode[T]

将任何 MapDataset 转换为 torchdata.node 的轻薄包装器。如果需要并行性,请复制此代码并将 Mapper 替换为 ParallelMapper。

参数:
  • map_dataset (Mapping[K, T]) –

    • 将 map_dataset.__getitem__ 应用于 sampler 的输出。

  • sampler (Sampler[K]) –

torchdata.nodes.Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) ParallelMapper[T]

返回一个 ParallelMapper 节点,其 num_workers=0,将在当前进程/线程中执行 map_fn。

参数:
  • source (BaseNode[X]) – 要在其上进行映射的源节点。

  • map_fn (Callable[[X], T]) – 要应用于来自源节点的每个项目的函数。

class torchdata.nodes.MultiNodeWeightedSampler(source_nodes: Mapping[str, BaseNode[T]], weights: Dict[str, float], stop_criteria: str = 'CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED', rank: Optional[int] = None, world_size: Optional[int] = None, seed: int = 0)

基类: BaseNode[T]

一个从多个数据集按权重采样的节点。

此节点希望接收源节点字典和权重字典。源节点和权重的键必须相同。权重用于从源节点采样。我们使用 torch.multinomial 从源节点采样,请参阅 https://pytorch.ac.cn/docs/stable/generated/torch.multinomial.html 了解如何使用权重进行采样。seed 用于初始化随机数生成器。

该节点使用以下键实现状态: - DATASET_NODE_STATES_KEY:每个源节点的状态字典。 - DATASETS_EXHAUSTED_KEY:一个布尔字典,指示每个源节点是否已耗尽。 - EPOCH_KEY:用于初始化随机数生成器的 epoch 计数器。 - NUM_YIELDED_KEY:产生的项目数。 - WEIGHTED_SAMPLER_STATE_KEY:加权采样器的状态。

我们支持多种停止标准: - CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED:循环遍历源节点,直到所有数据集都耗尽。这是默认行为。 - FIRST_DATASET_EXHAUSTED:当第一个数据集耗尽时停止。 - ALL_DATASETS_EXHAUSTED:当所有数据集都耗尽时停止。

当源节点完全耗尽时,该节点将引发 StopIteration。

参数:
  • source_nodes (Mapping[str, BaseNode[T]]) – 源节点字典。

  • weights (Dict[str, float]) – 每个源节点的权重字典。

  • stop_criteria (str) – 停止标准。默认为 CYCLE_UNTIL_ALL_DATASETS_EXHAUST

  • rank (int) – 当前进程的排名。默认为 None,在这种情况下,将从分布式环境中获取排名。

  • world_size (int) – 分布式环境的世界大小。默认为 None,在这种情况下,将从分布式环境中获取世界大小。

  • seed (int) – 随机数生成器的种子。默认为 0。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()

next() T

子类必须实现此方法,而不是 __next。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置为 initial_state 传入的状态。

Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。

class torchdata.nodes.ParallelMapper(source: BaseNode[X], map_fn: Callable[[X], T], num_workers: int, in_order: bool = True, method: Literal['thread', 'process'] = 'thread', multiprocessing_context: Optional[str] = None, max_concurrent: Optional[int] = None, snapshot_frequency: int = 1)

基类: BaseNode[T]

ParallelMapper 在 num_workers 线程或进程中并行执行 map_fn。对于进程,multiprocessing_context 可以是 spawn、forkserver、fork 或 None(选择操作系统默认值)。最多 max_concurrent 项将被处理或在迭代器的输出队列中,以限制 CPU 和内存利用率。如果为 None(默认值),则该值将为 2 * num_workers。

最多从源创建一个 iter(),并且最多一个线程将同时在其上调用 next()。

如果 in_order 为 true,则迭代器将按照项目从源迭代器到达的顺序返回项目,即使其他项目可用也可能阻塞。

参数:
  • source (BaseNode[X]) – 要在其上进行映射的源节点。

  • map_fn (Callable[[X], T]) – 要应用于来自源节点的每个项目的函数。

  • num_workers (int) – 用于并行处理的工作进程数。

  • in_order (bool) – 是否按照项目到达的顺序返回项目。默认为 True。

  • method (Literal["thread", "process"]) – 用于并行处理的方法。默认为 “thread”。

  • multiprocessing_context (Optional[str]) – 用于并行处理的多进程上下文。默认为 None。

  • max_concurrent (Optional[int]) – 一次处理的最大项目数。默认为 None。

  • snapshot_frequency (int) – 对源节点状态进行快照的频率。默认为 1。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()

next()

子类必须实现此方法,而不是 __next。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置为 initial_state 传入的状态。

Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。

class torchdata.nodes.PinMemory(source: BaseNode[T], pin_memory_device: str = '', snapshot_frequency: int = 1)

基类: BaseNode[T]

将底层节点的数据固定到设备上。这由 torch.utils.data._utils.pin_memory._pin_memory_loop 支持。

参数:
  • source (BaseNode[T]) – 要从中固定数据的源节点。

  • pin_memory_device (str) – 要将数据固定到的设备。默认值为 “”。

  • snapshot_frequency (int) – 对源节点状态进行快照的频率。默认值为 1,这意味着源节点的状态将在每个项目后进行快照。如果设置为更高的值,则源节点的状态将在每 snapshot_frequency 个项目后进行快照。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()

next()

子类必须实现此方法,而不是 __next。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置为 initial_state 传入的状态。

Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。

class torchdata.nodes.Prefetcher(source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1)

基类: BaseNode[T]

从源节点预取数据并将其存储在队列中。

参数:
  • source (BaseNode[T]) – 要从中预取数据的源节点。

  • prefetch_factor (int) – 提前预取的项目数。

  • snapshot_frequency (int) – 对源节点状态进行快照的频率。默认值为 1,这意味着源节点的状态将在每个项目后进行快照。如果设置为更高的值,则源节点的状态将在每 snapshot_frequency 个项目后进行快照。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()

next()

子类必须实现此方法,而不是 __next。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置为 initial_state 传入的状态。

Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。

class torchdata.nodes.SamplerWrapper(sampler: Sampler[T], initial_epoch: int = 0, epoch_updater: Optional[Callable[[int], int]] = None)

基类: BaseNode[T]

将采样器转换为 BaseNode。这与 IterableWrapper 几乎相同,除了它包含一个钩子来在采样器上调用 set_epoch(如果它支持)。

参数:
  • sampler (Sampler) – 要包装的采样器。

  • initial_epoch (int) – 要在采样器上设置的初始 epoch

  • epoch_updater (Optional[Callable[[int], int]] = None) – 在新的迭代开始时更新 epoch 的回调。它在每个迭代器请求的开始时被调用,除了第一个。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()

next() T

子类必须实现此方法,而不是 __next。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置为 initial_state 传入的状态。

Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。

class torchdata.nodes.Stateful(*args, **kwargs)

基类: Protocol

用于实现 state_dict()load_state_dict(state_dict: Dict[str, Any]) 的对象的协议

class torchdata.nodes.StopCriteria

基类: object

数据集采样器的停止标准。

  1. CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: 一旦最后一个未见数据集耗尽就停止。所有数据集至少被看到一次。在某些情况下,当仍然存在未耗尽的数据集时,某些数据集可能会被看到多次。

  2. ALL_DATASETS_EXHAUSTED: 一旦所有数据集都耗尽就停止。每个数据集只被看到一次。不会执行环绕或重启。

  3. FIRST_DATASET_EXHAUSTED: 当第一个数据集耗尽时停止。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源