• 文档 >
  • 什么是 torchdata.nodes (beta)?
快捷方式

注意

2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2

我们正在重新调整 torchdata 仓库的重点,使其成为 torch.utils.data.DataLoader 的迭代增强。我们不计划继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新审视 pytorch/pytorch 中的 DataPipes 引用。在 torchdata==0.8.0 版本(2024 年 7 月)中,它们将被标记为已弃用,在 0.10.0 版本(2024 年末)中,它们将被删除。建议现有用户固定到 torchdata<=0.9.0 或更旧版本,直到他们能够迁移走。后续版本将不包含 DataPipes 或 DataLoaderV2。如果您有建议或意见,请联系我们(请使用 此 issue 进行反馈)

什么是 torchdata.nodes (beta)?

torchdata.nodes 是可组合迭代器(不是可迭代对象!)的库,可让您将常见的数据加载和预处理操作链接在一起。它遵循流式编程模型,尽管如果需要,仍然可以配置“采样器 + Map-style”。

torchdata.nodes 为标准 torch.utils.data 产品增加了更多灵活性,并引入了多线程并行性以及多进程(torch.utils.data.DataLoader 中唯一支持的方法),以及通过 state_dict/load_state_dict 接口对 epoch 中途检查点的一流支持。

torchdata.nodes 努力包含尽可能多的有用运算符,但它被设计为可扩展的。新的节点需要继承 torchdata.nodes.BaseNode(它本身继承自 typing.Iterator)并实现 next()reset(initial_state)get_state() 操作(值得注意的是,不是 __next__load_state_dictstate_dict

请参阅 开始使用 torchdata.nodes (beta) 以开始使用

为什么选择 torchdata.nodes

我们理解,torch.utils.data 对于许多用例来说已经足够好用。然而,它肯定存在一些粗糙之处

多进程很糟糕

  • 您需要复制存储在 Dataset 中的内存(因为 Python 的写时复制)

  • IPC 通过多进程队列速度很慢,并且可能导致启动时间缓慢

  • 您被迫在工作进程而不是主进程上执行批处理,以减少 IPC 开销,从而增加峰值内存。

  • 借助 GIL 释放函数和自由线程 Python,多线程可能不像以前那样受 GIL 限制。

torchdata.nodes 既支持多线程又支持多进程,因此您可以选择最适合您特定设置的方法。并行性主要在 Mapper 运算符中配置,使您可以灵活地选择并行化的内容、时间和方式。

Map-style 和随机访问无法扩展

当前的 map 数据集方法非常适合内存中的数据集,但是一旦数据集增长超出内存限制,除非您使用特殊的采样器跳过一些步骤,否则真正的随机访问性能将不会很高。

torchdata.nodes 遵循流式数据模型,其中运算符是迭代器,可以组合在一起以定义数据加载和预处理管道。采样器仍然受支持(请参阅上面的示例),并且可以与 Mapper 组合以生成迭代器

多数据集与 torch.utils.data 中的当前实现不太兼容

当您开始尝试组合多个数据集时,当前的采样器(每个数据加载器一个)概念开始崩溃。(对于单个数据集,它们是一个很好的抽象,并将继续受到支持!)

  • 对于多数据集,请考虑以下场景:len(dsA): 10 len(dsB): 20。现在我们想要在这两个数据集之间进行轮询(或均匀采样)以馈送到我们的训练器。仅使用单个采样器,您如何实现该策略?也许是一个发出元组的采样器?如果您想与 RandomSampler 或 DistributedSampler 交换怎么办?sampler.set_epoch 将如何工作?

torchdata.nodes 通过仅处理迭代器来帮助解决和扩展多数据集数据加载,从而强制采样器和数据集一起工作,专注于将较小的原始节点组合成更复杂的数据加载管道。

IterableDataset + 多进程需要额外的数据集分片

数据并行训练需要数据集分片,这是相当合理的。但是,数据加载器工作进程之间的分片呢?对于 Map-style 数据集,工作进程之间工作分配由主进程处理,主进程将采样器索引分发给工作进程。对于 IterableDatasets,每个工作进程都需要弄清楚(通过 torch.utils.data.get_worker_info)它应该返回什么数据。

torchdata.nodes 的性能如何?

我们在 PyTorch Conf 2024 的视频解码基准测试中展示了 torchdata.nodes 早期版本的一些结果,我们在其中表明

  • 当使用多进程时,torchdata.nodes 的性能与 torch.utils.data.DataLoader 相当或更好(请参阅 从 torch.utils.data 迁移到 torchdata.nodes

  • 对于 GIL python,在某些情况下,使用多线程的 torchdata.nodes 比多进程性能更好,但使 GPU 预处理等功能更容易执行,从而可以提高性能

我们运行了一个基准测试,从磁盘加载 Imagenet 数据集,并设法在 CPU 利用率明显低于多进程工作进程的情况下,使用自由线程 Python (3.13t) 饱和主内存带宽(预计在 2025 年初发布博文)。请参阅 examples/nodes/imagenet_benchmark.py

设计选择

没有生成器 BaseNodes

有关更多想法,请参阅 https://github.com/pytorch/data/pull/1362

我们做出的一个艰难选择是不允许在定义新的 BaseNode 实现时使用生成器。但是,我们放弃了它,并转向仅迭代器的基础,原因是一些围绕状态管理的问题

  1. 我们要求在 BaseNode 实现中显式处理状态。生成器在堆栈上隐式存储状态,我们发现我们需要跳过一些步骤并编写非常复杂的代码才能使基本状态与生成器一起工作

  2. 迭代结束状态字典:可迭代对象可能感觉更自然,但是围绕状态管理出现了很多问题。考虑迭代结束状态字典。如果您将此 state_dict 加载到您的可迭代对象中,这应该表示迭代结束还是下一次迭代的开始?

  3. 加载状态:如果您在可迭代对象上调用 load_state_dict(),大多数用户会期望从中请求的下一个迭代器以加载的状态开始。但是,如果在迭代开始之前调用 iter 两次会发生什么?

  4. 多个活动迭代器问题:如果您有一个可迭代对象的实例,但有两个活动迭代器,那么在可迭代对象上调用 state_dict() 意味着什么?在数据加载中,这种情况非常罕见,但是我们仍然需要解决它并做出一些假设。在我们看来,强迫正在实现 BaseNode 的开发人员推理这些场景比不允许生成器和可迭代对象更糟糕。

torchdata.nodes.BaseNode 实现是迭代器。迭代器定义 next()get_state()reset(initial_state | None)。所有重新初始化都应在 reset() 中完成,包括使用特定状态(如果传递了状态)进行初始化。

但是,最终用户习惯于处理可迭代对象,例如,

for epoch in range(5):
  # Most frameworks and users don't expect to call loader.reset()
  for batch in loader:
    ...
  sd = loader.state_dict()
  # Loading sd should not throw StopIteration right away, but instead start at the next epoch

为了处理这个问题,我们将所有假设和特殊的 epoch 结束处理都保留在一个 Loader 类中,该类接受任何 BaseNode 并使其成为可迭代对象,处理 reset() 调用和 epoch 结束 state_dict 加载。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源