注意
2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2
我们正在重新聚焦 torchdata 仓库,使其成为对 torch.utils.data.DataLoader 的迭代增强。我们不计划继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新审视 pytorch/pytorch 中对 DataPipes 的引用。在版本 torchdata==0.8.0(2024 年 7 月)中,它们将被标记为已弃用,并在 0.10.0(2024 年底)中被删除。建议现有用户锁定版本到 torchdata<=0.9.0 或更旧版本,直到他们能够迁移。后续版本将不包含 DataPipes 或 DataLoaderV2。如果您有建议或意见,请联系我们(请使用 此议题 提供反馈)
什么是 torchdata.nodes
(beta)?¶
torchdata.nodes
是一个由可组合的迭代器(不是可迭代对象!)组成的库,它允许您将常见的数据加载和预处理操作串联起来。它遵循流式编程模型,不过如果需要,仍然可以配置“采样器 + Map 风格”的方式。
torchdata.nodes
为标准的 torch.utils.data
提供了更多灵活性,除了多进程(torch.utils.data.DataLoader
中唯一支持的方法)之外,还引入了多线程并行性,并且通过 state_dict/load_state_dict
接口提供了对训练中途检查点的第一类支持。
torchdata.nodes
努力包含尽可能多的有用操作符,但它的设计宗旨是可扩展的。新的节点需要继承 torchdata.nodes.BaseNode
(它本身继承自 typing.Iterator
),并实现 next()
、reset(initial_state)
和 get_state()
操作(值得注意的是,不包括 __next__
、load_state_dict
和 state_dict
)
请参阅 torchdata.nodes (beta) 入门 开始使用
为什么选择 torchdata.nodes
?¶
我们明白,torch.utils.data
对于很多用例来说已经足够。然而,它确实存在一些不足之处
多进程的缺点¶
您需要复制存储在 Dataset 中的内存(由于 Python 的写时复制特性)
IPC(进程间通信)在多进程队列上速度较慢,并可能导致启动时间变慢
您被迫在 worker 进程而不是主进程上执行批处理,以减少 IPC 开销,这会增加峰值内存占用。
随着释放 GIL 的函数和 Free-Threaded Python 的出现,多线程可能不再像以前那样受 GIL 限制。
torchdata.nodes
同时支持多线程和多进程,因此您可以根据您的特定设置选择最合适的方式。并行性主要在 Mapper 操作符中配置,这使您在并行化什么、何时以及如何并行化方面具有灵活性。
Map 风格和随机访问无法扩展¶
当前的 Map 风格数据集方法非常适合能够完全载入内存的数据集,但一旦数据集大小超出内存限制,真正的随机访问性能将不佳,除非您通过特殊的采样器绕过一些限制。
torchdata.nodes
遵循流式数据模型,其中的操作符是 Iterators,可以组合起来定义数据加载和预处理管道。采样器仍然支持(详见 从 torch.utils.data 迁移到 torchdata.nodes),并且可以与 Mapper 结合使用以生成 Iterator
多数据集与 torch.utils.data
中的当前实现结合不佳¶
当您尝试组合多个数据集时,当前的 Sampler(每个 dataloader 一个)概念开始失效。(对于单个数据集,它们是一个很好的抽象,并且将继续得到支持!)
对于多数据集,考虑以下场景:
len(dsA): 10
len(dsB): 20
。现在我们想在这两个数据集之间进行轮询(或均匀采样)以馈送到我们的训练器。只有一个采样器时,如何实现这种策略?也许是一个会发出元组的采样器?如果您想替换成 RandomSampler 或 DistributedSampler 呢?sampler.set_epoch
将如何工作?
torchdata.nodes
通过只处理 Iterators,从而将采样器和数据集整合在一起,专注于将较小的基本节点组合成更复杂的数据加载管道,从而帮助解决和扩展多数据集加载问题。
IterableDataset + 多进程需要额外的数据集分片¶
数据并行训练需要数据集分片,这是相当合理的。但是 dataloader worker 之间的分片呢?对于 Map 风格数据集,worker 之间的工作分配由主进程处理,主进程将采样器索引分发给 worker。对于 IterableDatasets,每个 worker 需要自行确定(通过 torch.utils.data.get_worker_info
)应该返回哪些数据。
torchdata.nodes
的性能如何?¶
我们在 PyTorch Conf 2024 上的一个视频解码基准测试中展示了早期版本的 torchdata.nodes
的一些结果,表明
在使用多进程时,
torchdata.nodes
的性能与torch.utils.data.DataLoader
持平或更优(详见 从 torch.utils.data 迁移到 torchdata.nodes)在使用 GIL Python 时,
torchdata.nodes
的多线程在某些场景下性能优于多进程,并且使得 GPU 预处理等功能更容易实现,这可以提高许多用例的吞吐量。在使用 No-GIL / Free-Threaded Python (3.13t) 时,我们运行了一个从磁盘加载 Imagenet 数据集的基准测试,并在远低于多进程 worker 的 CPU 利用率下成功地使主内存带宽达到饱和(博客文章预计 2025 年初发布)。请参阅 imagenet_benchmark.py 在您自己的硬件上尝试。
设计选择¶
不支持 Generator BaseNodes¶
有关更多思考,请参阅 https://github.com/pytorch/data/pull/1362。
我们做的一个艰难选择是在定义新的 BaseNode 实现时禁止使用 Generators。然而,我们放弃了这个想法,转而采用基于 Iterator 的基础,原因围绕状态管理有几点:
我们在 BaseNode 实现中要求显式的状态处理。Generators 将状态隐式存储在堆栈上,我们发现需要绕很多弯子并编写非常复杂的代码才能使基本状态与 Generators 一起工作
迭代结束时的 state dict:Iterable 对象可能感觉更自然,但在状态管理方面会出现很多问题。考虑迭代结束时的 state dict。如果您将此 state_dict 加载到您的 iterable 中,这应该代表迭代的结束还是下一次迭代的开始?
加载状态:如果您在 iterable 对象上调用 load_state_dict(),大多数用户会期望从它请求的下一个 iterator 以加载的状态开始。但是如果在迭代开始前调用了两次 iter 怎么办?
多个活跃 Iterator 问题:如果您有一个 Iterable 实例,但有两个活跃的 iterator,在 Iterable 上调用 state_dict() 意味着什么?在数据加载中,这种情况非常罕见,但我们仍然需要绕开这个问题并做出许多假设。在我们看来,强迫实现 BaseNodes 的开发者考虑这些场景,比禁止使用 generators 和 Iterables 更糟糕。
torchdata.nodes.BaseNode
实现是 Iterators。Iterators 定义了 next()
、get_state()
和 reset(initial_state | None)
。所有重新初始化都应在 reset() 中完成,包括如果传递了特定状态,则使用该状态进行初始化。
然而,最终用户习惯于处理 Iterable 对象,例如,
for epoch in range(5):
# Most frameworks and users don't expect to call loader.reset()
for batch in loader:
...
sd = loader.state_dict()
# Loading sd should not throw StopIteration right away, but instead start at the next epoch
为了处理这种情况,我们将所有假设和特殊的 epoch 结束处理都放在一个单独的 Loader
类中,该类接受任何 BaseNode 并使其成为 Iterable,负责处理 reset() 调用和 epoch 结束时的 state_dict 加载。