状态化数据加载器¶
StatefulDataLoader 是 torch.utils.data.DataLoader 的替代品,它提供了 state_dict
/ load_state_dict
方法来处理中间 epoch 检查点,这些方法分别作用于数据加载器请求的上一个/下一个迭代器。
默认情况下,状态包括已生成的批次数量,并使用它来简单地快速转发采样器(映射式)或数据集(可迭代式)。但是,如果采样器和/或数据集包含 state_dict
/ load_state_dict
方法,则它将在其自己的 state_dict
/ load_state_dict
调用期间调用它们。在后台,StatefulDataLoader
处理跨多进程工作进程(但不跨秩)的状态聚合和分发。
- class torchdata.stateful_dataloader.StatefulDataLoader(dataset: Dataset[_T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, sampler: Optional[Union[Sampler, Iterable]] = None, batch_sampler: Optional[Union[Sampler[List], Iterable[List]]] = None, num_workers: int = 0, collate_fn: Optional[Callable[[List[T]], Any]] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable[[int], None]] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = '', snapshot_every_n_steps: Optional[int] = 1)¶
这是
torch.utils.data.DataLoader
的替代品,它实现了 state_dict 和 load_state_dict 方法,从而支持中间 epoch 检查点。所有参数都与
torch.utils.data.DataLoader
相同,增加了一个新的关键字参数:snapshot_every_n_steps
。- 参数:
dataset (Dataset) – 要从中加载数据的 dataset。
batch_size (int, 可选) – 每个批次要加载多少个样本(默认值:
1
)。shuffle (bool, 可选) – 设置为
True
以在每个 epoch 重新洗牌数据(默认值:False
)。sampler (采样器 或 可迭代对象, 可选) – 定义从数据集中抽取样本的策略。可以是任何实现了
__len__
的Iterable
。如果指定了该参数,则不能指定shuffle
。batch_sampler (采样器 或 可迭代对象, 可选) – 与
sampler
类似,但每次返回一批索引。与batch_size
、shuffle
、sampler
和drop_last
互斥。num_workers (int, 可选) – 用于数据加载的子进程数量。
0
表示数据将在主进程中加载。(默认值:0
)collate_fn (Callable, 可选) – 将样本列表合并成一个 Tensor(s) 的小批量。在使用来自映射风格数据集的批量加载时使用。
pin_memory (bool, 可选) – 如果为
True
,则数据加载器将在返回张量之前将其复制到设备/CUDA 固定内存中。如果您的数据元素是自定义类型,或者您的collate_fn
返回一个自定义类型的批次,请参见下面的示例。drop_last (bool, 可选) – 设置为
True
以丢弃最后一个不完整的批次,如果数据集大小不能被批次大小整除。如果为False
并且数据集的大小不能被批次大小整除,则最后一个批次将更小。(默认值:False
)timeout (数值, 可选) – 如果为正数,则表示从工作进程收集批次的超时值。应始终为非负数。(默认值:
0
)worker_init_fn (Callable, 可选) – 如果不为
None
,则在播种和数据加载之前,将在每个工作进程子进程上调用此函数,并将工作进程 ID([0, num_workers - 1]
中的一个整数)作为输入。(默认值:None
)multiprocessing_context (str 或 multiprocessing.context.BaseContext, 可选) – 如果为
None
,则将使用操作系统的默认 多进程上下文。(默认值:None
)generator (torch.Generator, 可选) – 如果不为
None
,则 RandomSampler 将使用此 RNG 生成随机索引,并且多处理将使用此 RNG 为工作进程生成base_seed
。(默认值:None
)prefetch_factor (int, 可选, 关键字参数) – 每个工作进程预先加载的批次数量。
2
表示所有工作进程总共将预取 2 * num_workers 个批次。(默认值取决于 num_workers 的设置值。如果 num_workers 的值为 0,则默认为None
。否则,如果 num_workers 的值 > 0,则默认为2
)。persistent_workers (bool, 可选) – 如果为
True
,则数据加载器在数据集被使用一次后不会关闭工作进程。这允许保持工作进程的 Dataset 实例处于活动状态。(默认值:False
)pin_memory_device (str, 可选) – 如果
pin_memory
为True
,则要将内存固定到的设备。snapshot_every_n_steps (int, 可选) – 定义数据加载器工作进程的状态多久传输一次到数据加载器。默认情况下,它设置为
1
,即每一步都传输状态。如果状态很大,可以增加此值(理想情况下设置为训练检查点的频率),以减少每一步传输状态的开销。
警告
如果使用
spawn
启动方法,则worker_init_fn
不能是不可腌制的对象,例如 lambda 函数。有关 PyTorch 中多处理的更多详细信息,请参见 多处理最佳实践。警告
len(dataloader)
启发式方法基于所用采样器的长度。当dataset
是IterableDataset
时,它将返回基于len(dataset) / batch_size
的估计值,并根据drop_last
进行适当的舍入,而不管多进程加载配置如何。这代表了 PyTorch 可以做出的最佳猜测,因为 PyTorch 信任用户dataset
代码能够正确处理多进程加载以避免重复数据。但是,如果分片导致多个工作进程具有不完整的最后一个批次,则此估计值仍然可能不准确,因为 (1) 一个原本完整的批次可以被分成多个批次,并且 (2) 当
drop_last
设置时,可能会丢弃多个批次数量的样本。不幸的是,PyTorch 通常无法检测到此类情况。有关这两种类型的数据集以及
IterableDataset
如何与 多进程数据加载 交互的更多详细信息,请参见 数据集类型。警告
有关随机种子相关问题,请参阅 可重复性、数据加载器工作进程随机种子 和 数据加载随机性 说明。