• 文档 >
  • 状态化数据加载器
快捷方式

状态化数据加载器

StatefulDataLoader 是 torch.utils.data.DataLoader 的替代品,它提供了 state_dict / load_state_dict 方法来处理中间 epoch 检查点,这些方法分别作用于数据加载器请求的上一个/下一个迭代器。

默认情况下,状态包括已生成的批次数量,并使用它来简单地快速转发采样器(映射式)或数据集(可迭代式)。但是,如果采样器和/或数据集包含 state_dict / load_state_dict 方法,则它将在其自己的 state_dict / load_state_dict 调用期间调用它们。在后台,StatefulDataLoader 处理跨多进程工作进程(但不跨秩)的状态聚合和分发。

class torchdata.stateful_dataloader.StatefulDataLoader(dataset: Dataset[_T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, sampler: Optional[Union[Sampler, Iterable]] = None, batch_sampler: Optional[Union[Sampler[List], Iterable[List]]] = None, num_workers: int = 0, collate_fn: Optional[Callable[[List[T]], Any]] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable[[int], None]] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = '', snapshot_every_n_steps: Optional[int] = 1)

这是 torch.utils.data.DataLoader 的替代品,它实现了 state_dict 和 load_state_dict 方法,从而支持中间 epoch 检查点。

所有参数都与 torch.utils.data.DataLoader 相同,增加了一个新的关键字参数:snapshot_every_n_steps

参数:
  • dataset (Dataset) – 要从中加载数据的 dataset。

  • batch_size (int, 可选) – 每个批次要加载多少个样本(默认值:1)。

  • shuffle (bool, 可选) – 设置为 True 以在每个 epoch 重新洗牌数据(默认值:False)。

  • sampler (采样器可迭代对象, 可选) – 定义从数据集中抽取样本的策略。可以是任何实现了 __len__Iterable。如果指定了该参数,则不能指定 shuffle

  • batch_sampler (采样器可迭代对象, 可选) – 与 sampler 类似,但每次返回一批索引。与 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers (int, 可选) – 用于数据加载的子进程数量。0 表示数据将在主进程中加载。(默认值:0

  • collate_fn (Callable, 可选) – 将样本列表合并成一个 Tensor(s) 的小批量。在使用来自映射风格数据集的批量加载时使用。

  • pin_memory (bool, 可选) – 如果为 True,则数据加载器将在返回张量之前将其复制到设备/CUDA 固定内存中。如果您的数据元素是自定义类型,或者您的 collate_fn 返回一个自定义类型的批次,请参见下面的示例。

  • drop_last (bool, 可选) – 设置为 True 以丢弃最后一个不完整的批次,如果数据集大小不能被批次大小整除。如果为 False 并且数据集的大小不能被批次大小整除,则最后一个批次将更小。(默认值:False

  • timeout (数值, 可选) – 如果为正数,则表示从工作进程收集批次的超时值。应始终为非负数。(默认值:0

  • worker_init_fn (Callable, 可选) – 如果不为 None,则在播种和数据加载之前,将在每个工作进程子进程上调用此函数,并将工作进程 ID([0, num_workers - 1] 中的一个整数)作为输入。(默认值:None

  • multiprocessing_context (strmultiprocessing.context.BaseContext, 可选) – 如果为 None,则将使用操作系统的默认 多进程上下文。(默认值:None

  • generator (torch.Generator, 可选) – 如果不为 None,则 RandomSampler 将使用此 RNG 生成随机索引,并且多处理将使用此 RNG 为工作进程生成 base_seed。(默认值:None

  • prefetch_factor (int, 可选, 关键字参数) – 每个工作进程预先加载的批次数量。2 表示所有工作进程总共将预取 2 * num_workers 个批次。(默认值取决于 num_workers 的设置值。如果 num_workers 的值为 0,则默认为 None。否则,如果 num_workers 的值 > 0,则默认为 2)。

  • persistent_workers (bool, 可选) – 如果为 True,则数据加载器在数据集被使用一次后不会关闭工作进程。这允许保持工作进程的 Dataset 实例处于活动状态。(默认值:False

  • pin_memory_device (str, 可选) – 如果 pin_memoryTrue,则要将内存固定到的设备。

  • snapshot_every_n_steps (int, 可选) – 定义数据加载器工作进程的状态多久传输一次到数据加载器。默认情况下,它设置为 1,即每一步都传输状态。如果状态很大,可以增加此值(理想情况下设置为训练检查点的频率),以减少每一步传输状态的开销。

警告

如果使用 spawn 启动方法,则 worker_init_fn 不能是不可腌制的对象,例如 lambda 函数。有关 PyTorch 中多处理的更多详细信息,请参见 多处理最佳实践

警告

len(dataloader) 启发式方法基于所用采样器的长度。当 datasetIterableDataset 时,它将返回基于 len(dataset) / batch_size 的估计值,并根据 drop_last 进行适当的舍入,而不管多进程加载配置如何。这代表了 PyTorch 可以做出的最佳猜测,因为 PyTorch 信任用户 dataset 代码能够正确处理多进程加载以避免重复数据。

但是,如果分片导致多个工作进程具有不完整的最后一个批次,则此估计值仍然可能不准确,因为 (1) 一个原本完整的批次可以被分成多个批次,并且 (2) 当 drop_last 设置时,可能会丢弃多个批次数量的样本。不幸的是,PyTorch 通常无法检测到此类情况。

有关这两种类型的数据集以及 IterableDataset 如何与 多进程数据加载 交互的更多详细信息,请参见 数据集类型

警告

有关随机种子相关问题,请参阅 可重复性数据加载器工作进程随机种子数据加载随机性 说明。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源