博客

PyTorch 1.11、TorchData 和 functorch 现已推出

作者 2022年3月10日2024年11月15日暂无评论

我们非常高兴地宣布 PyTorch 1.11 正式发布(发布说明)。自 1.10 版本以来,该版本包含了 434 位贡献者提交的 3,300 多次代码提交。随 1.11 一起发布的还有 TorchData 和 functorch 的测试版。

总结

  • TorchData 是一个全新的库,提供通用的模块化数据加载原语,旨在轻松构建灵活且高性能的数据流水线。在 GitHub 上查看
  • functorch 是一个为 PyTorch 增加可组合函数变换的库,现已推出测试版。在 GitHub 上查看
  • 分布式数据并行(DDP)静态图优化现已进入稳定版。

引入 TorchData

我们很高兴地推出 TorchData 的测试版。这是一个包含通用模块化数据加载原语的库,旨在轻松构建灵活且高性能的数据流水线。根据社区反馈,现有的 DataLoader 捆绑了太多功能,且难以扩展。此外,不同的应用场景往往需要重复编写相同的数据加载工具。我们的目标是通过名为“DataPipes”的 Iterable(可迭代)式和 Map(映射)式构建块,实现可组合的数据加载,这些构建块可以与 PyTorch 的 DataLoader 无缝配合使用。

一个 DataPipe 接收对 Python 数据结构的某种访问函数(IterDataPipe 使用 __iter__MapDataPipe 使用 __getitem__),并返回一个应用了简单变换后的新访问函数。你可以将多个 DataPipe 串联起来,形成一个执行所有必要数据变换的数据流水线。

我们已经实现了 50 多个提供不同核心功能的 DataPipe,例如打开文件、解析文本、转换样本、缓存、打乱和批处理。对于希望连接到云服务提供商(如 Google Drive 或 AWS S3)的用户,fsspec 和 iopath DataPipes 将满足你的需求。文档中提供了每个 IterDataPipe 和 MapDataPipe 的详细说明和使用示例。

在此版本中,一些 PyTorch 域库已迁移其数据集以使用 DataPipes。在 TorchText 中,该库提供的常用数据集已使用 DataPipes 实现,其 SST-2 二元文本分类教程的一节演示了如何使用 DataPipes 为模型预处理数据。此外,TorchVision(在每日构建版本中可用)TorchRec 中也有其他使用 DataPipes 的原型数据集实现。

TorchData 的文档现已上线。它包含一篇教程,涵盖了如何使用 DataPipes、如何将其与 DataLoader 结合使用以及如何实现自定义 DataPipe。与 DataLoader 相关的常见问题解答和未来计划详见我们项目的 README 文件

引入 functorch

我们很高兴地宣布 functorch 的首个测试版发布。深受 Google JAX 的启发,functorch 是一个为 PyTorch 添加可组合函数变换的库。它旨在提供可组合的 vmap(向量化)和自动微分变换,能够与 PyTorch 模块和 PyTorch autograd 配合使用,并具备良好的即时模式(eager-mode)性能。

可组合函数变换可以帮助解决目前在 PyTorch 中处理起来比较棘手的多种用例:

  • 计算每个样本的梯度(或每个样本的其他量)
  • 在一台机器上运行模型集成
  • 高效地对 MAML 内循环中的任务进行批处理
  • 高效地计算雅可比矩阵(Jacobians)和海森矩阵(Hessians)及其批处理版本

组合使用 vmap(向量化)、vjp(反向模式自动微分)和 jvp(前向模式自动微分)变换,使我们能够轻松地表达上述操作,而无需为每种操作设计单独的库。

有关更多详细信息,请参阅我们的文档教程安装说明

分布式训练

(稳定版)DDP 静态图

DDP 静态图假设你的模型在每次迭代中都使用相同的一组已使用/未使用的参数,因此它可以在第一次迭代后确定性地知道诸如哪些钩子(hooks)将触发、钩子触发的次数以及梯度计算就绪的顺序等状态。静态图在第一次迭代中缓存这些状态,从而能够支持 DDP 在先前版本中无法支持的功能,例如,无论是否存在未使用的参数,都支持在同一参数上进行多次激活检查点设置。静态图功能还在存在未使用参数时应用性能优化,例如,它避免了在每次迭代中遍历图来搜索未使用参数,并支持动态桶排序(bucketing order)。这些在 DDP 静态图中的优化为某些推荐模型带来了 10% 的 QPS(每秒查询率)提升。

要启用静态图,只需在 DDP API 中设置 static_graph=True,如下所示:

ddp_model = DistributedDataParallel(model, static_graph=True)

有关更多详细信息,请参阅我们的文档教程

感谢阅读,如果您对这些更新感兴趣并希望加入 PyTorch 社区,我们鼓励您加入 讨论论坛提交 GitHub issue。要获取 PyTorch 的最新消息,请在 TwitterMediumYouTubeLinkedIn 上关注我们。

干杯!

PyTorch 团队