注意
2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2
我们正在重新聚焦 torchdata 仓库,使其成为 torch.utils.data.DataLoader 的迭代增强。我们不计划继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新审视 pytorch/pytorch 中的 DataPipes 引用。在 torchdata==0.8.0 版本(2024 年 7 月)中,它们将被标记为已弃用,在 0.10.0 版本(2024 年末)中,它们将被删除。建议现有用户锁定到 torchdata<=0.9.0 或更旧版本,直到他们能够迁移。后续版本将不包含 DataPipes 或 DataLoaderV2。如果您有建议或意见,请联系我们(请使用 此 issue 进行反馈)
torchdata.nodes
(beta)¶
- class torchdata.nodes.BaseNode(*args, **kwargs)¶
基类:
Iterator
[T
]BaseNode 是在
torchdata.nodes
中创建可组合数据加载 DAG 的基类。大多数最终用户不会直接迭代 BaseNode 实例,而是将其包装在
torchdata.nodes.Loader
中,后者将 DAG 转换为更熟悉的 Iterable。node = MyBaseNodeImpl() loader = Loader(node) # loader supports state_dict() and load_state_dict() for epoch in range(5): for idx, batch in enumerate(loader): ... # or if using node directly: node = MyBaseNodeImpl() for epoch in range(5): node.reset() for idx, batch in enumerate(loader): ...
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()
- next() T ¶
子类必须实现此方法,而不是
__next
。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
- reset(initial_state: Optional[dict] = None)¶
将迭代器重置到开头,或重置为 initial_state 传入的状态。
Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用
super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。
- state_dict() Dict[str, Any] ¶
获取此 BaseNode 的 state_dict。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()。
- class torchdata.nodes.Batcher(source: BaseNode[T], batch_size: int, drop_last: bool = True)¶
基类:
BaseNode
[List
[T
]]Batcher 节点将来自源节点的数据分批处理成 batch_size 大小的批次。如果源节点已耗尽,它将返回批次或引发 StopIteration。如果 drop_last 为 True,则如果最后一个批次小于 batch_size,则会丢弃最后一个批次。如果 drop_last 为 False,则即使最后一个批次小于 batch_size,也会返回最后一个批次。
- 参数:
source (BaseNode[T]) – 要从中批量处理数据的源节点。
batch_size (int) – 批次的大小。
drop_last (bool) – 是否在最后一个批次小于 batch_size 时丢弃最后一个批次。默认为 True。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()
- next() List[T] ¶
子类必须实现此方法,而不是
__next
。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置为 initial_state 传入的状态。
Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用
super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。
- class torchdata.nodes.IterableWrapper(iterable: Iterable[T])¶
基类:
BaseNode
[T
]将任何 Iterable(包括 torch.utils.data.IterableDataset)转换为 BaseNode 的轻薄包装器。
如果 iterable 实现了有状态协议,它将使用其 state_dict/load_state_dict 方法保存和恢复其状态。
- 参数:
iterable (Iterable[T]) – 要转换为 BaseNode 的 Iterable。IterableWrapper 对其调用 iter()。
- 警告:
注意在 Iterable 上定义的 state_dict/load_state_dict 与 Iterator 之间的区别。仅使用 Iterable 的 state_dict/load_state_dict。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()
- next() T ¶
子类必须实现此方法,而不是
__next
。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置为 initial_state 传入的状态。
Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用
super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。
- class torchdata.nodes.Loader(root: BaseNode[T], restart_on_stop_iteration: bool = True)¶
基类:
Generic
[T
]包装根 BaseNode(一个迭代器)并提供有状态的可迭代接口。
最后返回的迭代器的状态由 state_dict() 方法返回,并且可以使用 load_state_dict() 方法加载。
- 参数:
root (BaseNode[T]) – 数据管道的根节点。
restart_on_stop_iteration (bool) – 是否在迭代器到达末尾时重新启动迭代器。默认为 True
- load_state_dict(state_dict: Dict[str, Any])¶
加载一个 state_dict,它将用于初始化从此加载器请求的下一个 iter()。
- 参数:
state_dict (Dict[str, Any]) – 要加载的 state_dict。应从调用 state_dict() 生成。
- state_dict() Dict[str, Any] ¶
返回一个 state_dict,将来可以将其传递给 load_state_dict() 以恢复迭代。
state_dict 将来自最近一次调用 iter() 返回的迭代器。如果尚未创建迭代器,则将创建一个新迭代器并从中返回 state_dict。
- torchdata.nodes.MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Sampler[K]) BaseNode[T] ¶
将任何 MapDataset 转换为 torchdata.node 的轻薄包装器。如果需要并行性,请复制此代码并将 Mapper 替换为 ParallelMapper。
- 参数:
map_dataset (Mapping[K, T]) –
将 map_dataset.__getitem__ 应用于 sampler 的输出。
sampler (Sampler[K]) –
- torchdata.nodes.Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) ParallelMapper[T] ¶
返回一个
ParallelMapper
节点,其 num_workers=0,将在当前进程/线程中执行 map_fn。- 参数:
source (BaseNode[X]) – 要在其上进行映射的源节点。
map_fn (Callable[[X], T]) – 要应用于来自源节点的每个项目的函数。
- class torchdata.nodes.MultiNodeWeightedSampler(source_nodes: Mapping[str, BaseNode[T]], weights: Dict[str, float], stop_criteria: str = 'CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED', rank: Optional[int] = None, world_size: Optional[int] = None, seed: int = 0)¶
基类:
BaseNode
[T
]一个从多个数据集按权重采样的节点。
此节点希望接收源节点字典和权重字典。源节点和权重的键必须相同。权重用于从源节点采样。我们使用 torch.multinomial 从源节点采样,请参阅 https://pytorch.ac.cn/docs/stable/generated/torch.multinomial.html 了解如何使用权重进行采样。seed 用于初始化随机数生成器。
该节点使用以下键实现状态: - DATASET_NODE_STATES_KEY:每个源节点的状态字典。 - DATASETS_EXHAUSTED_KEY:一个布尔字典,指示每个源节点是否已耗尽。 - EPOCH_KEY:用于初始化随机数生成器的 epoch 计数器。 - NUM_YIELDED_KEY:产生的项目数。 - WEIGHTED_SAMPLER_STATE_KEY:加权采样器的状态。
我们支持多种停止标准: - CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED:循环遍历源节点,直到所有数据集都耗尽。这是默认行为。 - FIRST_DATASET_EXHAUSTED:当第一个数据集耗尽时停止。 - ALL_DATASETS_EXHAUSTED:当所有数据集都耗尽时停止。
当源节点完全耗尽时,该节点将引发 StopIteration。
- 参数:
source_nodes (Mapping[str, BaseNode[T]]) – 源节点字典。
weights (Dict[str, float]) – 每个源节点的权重字典。
stop_criteria (str) – 停止标准。默认为 CYCLE_UNTIL_ALL_DATASETS_EXHAUST
rank (int) – 当前进程的排名。默认为 None,在这种情况下,将从分布式环境中获取排名。
world_size (int) – 分布式环境的世界大小。默认为 None,在这种情况下,将从分布式环境中获取世界大小。
seed (int) – 随机数生成器的种子。默认为 0。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()
- next() T ¶
子类必须实现此方法,而不是
__next
。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置为 initial_state 传入的状态。
Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用
super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。
- class torchdata.nodes.ParallelMapper(source: BaseNode[X], map_fn: Callable[[X], T], num_workers: int, in_order: bool = True, method: Literal['thread', 'process'] = 'thread', multiprocessing_context: Optional[str] = None, max_concurrent: Optional[int] = None, snapshot_frequency: int = 1)¶
基类:
BaseNode
[T
]ParallelMapper 在 num_workers 线程或进程中并行执行 map_fn。对于进程,multiprocessing_context 可以是 spawn、forkserver、fork 或 None(选择操作系统默认值)。最多 max_concurrent 项将被处理或在迭代器的输出队列中,以限制 CPU 和内存利用率。如果为 None(默认值),则该值将为 2 * num_workers。
最多从源创建一个 iter(),并且最多一个线程将同时在其上调用 next()。
如果 in_order 为 true,则迭代器将按照项目从源迭代器到达的顺序返回项目,即使其他项目可用也可能阻塞。
- 参数:
source (BaseNode[X]) – 要在其上进行映射的源节点。
map_fn (Callable[[X], T]) – 要应用于来自源节点的每个项目的函数。
num_workers (int) – 用于并行处理的工作进程数。
in_order (bool) – 是否按照项目到达的顺序返回项目。默认为 True。
method (Literal["thread", "process"]) – 用于并行处理的方法。默认为 “thread”。
multiprocessing_context (Optional[str]) – 用于并行处理的多进程上下文。默认为 None。
max_concurrent (Optional[int]) – 一次处理的最大项目数。默认为 None。
snapshot_frequency (int) – 对源节点状态进行快照的频率。默认为 1。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()
- next()¶
子类必须实现此方法,而不是
__next
。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置为 initial_state 传入的状态。
Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用
super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。
- class torchdata.nodes.PinMemory(source: BaseNode[T], pin_memory_device: str = '', snapshot_frequency: int = 1)¶
基类:
BaseNode
[T
]将底层节点的数据固定到设备上。这由 torch.utils.data._utils.pin_memory._pin_memory_loop 支持。
- 参数:
source (BaseNode[T]) – 要从中固定数据的源节点。
pin_memory_device (str) – 要将数据固定到的设备。默认值为 “”。
snapshot_frequency (int) – 对源节点状态进行快照的频率。默认值为 1,这意味着源节点的状态将在每个项目后进行快照。如果设置为更高的值,则源节点的状态将在每 snapshot_frequency 个项目后进行快照。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()
- next()¶
子类必须实现此方法,而不是
__next
。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置为 initial_state 传入的状态。
Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用
super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。
- class torchdata.nodes.Prefetcher(source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1)¶
基类:
BaseNode
[T
]从源节点预取数据并将其存储在队列中。
- 参数:
source (BaseNode[T]) – 要从中预取数据的源节点。
prefetch_factor (int) – 提前预取的项目数。
snapshot_frequency (int) – 对源节点状态进行快照的频率。默认值为 1,这意味着源节点的状态将在每个项目后进行快照。如果设置为更高的值,则源节点的状态将在每 snapshot_frequency 个项目后进行快照。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()
- next()¶
子类必须实现此方法,而不是
__next
。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置为 initial_state 传入的状态。
Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用
super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。
- class torchdata.nodes.SamplerWrapper(sampler: Sampler[T], initial_epoch: int = 0, epoch_updater: Optional[Callable[[int], int]] = None)¶
基类:
BaseNode
[T
]将采样器转换为 BaseNode。这与 IterableWrapper 几乎相同,除了它包含一个钩子来在采样器上调用 set_epoch(如果它支持)。
- 参数:
sampler (Sampler) – 要包装的采样器。
initial_epoch (int) – 要在采样器上设置的初始 epoch
epoch_updater (Optional[Callable[[int], int]] = None) – 在新的迭代开始时更新 epoch 的回调。它在每个迭代器请求的开始时被调用,除了第一个。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是 state_dict()。应该只由 BaseNode 调用。 :return: Dict[str, Any] - 一个状态字典,可以在将来的某个时候传递给 reset()
- next() T ¶
子类必须实现此方法,而不是
__next
。应该只由 BaseNode 调用。 :return: T - 序列中的下一个值,或抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置为 initial_state 传入的状态。
Reset 是放置昂贵初始化的好地方,因为它将在调用 next() 或 state_dict() 时延迟调用。子类必须调用
super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的 state dict。如果为 None,则重置为开头。
- class torchdata.nodes.Stateful(*args, **kwargs)¶
基类:
Protocol
用于实现
state_dict()
和load_state_dict(state_dict: Dict[str, Any])
的对象的协议
- class torchdata.nodes.StopCriteria¶
基类:
object
数据集采样器的停止标准。
CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: 一旦最后一个未见数据集耗尽就停止。所有数据集至少被看到一次。在某些情况下,当仍然存在未耗尽的数据集时,某些数据集可能会被看到多次。
ALL_DATASETS_EXHAUSTED: 一旦所有数据集都耗尽就停止。每个数据集只被看到一次。不会执行环绕或重启。
FIRST_DATASET_EXHAUSTED: 当第一个数据集耗尽时停止。