• 文档 >
  • 什么是 torchdata.nodes (beta)?
快捷方式

注意

2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2

我们正在重新聚焦 torchdata 仓库,使其成为对 torch.utils.data.DataLoader 的迭代增强。我们不计划继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新审视 pytorch/pytorch 中对 DataPipes 的引用。在版本 torchdata==0.8.0(2024 年 7 月)中,它们将被标记为已弃用,并在 0.10.0(2024 年底)中被删除。建议现有用户锁定版本到 torchdata<=0.9.0 或更旧版本,直到他们能够迁移。后续版本将不包含 DataPipes 或 DataLoaderV2。如果您有建议或意见,请联系我们(请使用 此议题 提供反馈)

什么是 torchdata.nodes (beta)?

torchdata.nodes 是一个由可组合的迭代器(不是可迭代对象!)组成的库,它允许您将常见的数据加载和预处理操作串联起来。它遵循流式编程模型,不过如果需要,仍然可以配置“采样器 + Map 风格”的方式。

torchdata.nodes 为标准的 torch.utils.data 提供了更多灵活性,除了多进程(torch.utils.data.DataLoader 中唯一支持的方法)之外,还引入了多线程并行性,并且通过 state_dict/load_state_dict 接口提供了对训练中途检查点的第一类支持。

torchdata.nodes 努力包含尽可能多的有用操作符,但它的设计宗旨是可扩展的。新的节点需要继承 torchdata.nodes.BaseNode(它本身继承自 typing.Iterator),并实现 next()reset(initial_state)get_state() 操作(值得注意的是,不包括 __next__load_state_dictstate_dict

请参阅 torchdata.nodes (beta) 入门 开始使用

为什么选择 torchdata.nodes

我们明白,torch.utils.data 对于很多用例来说已经足够。然而,它确实存在一些不足之处

多进程的缺点

  • 您需要复制存储在 Dataset 中的内存(由于 Python 的写时复制特性)

  • IPC(进程间通信)在多进程队列上速度较慢,并可能导致启动时间变慢

  • 您被迫在 worker 进程而不是主进程上执行批处理,以减少 IPC 开销,这会增加峰值内存占用。

  • 随着释放 GIL 的函数和 Free-Threaded Python 的出现,多线程可能不再像以前那样受 GIL 限制。

torchdata.nodes 同时支持多线程和多进程,因此您可以根据您的特定设置选择最合适的方式。并行性主要在 Mapper 操作符中配置,这使您在并行化什么、何时以及如何并行化方面具有灵活性。

Map 风格和随机访问无法扩展

当前的 Map 风格数据集方法非常适合能够完全载入内存的数据集,但一旦数据集大小超出内存限制,真正的随机访问性能将不佳,除非您通过特殊的采样器绕过一些限制。

torchdata.nodes 遵循流式数据模型,其中的操作符是 Iterators,可以组合起来定义数据加载和预处理管道。采样器仍然支持(详见 从 torch.utils.data 迁移到 torchdata.nodes),并且可以与 Mapper 结合使用以生成 Iterator

多数据集与 torch.utils.data 中的当前实现结合不佳

当您尝试组合多个数据集时,当前的 Sampler(每个 dataloader 一个)概念开始失效。(对于单个数据集,它们是一个很好的抽象,并且将继续得到支持!)

  • 对于多数据集,考虑以下场景:len(dsA): 10 len(dsB): 20。现在我们想在这两个数据集之间进行轮询(或均匀采样)以馈送到我们的训练器。只有一个采样器时,如何实现这种策略?也许是一个会发出元组的采样器?如果您想替换成 RandomSampler 或 DistributedSampler 呢?sampler.set_epoch 将如何工作?

torchdata.nodes 通过只处理 Iterators,从而将采样器和数据集整合在一起,专注于将较小的基本节点组合成更复杂的数据加载管道,从而帮助解决和扩展多数据集加载问题。

IterableDataset + 多进程需要额外的数据集分片

数据并行训练需要数据集分片,这是相当合理的。但是 dataloader worker 之间的分片呢?对于 Map 风格数据集,worker 之间的工作分配由主进程处理,主进程将采样器索引分发给 worker。对于 IterableDatasets,每个 worker 需要自行确定(通过 torch.utils.data.get_worker_info)应该返回哪些数据。

torchdata.nodes 的性能如何?

我们在 PyTorch Conf 2024 上的一个视频解码基准测试中展示了早期版本的 torchdata.nodes 的一些结果,表明

  • 在使用多进程时,torchdata.nodes 的性能与 torch.utils.data.DataLoader 持平或更优(详见 从 torch.utils.data 迁移到 torchdata.nodes

  • 在使用 GIL Python 时,torchdata.nodes 的多线程在某些场景下性能优于多进程,并且使得 GPU 预处理等功能更容易实现,这可以提高许多用例的吞吐量。

  • 在使用 No-GIL / Free-Threaded Python (3.13t) 时,我们运行了一个从磁盘加载 Imagenet 数据集的基准测试,并在远低于多进程 worker 的 CPU 利用率下成功地使主内存带宽达到饱和(博客文章预计 2025 年初发布)。请参阅 imagenet_benchmark.py 在您自己的硬件上尝试。

设计选择

不支持 Generator BaseNodes

有关更多思考,请参阅 https://github.com/pytorch/data/pull/1362

我们做的一个艰难选择是在定义新的 BaseNode 实现时禁止使用 Generators。然而,我们放弃了这个想法,转而采用基于 Iterator 的基础,原因围绕状态管理有几点:

  1. 我们在 BaseNode 实现中要求显式的状态处理。Generators 将状态隐式存储在堆栈上,我们发现需要绕很多弯子并编写非常复杂的代码才能使基本状态与 Generators 一起工作

  2. 迭代结束时的 state dict:Iterable 对象可能感觉更自然,但在状态管理方面会出现很多问题。考虑迭代结束时的 state dict。如果您将此 state_dict 加载到您的 iterable 中,这应该代表迭代的结束还是下一次迭代的开始?

  3. 加载状态:如果您在 iterable 对象上调用 load_state_dict(),大多数用户会期望从它请求的下一个 iterator 以加载的状态开始。但是如果在迭代开始前调用了两次 iter 怎么办?

  4. 多个活跃 Iterator 问题:如果您有一个 Iterable 实例,但有两个活跃的 iterator,在 Iterable 上调用 state_dict() 意味着什么?在数据加载中,这种情况非常罕见,但我们仍然需要绕开这个问题并做出许多假设。在我们看来,强迫实现 BaseNodes 的开发者考虑这些场景,比禁止使用 generators 和 Iterables 更糟糕。

torchdata.nodes.BaseNode 实现是 Iterators。Iterators 定义了 next()get_state()reset(initial_state | None)。所有重新初始化都应在 reset() 中完成,包括如果传递了特定状态,则使用该状态进行初始化。

然而,最终用户习惯于处理 Iterable 对象,例如,

for epoch in range(5):
  # Most frameworks and users don't expect to call loader.reset()
  for batch in loader:
    ...
  sd = loader.state_dict()
  # Loading sd should not throw StopIteration right away, but instead start at the next epoch

为了处理这种情况,我们将所有假设和特殊的 epoch 结束处理都放在一个单独的 Loader 类中,该类接受任何 BaseNode 并使其成为 Iterable,负责处理 reset() 调用和 epoch 结束时的 state_dict 加载。

文档

查阅全面的 PyTorch 开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源