• 文档 >
  • torchdata.nodes (beta)
快捷方式

注意

2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2

我们将 torchdata 仓库的重点重新聚焦到 torch.utils.data.DataLoader 的迭代增强上。我们不打算继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新审视 pytorch/pytorch 中对 DataPipes 的引用。在 torchdata==0.8.0 版本 (2024 年 7 月) 中,它们将被标记为已弃用,并在 0.10.0 版本 (2024 年底) 中删除。建议现有用户将版本固定在 torchdata<=0.9.0 或更旧的版本,直到他们能够迁移。后续版本将不包含 DataPipes 或 DataLoader V2。如果您有建议或意见,请随时联系我们(请使用此议题进行反馈)

torchdata.nodes (beta)

class torchdata.nodes.BaseNode(*args, **kwargs)

基类: Iterator[T]

BaseNodes 是在 torchdata.nodes 中创建可组合数据加载 DAG 的基类。

大多数最终用户不会直接遍历 BaseNode 实例,而是将其包装在 torchdata.nodes.Loader 中,后者将 DAG 转换为更熟悉的 Iterable。

node = MyBaseNodeImpl()
loader = Loader(node)
# loader supports state_dict() and load_state_dict()

for epoch in range(5):
    for idx, batch in enumerate(loader):
        ...

# or if using node directly:
node = MyBaseNodeImpl()
for epoch in range(5):
    node.reset()
    for idx, batch in enumerate(loader):
        ...
get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。此方法应仅由 BaseNode 调用。

返回:

Dict[str, Any] - 可在未来某个时间点传递给 reset() 的状态字典

next() T

子类必须实现此方法,而不是 __next__。此方法应仅由 BaseNode 调用。

返回:

T - 序列中的下一个值,如果序列结束则抛出 StopIteration

reset(initial_state: Optional[dict] = None)

将迭代器重置到开头,或重置到 initial_state 中传递的状态。

reset() 是一个放置昂贵初始化代码的好地方,因为它会在调用 next()state_dict() 时被延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。

state_dict() Dict[str, Any]

获取此 BaseNode 的状态字典。

返回:

Dict[str, Any] - 可在未来某个时间点传递给 reset() 的状态字典。

class torchdata.nodes.Batcher(source: BaseNode[T], batch_size: int, drop_last: bool = True)

基类: BaseNode[List[T]]

Batcher 节点将源节点的数据按 batch_size 大小打包成批次。如果源节点耗尽,它将返回当前批次或抛出 StopIteration。如果 drop_last 为 True,则如果最后一个批次小于 batch_size,它将被丢弃。如果 drop_last 为 False,即使最后一个批次小于 batch_size,它也将被返回。

参数:
  • source (BaseNode[T]) – 用于从中批量获取数据的源节点。

  • batch_size (int) – 批次大小。

  • drop_last (bool) – 如果最后一个批次小于 batch_size,是否丢弃它。默认为 True。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。此方法应仅由 BaseNode 调用。

返回:

Dict[str, Any] - 可在未来某个时间点传递给 reset() 的状态字典

next() List[T]

子类必须实现此方法,而不是 __next__。此方法应仅由 BaseNode 调用。

返回:

T - 序列中的下一个值,如果序列结束则抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置到 initial_state 中传递的状态。

reset() 是一个放置昂贵初始化代码的好地方,因为它会在调用 next()state_dict() 时被延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。

class torchdata.nodes.IterableWrapper(iterable: Iterable[T])

基类: BaseNode[T]

一个将任何 Iterable (包括 torch.utils.data.IterableDataset) 转换为 BaseNode 的轻量级包装器。

如果 iterable 实现了 Stateful 协议,它将通过其 state_dict/load_state_dict 方法进行保存和恢复。

参数:

iterable (*Iterable*[*T*]) – 要转换为 BaseNode 的 Iterable。IterableWrapper 会对其调用 iter() 方法。

警告:

请注意定义在 Iterable 上的 state_dict/load_state_dict 与定义在 Iterator 上的区别。此处仅使用 Iterable 的 state_dict/load_state_dict。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。此方法应仅由 BaseNode 调用。

返回:

Dict[str, Any] - 可在未来某个时间点传递给 reset() 的状态字典

next() T

子类必须实现此方法,而不是 __next__。此方法应仅由 BaseNode 调用。

返回:

T - 序列中的下一个值,如果序列结束则抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置到 initial_state 中传递的状态。

reset() 是一个放置昂贵初始化代码的好地方,因为它会在调用 next()state_dict() 时被延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。

class torchdata.nodes.Loader(root: BaseNode[T], restart_on_stop_iteration: bool = True)

基类: Generic[T]

包装根 BaseNode (一个迭代器),并提供一个有状态的 Iterable 接口。

最后返回的迭代器的状态由 state_dict() 方法返回,并可以使用 load_state_dict() 方法加载。

参数:
  • root (BaseNode[T]) – 数据管道的根节点。

  • restart_on_stop_iteration (*bool*) – 当迭代器到达末尾时是否重新启动。默认为 True

load_state_dict(state_dict: Dict[str, Any])

加载一个 state_dict,它将用于初始化从此加载器请求的下一个 iter()。

参数:

state_dict (*Dict*[*str*, *Any*]) – 要加载的 state_dict。应由 state_dict() 调用生成。

state_dict() Dict[str, Any]

返回一个 state_dict,将来可以将其传递给 load_state_dict() 以恢复迭代。

state_dict 将来自最近一次调用 iter() 返回的迭代器。如果尚未创建迭代器,将创建一个新的迭代器并从中返回 state_dict。

torchdata.nodes.MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Sampler[K]) BaseNode[T]

一个将任何 MapDataset 转换为 torchdata.node 的轻量级包装器。如果您需要并行化,请复制此代码并将 Mapper 替换为 ParallelMapper。

参数:
  • map_dataset (*Mapping*[*K*, *T*]) –

    • 对 sampler 的输出应用 map_dataset.__getitem__ 方法。

  • sampler (*Sampler*[*K*]) –

torchdata.nodes.Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) ParallelMapper[T]

返回一个 [`ParallelMapper`](#torchdata.nodes.ParallelMapper) 节点,其 num_workers=0,这将在当前进程/线程中执行 map_fn。

参数:
  • source (BaseNode[X]) – 要映射的源节点。

  • map_fn (*Callable*[[*X*], *T*]) – 应用于源节点中每个项目的函数。

class torchdata.nodes.MultiNodeWeightedSampler(source_nodes: Mapping[str, BaseNode[T]], weights: Dict[str, float], stop_criteria: str = 'CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED', rank: Optional[int] = None, world_size: Optional[int] = None, seed: int = 0)

基类: BaseNode[T]

一个按权重从多个数据集中采样数据的节点。

此节点需要接收一个源节点字典和一个权重字典。源节点和权重的键必须相同。权重用于从源节点中采样。我们使用 torch.multinomial 从源节点中采样,关于如何使用权重进行采样,请参阅 https://pytorch.ac.cn/docs/stable/generated/torch.multinomial.htmlseed 用于初始化随机数生成器。

此节点使用以下键来实现状态

  • DATASET_NODE_STATES_KEY: 每个源节点状态的字典。

  • DATASETS_EXHAUSTED_KEY: 一个布尔值字典,指示每个源节点是否已耗尽。

  • EPOCH_KEY: 一个用于初始化随机数生成器的 epoch 计数器。

  • NUM_YIELDED_KEY: 已产生的项目数量。

  • WEIGHTED_SAMPLER_STATE_KEY: 加权采样器的状态。

我们支持多种停止条件

  • CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: 循环遍历源节点直到所有数据集都耗尽。这是默认行为。

  • FIRST_DATASET_EXHAUSTED: 当第一个数据集耗尽时停止。

  • ALL_DATASETS_EXHAUSTED: 当所有数据集都耗尽时停止。

源节点完全耗尽时,节点将抛出 StopIteration。

参数:
  • source_nodes (*Mapping*[*str*, BaseNode[*T*]]) – 源节点的字典。

  • weights (*Dict*[*str*, *float*]) – 每个源节点的权重字典。

  • stop_criteria (*str*) – 停止条件。默认为 CYCLE_UNTIL_ALL_DATASETS_EXHAUST

  • rank (*int*) – 当前进程的排名。默认为 None,此时将从分布式环境中获取排名。

  • world_size (*int*) – 分布式环境的世界大小。默认为 None,此时将从分布式环境中获取世界大小。

  • seed (*int*) – 随机数生成器的种子。默认为 0。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。此方法应仅由 BaseNode 调用。

返回:

Dict[str, Any] - 可在未来某个时间点传递给 reset() 的状态字典

next() T

子类必须实现此方法,而不是 __next__。此方法应仅由 BaseNode 调用。

返回:

T - 序列中的下一个值,如果序列结束则抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置到 initial_state 中传递的状态。

reset() 是一个放置昂贵初始化代码的好地方,因为它会在调用 next()state_dict() 时被延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。

class torchdata.nodes.ParallelMapper(source: BaseNode[X], map_fn: Callable[[X], T], num_workers: int, in_order: bool = True, method: Literal['thread', 'process'] = 'thread', multiprocessing_context: Optional[str] = None, max_concurrent: Optional[int] = None, snapshot_frequency: int = 1, prebatch: Optional[int] = None)

基类: BaseNode[T]

ParallelMapper 在 num_workers 数量的线程或进程中并行执行 map_fn。对于进程方式,multiprocessing_context 可以是 spawn、forkserver、fork 或 None(选择操作系统默认)。最多 max_concurrent 个项目将被处理或放入迭代器的输出队列中,以限制 CPU 和内存利用率。如果为 None(默认),该值将为 2 * num_workers。

source 只创建一个 iter(),并且最多只有一个线程同时调用其 next()。

如果 in_order 为 True,迭代器将按照项目从 source 迭代器到达的顺序返回它们,即使有其他可用项目,也可能会阻塞。

参数:
  • source (BaseNode[X]) – 要映射的源节点。

  • map_fn (*Callable*[[*X*], *T*]) – 应用于源节点中每个项目的函数。

  • num_workers (int) – 用于并行处理的 worker 数量。

  • in_order (bool) – 是否按照项目从源到达的顺序返回。默认为 True。

  • method (Literal["thread", "process"]) – 用于并行处理的方法。默认为“thread”。

  • multiprocessing_context (Optional[str]) – 用于并行处理的多进程上下文。默认为 None。

  • max_concurrent (Optional[int]) – 同时处理的最大项目数。默认为 None。

  • snapshot_frequency (int) – 对源节点状态进行快照的频率。默认为 1。

  • prebatch (Optional[int]) – 可选地在映射之前对源中的项目执行预批量处理。对于小项目,这可能会以增加峰值内存为代价来提高吞吐量。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。此方法应仅由 BaseNode 调用。

返回:

Dict[str, Any] - 可在未来某个时间点传递给 reset() 的状态字典

next() T

子类必须实现此方法,而不是 __next__。此方法应仅由 BaseNode 调用。

返回:

T - 序列中的下一个值,如果序列结束则抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置到 initial_state 中传递的状态。

reset() 是一个放置昂贵初始化代码的好地方,因为它会在调用 next()state_dict() 时被延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。

class torchdata.nodes.PinMemory(source: BaseNode[T], pin_memory_device: str = '', snapshot_frequency: int = 1)

基类: BaseNode[T]

将底层节点的数据固定到设备上。这由 torch.utils.data._utils.pin_memory._pin_memory_loop 支持。

参数:
  • source (BaseNode[T]) – 要固定数据的源节点。

  • pin_memory_device (str) – 固定数据到的设备。默认为 ""。

  • snapshot_frequency (int) – 对源节点状态进行快照的频率。默认为 1,这意味着每处理一个项目后,都会对源节点的状态进行快照。如果设置为更高的值,则每处理 snapshot_frequency 个项目后,才会对源节点的状态进行快照。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。此方法应仅由 BaseNode 调用。

返回:

Dict[str, Any] - 可在未来某个时间点传递给 reset() 的状态字典

next()

子类必须实现此方法,而不是 __next__。此方法应仅由 BaseNode 调用。

返回:

T - 序列中的下一个值,如果序列结束则抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置到 initial_state 中传递的状态。

reset() 是一个放置昂贵初始化代码的好地方,因为它会在调用 next()state_dict() 时被延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。

class torchdata.nodes.Prefetcher(source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1)

基类: BaseNode[T]

从源节点预取数据并存储在队列中。

参数:
  • source (BaseNode[T]) – 从其预取数据的源节点。

  • prefetch_factor (int) – 提前预取的项目数量。

  • snapshot_frequency (int) – 对源节点状态进行快照的频率。默认为 1,这意味着每处理一个项目后,都会对源节点的状态进行快照。如果设置为更高的值,则每处理 snapshot_frequency 个项目后,才会对源节点的状态进行快照。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。此方法应仅由 BaseNode 调用。

返回:

Dict[str, Any] - 可在未来某个时间点传递给 reset() 的状态字典

next()

子类必须实现此方法,而不是 __next__。此方法应仅由 BaseNode 调用。

返回:

T - 序列中的下一个值,如果序列结束则抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置到 initial_state 中传递的状态。

reset() 是一个放置昂贵初始化代码的好地方,因为它会在调用 next()state_dict() 时被延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。

class torchdata.nodes.SamplerWrapper(sampler: Sampler[T], initial_epoch: int = 0, epoch_updater: Optional[Callable[[int], int]] = None)

基类: BaseNode[T]

将采样器 (sampler) 转换为 BaseNode。这与 IterableWrapper 几乎相同,但它包含一个钩子,用于在采样器支持时调用其 set_epoch 方法。

参数:
  • sampler (Sampler) – 要封装的采样器。

  • initial_epoch (int) – 要在采样器上设置的初始轮次。

  • epoch_updater (Optional[Callable[[int], int]] = None) – 在新迭代开始时更新轮次的回调函数。除了第一次迭代请求外,每次迭代请求开始时都会调用它。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。此方法应仅由 BaseNode 调用。

返回:

Dict[str, Any] - 可在未来某个时间点传递给 reset() 的状态字典

next() T

子类必须实现此方法,而不是 __next__。此方法应仅由 BaseNode 调用。

返回:

T - 序列中的下一个值,如果序列结束则抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置到 initial_state 中传递的状态。

reset() 是一个放置昂贵初始化代码的好地方,因为它会在调用 next()state_dict() 时被延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。

class torchdata.nodes.Stateful(*args, **kwargs)

基类: Protocol

为实现了 state_dict()load_state_dict(state_dict: Dict[str, Any]) 方法的对象定义的协议。

class torchdata.nodes.StopCriteria

基类: object

数据集采样器的停止标准。

  1. CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: 当最后一个未曾见过的数据集耗尽时停止。所有数据集至少被见过一次。在某些情况下,当仍有未耗尽的数据集时,一些数据集可能会被见过不止一次。

  2. ALL_DATASETS_EXHAUSTED: 当所有数据集都耗尽时停止。每个数据集只被见过一次。不会执行循环或重启。

  3. FIRST_DATASET_EXHAUSTED: 当第一个数据集耗尽时停止。

  4. CYCLE_FOREVER: 通过重新初始化每个耗尽的源节点来循环遍历数据集。当训练器需要控制步数而不是轮次时,这很有用。

class torchdata.nodes.Unbatcher(source: BaseNode[Sequence[T]])

基类: BaseNode[T]

Unbatcher 将从 source 拉取的批次展平,并在调用其 next() 时按顺序生成元素。

参数:

source (BaseNode[T]) – 从其拉取批次的源节点。

get_state() Dict[str, Any]

子类必须实现此方法,而不是 state_dict()。此方法应仅由 BaseNode 调用。

返回:

Dict[str, Any] - 可在未来某个时间点传递给 reset() 的状态字典

next() T

子类必须实现此方法,而不是 __next__。此方法应仅由 BaseNode 调用。

返回:

T - 序列中的下一个值,如果序列结束则抛出 StopIteration

reset(initial_state: Optional[Dict[str, Any]] = None)

将迭代器重置到开头,或重置到 initial_state 中传递的状态。

reset() 是一个放置昂贵初始化代码的好地方,因为它会在调用 next()state_dict() 时被延迟调用。子类必须调用 super().reset(initial_state)

参数:

initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。

文档

查阅 PyTorch 完整的开发者文档

查看文档

教程

获取适合初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源