• 文档 >
  • 有状态 DataLoader
快捷方式

注意

2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2

我们正在重新聚焦 torchdata repo,使其成为 torch.utils.data.DataLoader 的迭代增强版本。我们不计划继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata repo 中移除。我们还将重新审视 pytorch/pytorch 中的 DataPipes 引用。在 torchdata==0.8.0 版本(2024 年 7 月)中,它们将被标记为已弃用,在 0.10.0 版本(2024 年末)中,它们将被删除。建议现有用户固定到 torchdata<=0.9.0 或更旧版本,直到他们能够迁移。后续版本将不包含 DataPipes 或 DataLoaderV2。如果您有建议或意见,请联系我们(请使用 此 issue 进行反馈)

有状态 DataLoader

StatefulDataLoader 是 torch.utils.data.DataLoader 的直接替代品,它提供 state_dict / load_state_dict 方法来处理 epoch 中途的检查点操作,这些操作作用于从 dataloader 请求的前一个/下一个迭代器(分别是)。

默认情况下,状态包括已生成的批次数,并使用它来简单地快进采样器(map-style)或数据集(iterable-style)。但是,如果采样器和/或数据集包含 state_dict / load_state_dict 方法,那么它将在其自身的 state_dict / load_state_dict 调用期间调用它们。在底层,StatefulDataLoader 处理跨多进程 worker 的状态聚合和分发(但不跨 rank)。

class torchdata.stateful_dataloader.StatefulDataLoader(dataset: Dataset[_T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, sampler: Optional[Union[Sampler, Iterable]] = None, batch_sampler: Optional[Union[Sampler[List], Iterable[List]]] = None, num_workers: int = 0, collate_fn: Optional[Callable[[List[_T]], Any]] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable[[int], None]] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = '', snapshot_every_n_steps: Optional[int] = 1)

这是 torch.utils.data.DataLoader 的直接替代品,它实现了 state_dict 和 load_state_dict 方法,从而支持 epoch 中途的检查点操作。

所有参数都与 torch.utils.data.DataLoader 相同,但新增了一个 kwarg 参数:snapshot_every_n_steps

参数:
  • dataset (Dataset) – 从中加载数据的数据集。

  • batch_size (int, optional) – 每批加载多少个样本(默认值:1)。

  • shuffle (bool, optional) – 设置为 True 以在每个 epoch 重新洗牌数据(默认值:False)。

  • sampler (SamplerIterable, optional) – 定义从数据集中抽取样本的策略。可以是任何实现了 __len__Iterable。如果指定,则不得指定 shuffle

  • batch_sampler (SamplerIterable, optional) – 类似于 sampler,但一次返回一批索引。与 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers (int, optional) – 用于数据加载的子进程数。0 表示数据将在主进程中加载。(默认值:0

  • collate_fn (Callable, optional) – 合并样本列表以形成 Tensor(s) 的小批量。当从 map-style 数据集使用批量加载时使用。

  • pin_memory (bool, optional) – 如果为 True,数据加载器将在返回 Tensor 之前将其复制到设备/CUDA 锁页内存中。如果您的数据元素是自定义类型,或者您的 collate_fn 返回的批次是自定义类型,请参阅下面的示例。

  • drop_last (bool, optional) – 设置为 True 以丢弃最后一个不完整的批次(如果数据集大小不能被批次大小整除)。如果为 False 且数据集大小不能被批次大小整除,则最后一个批次会更小。(默认值:False

  • timeout (numeric, optional) – 如果为正数,则为从 worker 收集批次的超时值。应始终为非负数。(默认值:0

  • worker_init_fn (Callable, optional) – 如果不是 None,则将在每个 worker 子进程中调用此函数,并将 worker ID([0, num_workers - 1] 中的整数)作为输入,在播种之后和数据加载之前。(默认值:None

  • multiprocessing_context (strmultiprocessing.context.BaseContext, optional) – 如果为 None,将使用您操作系统的默认 多进程上下文。(默认值:None

  • generator (torch.Generator, optional) – 如果不是 None,则 RandomSampler 将使用此 RNG 生成随机索引,多进程将使用此 RNG 生成 worker 的 base_seed。(默认值:None

  • prefetch_factor (int, optional, keyword-only arg) – 每个 worker 预先加载的批次数量。2 表示所有 worker 总共将预取 2 * num_workers 个批次。(默认值取决于 num_workers 的设置值。如果 num_workers 的值为 0,则默认值为 None。否则,如果 num_workers > 0 的值为 2)。

  • persistent_workers (bool, optional) – 如果为 True,则数据加载器在数据集被消耗一次后不会关闭 worker 进程。这允许保持 worker 的 Dataset 实例处于活动状态。(默认值:False

  • pin_memory_device (str, optional) – 如果 pin_memoryTrue,则将 pin_memory 固定到的设备。

  • snapshot_every_n_steps (int, optional) – 定义状态从数据加载器 worker 传输到数据加载器的频率。默认情况下,设置为 1,即每步传输状态。如果状态很大,则可以增加此值(理想情况下设置为训练检查点的频率),以减少每步传输状态的开销。

警告

如果使用 spawn 启动方法,则 worker_init_fn 不能是无法 pickle 化的对象,例如 lambda 函数。有关 PyTorch 中多进程的更多详细信息,请参阅 multiprocessing-best-practices

警告

len(dataloader) 启发式方法基于所用采样器的长度。当 datasetIterableDataset 时,它会返回基于 len(dataset) / batch_size 的估计值,并根据 drop_last 进行适当的舍入,而与多进程加载配置无关。这代表了 PyTorch 可以做出的最佳猜测,因为 PyTorch 信任用户 dataset 代码能够正确处理多进程加载以避免重复数据。

但是,如果分片导致多个 worker 具有不完整的最后一个批次,则此估计值仍然可能不准确,因为 (1) 否则完整的批次可能会被分成多个批次,并且 (2) 当 drop_last 设置时,可能会丢弃超过一个批次的样本。不幸的是,PyTorch 通常无法检测到此类情况。

有关这两种类型数据集以及 IterableDataset 如何与 多进程数据加载 交互的更多详细信息,请参阅 数据集类型

警告

有关随机种子相关问题,请参阅 可重复性Dataloader-workers-random-seedData-loading-randomness 注释。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源