注意
2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2
我们将 torchdata 仓库的重点重新聚焦到 torch.utils.data.DataLoader 的迭代增强上。我们不打算继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新审视 pytorch/pytorch 中对 DataPipes 的引用。在 torchdata==0.8.0 版本 (2024 年 7 月) 中,它们将被标记为已弃用,并在 0.10.0 版本 (2024 年底) 中删除。建议现有用户将版本固定在 torchdata<=0.9.0 或更旧的版本,直到他们能够迁移。后续版本将不包含 DataPipes 或 DataLoader V2。如果您有建议或意见,请随时联系我们(请使用此议题进行反馈)
torchdata.nodes
(beta)¶
- class torchdata.nodes.BaseNode(*args, **kwargs)¶
基类:
Iterator
[T
]BaseNodes 是在
torchdata.nodes
中创建可组合数据加载 DAG 的基类。大多数最终用户不会直接遍历 BaseNode 实例,而是将其包装在
torchdata.nodes.Loader
中,后者将 DAG 转换为更熟悉的 Iterable。node = MyBaseNodeImpl() loader = Loader(node) # loader supports state_dict() and load_state_dict() for epoch in range(5): for idx, batch in enumerate(loader): ... # or if using node directly: node = MyBaseNodeImpl() for epoch in range(5): node.reset() for idx, batch in enumerate(loader): ...
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是
state_dict()
。此方法应仅由 BaseNode 调用。- 返回:
Dict[str, Any] - 可在未来某个时间点传递给
reset()
的状态字典
- next() T ¶
子类必须实现此方法,而不是
__next__
。此方法应仅由 BaseNode 调用。- 返回:
T - 序列中的下一个值,如果序列结束则抛出 StopIteration
- reset(initial_state: Optional[dict] = None)¶
将迭代器重置到开头,或重置到 initial_state 中传递的状态。
reset()
是一个放置昂贵初始化代码的好地方,因为它会在调用next()
或state_dict()
时被延迟调用。子类必须调用super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。
- state_dict() Dict[str, Any] ¶
获取此 BaseNode 的状态字典。
- 返回:
Dict[str, Any] - 可在未来某个时间点传递给
reset()
的状态字典。
- class torchdata.nodes.Batcher(source: BaseNode[T], batch_size: int, drop_last: bool = True)¶
基类:
BaseNode
[List
[T
]]Batcher 节点将源节点的数据按 batch_size 大小打包成批次。如果源节点耗尽,它将返回当前批次或抛出 StopIteration。如果 drop_last 为 True,则如果最后一个批次小于 batch_size,它将被丢弃。如果 drop_last 为 False,即使最后一个批次小于 batch_size,它也将被返回。
- 参数:
source (BaseNode[T]) – 用于从中批量获取数据的源节点。
batch_size (int) – 批次大小。
drop_last (bool) – 如果最后一个批次小于 batch_size,是否丢弃它。默认为 True。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是
state_dict()
。此方法应仅由 BaseNode 调用。- 返回:
Dict[str, Any] - 可在未来某个时间点传递给
reset()
的状态字典
- next() List[T] ¶
子类必须实现此方法,而不是
__next__
。此方法应仅由 BaseNode 调用。- 返回:
T - 序列中的下一个值,如果序列结束则抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置到 initial_state 中传递的状态。
reset()
是一个放置昂贵初始化代码的好地方,因为它会在调用next()
或state_dict()
时被延迟调用。子类必须调用super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。
- class torchdata.nodes.IterableWrapper(iterable: Iterable[T])¶
基类:
BaseNode
[T
]一个将任何 Iterable (包括 torch.utils.data.IterableDataset) 转换为 BaseNode 的轻量级包装器。
如果 iterable 实现了 Stateful 协议,它将通过其 state_dict/load_state_dict 方法进行保存和恢复。
- 参数:
iterable (*Iterable*[*T*]) – 要转换为 BaseNode 的 Iterable。IterableWrapper 会对其调用 iter() 方法。
- 警告:
请注意定义在 Iterable 上的 state_dict/load_state_dict 与定义在 Iterator 上的区别。此处仅使用 Iterable 的 state_dict/load_state_dict。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是
state_dict()
。此方法应仅由 BaseNode 调用。- 返回:
Dict[str, Any] - 可在未来某个时间点传递给
reset()
的状态字典
- next() T ¶
子类必须实现此方法,而不是
__next__
。此方法应仅由 BaseNode 调用。- 返回:
T - 序列中的下一个值,如果序列结束则抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置到 initial_state 中传递的状态。
reset()
是一个放置昂贵初始化代码的好地方,因为它会在调用next()
或state_dict()
时被延迟调用。子类必须调用super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。
- class torchdata.nodes.Loader(root: BaseNode[T], restart_on_stop_iteration: bool = True)¶
基类:
Generic
[T
]包装根 BaseNode (一个迭代器),并提供一个有状态的 Iterable 接口。
最后返回的迭代器的状态由 state_dict() 方法返回,并可以使用 load_state_dict() 方法加载。
- 参数:
root (BaseNode[T]) – 数据管道的根节点。
restart_on_stop_iteration (*bool*) – 当迭代器到达末尾时是否重新启动。默认为 True
- load_state_dict(state_dict: Dict[str, Any])¶
加载一个 state_dict,它将用于初始化从此加载器请求的下一个 iter()。
- 参数:
state_dict (*Dict*[*str*, *Any*]) – 要加载的 state_dict。应由 state_dict() 调用生成。
- state_dict() Dict[str, Any] ¶
返回一个 state_dict,将来可以将其传递给 load_state_dict() 以恢复迭代。
state_dict 将来自最近一次调用 iter() 返回的迭代器。如果尚未创建迭代器,将创建一个新的迭代器并从中返回 state_dict。
- torchdata.nodes.MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Sampler[K]) BaseNode[T] ¶
一个将任何 MapDataset 转换为 torchdata.node 的轻量级包装器。如果您需要并行化,请复制此代码并将 Mapper 替换为 ParallelMapper。
- 参数:
map_dataset (*Mapping*[*K*, *T*]) –
对 sampler 的输出应用 map_dataset.__getitem__ 方法。
sampler (*Sampler*[*K*]) –
- torchdata.nodes.Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) ParallelMapper[T] ¶
返回一个 [`ParallelMapper`](#torchdata.nodes.ParallelMapper) 节点,其 num_workers=0,这将在当前进程/线程中执行 map_fn。
- 参数:
source (BaseNode[X]) – 要映射的源节点。
map_fn (*Callable*[[*X*], *T*]) – 应用于源节点中每个项目的函数。
- class torchdata.nodes.MultiNodeWeightedSampler(source_nodes: Mapping[str, BaseNode[T]], weights: Dict[str, float], stop_criteria: str = 'CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED', rank: Optional[int] = None, world_size: Optional[int] = None, seed: int = 0)¶
基类:
BaseNode
[T
]一个按权重从多个数据集中采样数据的节点。
此节点需要接收一个源节点字典和一个权重字典。源节点和权重的键必须相同。权重用于从源节点中采样。我们使用 torch.multinomial 从源节点中采样,关于如何使用权重进行采样,请参阅 https://pytorch.ac.cn/docs/stable/generated/torch.multinomial.html。seed 用于初始化随机数生成器。
此节点使用以下键来实现状态
DATASET_NODE_STATES_KEY: 每个源节点状态的字典。
DATASETS_EXHAUSTED_KEY: 一个布尔值字典,指示每个源节点是否已耗尽。
EPOCH_KEY: 一个用于初始化随机数生成器的 epoch 计数器。
NUM_YIELDED_KEY: 已产生的项目数量。
WEIGHTED_SAMPLER_STATE_KEY: 加权采样器的状态。
我们支持多种停止条件
CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: 循环遍历源节点直到所有数据集都耗尽。这是默认行为。
FIRST_DATASET_EXHAUSTED: 当第一个数据集耗尽时停止。
ALL_DATASETS_EXHAUSTED: 当所有数据集都耗尽时停止。
源节点完全耗尽时,节点将抛出 StopIteration。
- 参数:
source_nodes (*Mapping*[*str*, BaseNode[*T*]]) – 源节点的字典。
weights (*Dict*[*str*, *float*]) – 每个源节点的权重字典。
stop_criteria (*str*) – 停止条件。默认为 CYCLE_UNTIL_ALL_DATASETS_EXHAUST
rank (*int*) – 当前进程的排名。默认为 None,此时将从分布式环境中获取排名。
world_size (*int*) – 分布式环境的世界大小。默认为 None,此时将从分布式环境中获取世界大小。
seed (*int*) – 随机数生成器的种子。默认为 0。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是
state_dict()
。此方法应仅由 BaseNode 调用。- 返回:
Dict[str, Any] - 可在未来某个时间点传递给
reset()
的状态字典
- next() T ¶
子类必须实现此方法,而不是
__next__
。此方法应仅由 BaseNode 调用。- 返回:
T - 序列中的下一个值,如果序列结束则抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置到 initial_state 中传递的状态。
reset()
是一个放置昂贵初始化代码的好地方,因为它会在调用next()
或state_dict()
时被延迟调用。子类必须调用super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。
- class torchdata.nodes.ParallelMapper(source: BaseNode[X], map_fn: Callable[[X], T], num_workers: int, in_order: bool = True, method: Literal['thread', 'process'] = 'thread', multiprocessing_context: Optional[str] = None, max_concurrent: Optional[int] = None, snapshot_frequency: int = 1, prebatch: Optional[int] = None)¶
基类:
BaseNode
[T
]ParallelMapper 在 num_workers 数量的线程或进程中并行执行 map_fn。对于进程方式,multiprocessing_context 可以是 spawn、forkserver、fork 或 None(选择操作系统默认)。最多 max_concurrent 个项目将被处理或放入迭代器的输出队列中,以限制 CPU 和内存利用率。如果为 None(默认),该值将为 2 * num_workers。
source 只创建一个 iter(),并且最多只有一个线程同时调用其 next()。
如果 in_order 为 True,迭代器将按照项目从 source 迭代器到达的顺序返回它们,即使有其他可用项目,也可能会阻塞。
- 参数:
source (BaseNode[X]) – 要映射的源节点。
map_fn (*Callable*[[*X*], *T*]) – 应用于源节点中每个项目的函数。
num_workers (int) – 用于并行处理的 worker 数量。
in_order (bool) – 是否按照项目从源到达的顺序返回。默认为 True。
method (Literal["thread", "process"]) – 用于并行处理的方法。默认为“thread”。
multiprocessing_context (Optional[str]) – 用于并行处理的多进程上下文。默认为 None。
max_concurrent (Optional[int]) – 同时处理的最大项目数。默认为 None。
snapshot_frequency (int) – 对源节点状态进行快照的频率。默认为 1。
prebatch (Optional[int]) – 可选地在映射之前对源中的项目执行预批量处理。对于小项目,这可能会以增加峰值内存为代价来提高吞吐量。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是
state_dict()
。此方法应仅由 BaseNode 调用。- 返回:
Dict[str, Any] - 可在未来某个时间点传递给
reset()
的状态字典
- next() T ¶
子类必须实现此方法,而不是
__next__
。此方法应仅由 BaseNode 调用。- 返回:
T - 序列中的下一个值,如果序列结束则抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置到 initial_state 中传递的状态。
reset()
是一个放置昂贵初始化代码的好地方,因为它会在调用next()
或state_dict()
时被延迟调用。子类必须调用super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。
- class torchdata.nodes.PinMemory(source: BaseNode[T], pin_memory_device: str = '', snapshot_frequency: int = 1)¶
基类:
BaseNode
[T
]将底层节点的数据固定到设备上。这由 torch.utils.data._utils.pin_memory._pin_memory_loop 支持。
- 参数:
source (BaseNode[T]) – 要固定数据的源节点。
pin_memory_device (str) – 固定数据到的设备。默认为 ""。
snapshot_frequency (int) – 对源节点状态进行快照的频率。默认为 1,这意味着每处理一个项目后,都会对源节点的状态进行快照。如果设置为更高的值,则每处理 snapshot_frequency 个项目后,才会对源节点的状态进行快照。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是
state_dict()
。此方法应仅由 BaseNode 调用。- 返回:
Dict[str, Any] - 可在未来某个时间点传递给
reset()
的状态字典
- next()¶
子类必须实现此方法,而不是
__next__
。此方法应仅由 BaseNode 调用。- 返回:
T - 序列中的下一个值,如果序列结束则抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置到 initial_state 中传递的状态。
reset()
是一个放置昂贵初始化代码的好地方,因为它会在调用next()
或state_dict()
时被延迟调用。子类必须调用super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。
- class torchdata.nodes.Prefetcher(source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1)¶
基类:
BaseNode
[T
]从源节点预取数据并存储在队列中。
- 参数:
source (BaseNode[T]) – 从其预取数据的源节点。
prefetch_factor (int) – 提前预取的项目数量。
snapshot_frequency (int) – 对源节点状态进行快照的频率。默认为 1,这意味着每处理一个项目后,都会对源节点的状态进行快照。如果设置为更高的值,则每处理 snapshot_frequency 个项目后,才会对源节点的状态进行快照。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是
state_dict()
。此方法应仅由 BaseNode 调用。- 返回:
Dict[str, Any] - 可在未来某个时间点传递给
reset()
的状态字典
- next()¶
子类必须实现此方法,而不是
__next__
。此方法应仅由 BaseNode 调用。- 返回:
T - 序列中的下一个值,如果序列结束则抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置到 initial_state 中传递的状态。
reset()
是一个放置昂贵初始化代码的好地方,因为它会在调用next()
或state_dict()
时被延迟调用。子类必须调用super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。
- class torchdata.nodes.SamplerWrapper(sampler: Sampler[T], initial_epoch: int = 0, epoch_updater: Optional[Callable[[int], int]] = None)¶
基类:
BaseNode
[T
]将采样器 (sampler) 转换为 BaseNode。这与 IterableWrapper 几乎相同,但它包含一个钩子,用于在采样器支持时调用其 set_epoch 方法。
- 参数:
sampler (Sampler) – 要封装的采样器。
initial_epoch (int) – 要在采样器上设置的初始轮次。
epoch_updater (Optional[Callable[[int], int]] = None) – 在新迭代开始时更新轮次的回调函数。除了第一次迭代请求外,每次迭代请求开始时都会调用它。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是
state_dict()
。此方法应仅由 BaseNode 调用。- 返回:
Dict[str, Any] - 可在未来某个时间点传递给
reset()
的状态字典
- next() T ¶
子类必须实现此方法,而不是
__next__
。此方法应仅由 BaseNode 调用。- 返回:
T - 序列中的下一个值,如果序列结束则抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置到 initial_state 中传递的状态。
reset()
是一个放置昂贵初始化代码的好地方,因为它会在调用next()
或state_dict()
时被延迟调用。子类必须调用super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。
- class torchdata.nodes.Stateful(*args, **kwargs)¶
基类:
Protocol
为实现了
state_dict()
和load_state_dict(state_dict: Dict[str, Any])
方法的对象定义的协议。
- class torchdata.nodes.StopCriteria¶
基类:
object
数据集采样器的停止标准。
CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: 当最后一个未曾见过的数据集耗尽时停止。所有数据集至少被见过一次。在某些情况下,当仍有未耗尽的数据集时,一些数据集可能会被见过不止一次。
ALL_DATASETS_EXHAUSTED: 当所有数据集都耗尽时停止。每个数据集只被见过一次。不会执行循环或重启。
FIRST_DATASET_EXHAUSTED: 当第一个数据集耗尽时停止。
CYCLE_FOREVER: 通过重新初始化每个耗尽的源节点来循环遍历数据集。当训练器需要控制步数而不是轮次时,这很有用。
- class torchdata.nodes.Unbatcher(source: BaseNode[Sequence[T]])¶
基类:
BaseNode
[T
]Unbatcher 将从 source 拉取的批次展平,并在调用其 next() 时按顺序生成元素。
- 参数:
source (BaseNode[T]) – 从其拉取批次的源节点。
- get_state() Dict[str, Any] ¶
子类必须实现此方法,而不是
state_dict()
。此方法应仅由 BaseNode 调用。- 返回:
Dict[str, Any] - 可在未来某个时间点传递给
reset()
的状态字典
- next() T ¶
子类必须实现此方法,而不是
__next__
。此方法应仅由 BaseNode 调用。- 返回:
T - 序列中的下一个值,如果序列结束则抛出 StopIteration
- reset(initial_state: Optional[Dict[str, Any]] = None)¶
将迭代器重置到开头,或重置到 initial_state 中传递的状态。
reset()
是一个放置昂贵初始化代码的好地方,因为它会在调用next()
或state_dict()
时被延迟调用。子类必须调用super().reset(initial_state)
。- 参数:
initial_state – Optional[dict] - 要传递给节点的状态字典。如果为 None,则重置到开头。