• 文档 >
  • 有状态 DataLoader
快捷方式

注意

2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2

我们正将 torchdata 仓库重新聚焦于对 torch.utils.data.DataLoader 的迭代增强。我们不计划继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新审视 pytorch/pytorch 中对 DataPipes 的引用。在 torchdata==0.8.0 版本(2024 年 7 月)中,它们将被标记为已弃用,并在 0.10.0 版本(2024 年末)中被删除。建议现有用户在迁移之前,将版本固定到 torchdata<=0.9.0 或更旧的版本。后续版本将不再包含 DataPipes 或 DataLoaderV2。如果您有任何建议或意见,请联系我们(请使用此问题进行反馈)

有状态 DataLoader

StatefulDataLoader 是 torch.utils.data.DataLoader 的直接替代品,它提供了 state_dict / load_state_dict 方法,用于处理训练中途检查点,这些方法分别作用于从 dataloader 请求的上一个/下一个迭代器。

默认情况下,状态包含已产生的批次数,并以此来简单地快进采样器 (map 风格) 或数据集 (iterable 风格)。但是,如果采样器和/或数据集包含 state_dict / load_state_dict 方法,那么它将在自身的 state_dict / load_state_dict 调用期间调用它们。在底层,StatefulDataLoader 处理多进程工作进程之间的状态聚合和分发(但不处理跨 rank 的状态)。

class torchdata.stateful_dataloader.StatefulDataLoader(dataset: Dataset[_T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, sampler: Optional[Union[Sampler, Iterable]] = None, batch_sampler: Optional[Union[Sampler[List], Iterable[List]]] = None, num_workers: int = 0, collate_fn: Optional[Callable[[List[_T]], Any]] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable[[int], None]] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = '', in_order: bool = True, snapshot_every_n_steps: Optional[int] = 1)

这是 torch.utils.data.DataLoader 的直接替代品,它实现了 state_dict 和 load_state_dict 方法,从而支持训练中途检查点。

所有参数与 torch.utils.data.DataLoader 相同,新增了一个关键字参数:snapshot_every_n_steps

参数:
  • dataset (Dataset) – 用于加载数据的 dataset。

  • batch_size (int, optional) – 每批加载的样本数(默认值:1)。

  • shuffle (bool, optional) – 设置为 True 则每个 epoch 重新打乱数据(默认值:False)。

  • sampler (Sampler or Iterable, optional) – 定义从数据集中抽取样本的策略。可以是任何实现了 __len__Iterable。如果指定此参数,则不得指定 shuffle

  • batch_sampler (Sampler or Iterable, optional) – 类似于 sampler,但一次返回一批索引。与 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers (int, optional) – 用于数据加载的子进程数。0 表示数据将在主进程中加载。(默认值:0

  • collate_fn (Callable, optional) – 将样本列表合并以形成 Tensor(s) 的 mini-batch。在使用 map 风格数据集进行批量加载时使用。

  • pin_memory (bool, optional) – 如果为 True,则数据加载器在返回 Tensor 之前会将其复制到设备/CUDA 锁定内存中。如果您的数据元素是自定义类型,或者您的 collate_fn 返回的是自定义类型的 batch,请参见下面的示例。

  • drop_last (bool, optional) – 设置为 True 则丢弃最后一个不完整的批次(如果数据集大小不能被批次大小整除)。如果为 False 且数据集大小不能被批次大小整除,则最后一个批次会较小。(默认值:False

  • timeout (numeric, optional) – 如果为正数,则表示从工作进程收集批次的超时值。应始终为非负数。(默认值:0

  • worker_init_fn (Callable, optional) – 如果不为 None,则将在每个工作子进程上调用此函数,输入为工作进程 ID(一个介于 [0, num_workers - 1] 之间的整数),调用时间在 seeding 之后、数据加载之前。(默认值:None

  • multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – 如果为 None,将使用操作系统默认的多进程上下文。(默认值:None

  • generator (torch.Generator, optional) – 如果不为 None,则 RandomSampler 将使用此 RNG 生成随机索引,多进程将使用此 RNG 为工作进程生成 base_seed。(默认值:None

  • prefetch_factor (int, optional, keyword-only arg) – 每个工作进程预加载的批次数。 2 表示所有工作进程总共预加载 2 * num_workers 个批次。(默认值取决于 num_workers 的设置值。如果 num_workers=0,默认值为 None。否则,如果 num_workers > 0,默认值为 2)。

  • persistent_workers (bool, optional) – 如果为 True,数据加载器在数据集被消费一次后不会关闭工作进程。这使得工作进程的 Dataset 实例能够保持活动。(默认值:False

  • pin_memory_device (str, optional) – 如果 pin_memoryTrue,则将 pin_memory 到此设备。

  • in_order (bool, optional) – 如果为 False,数据加载器将不强制批次以先入先出的顺序返回。仅当 num_workers > 0 时有效。(默认值:True

  • snapshot_every_n_steps (int, optional) – 定义状态从数据加载器工作进程传输到数据加载器的频率。默认设置为 1,即每一步传输状态。如果状态很大,可以增加此值(理想情况下设置为训练检查点的频率),以减少每一步传输状态的开销。

警告

如果使用 spawn 启动方法,则 worker_init_fn 不能是不可 pickle 的对象,例如 lambda 函数。有关 PyTorch 中多进程的更多详细信息,请参见多进程最佳实践

警告

len(dataloader) 的启发式方法基于所用采样器的长度。当 dataset 是一个 IterableDataset 时,它返回一个基于 len(dataset) / batch_size 的估计值,并根据 drop_last 进行适当的四舍五入,与多进程加载配置无关。这代表了 PyTorch 能做出的最佳猜测,因为 PyTorch 信任用户的 dataset 代码能够正确处理多进程加载以避免数据重复。

然而,如果分片导致多个工作进程的最后一个批次不完整,这个估计值仍然可能不准确,因为 (1) 一个原本完整的批次可能被分成多个,并且 (2) 当设置了 drop_last 时,可能会丢弃不止一个批次的样本。遗憾的是,PyTorch 通常无法检测到此类情况。

有关这两种数据集类型以及 IterableDataset 如何与多进程数据加载交互的更多详细信息,请参见数据集类型

警告

有关随机种子相关问题,请参见可重现性Dataloader 工作进程随机种子数据加载随机性等注意事项。

警告

in_order 设置为 False 可能会损害可重现性,并在数据不平衡的情况下导致提供给训练器的数据分布倾斜。

警告

in_order 设置为 False 目前对状态管理没有保证。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并解答您的疑问

查看资源