博客

PyTorch 1.11、TorchData 和 functorch 现已推出

作者: 2022年3月10日2024年11月15日暂无评论

我们很高兴地宣布 PyTorch 1.11 正式发布(发布说明)。自 1.10 版本以来,共有 434 位贡献者提交了 3,300 多次代码变更。除 1.11 版本外,我们还发布了 TorchData 和 functorch 的测试(beta)版本。

总结

  • TorchData 是一个全新的库,提供了通用的模块化数据加载原语,用于轻松构建灵活且高性能的数据流水线。在 GitHub 上查看
  • functorch 是一个为 PyTorch 添加可组合函数变换的库,现已推出测试版本。在 GitHub 上查看
  • 分布式数据并行 (DDP) 静态图优化功能现已进入稳定版。

TorchData 介绍

我们很高兴为您带来 TorchData 的测试版本。这是一个包含通用模块化数据加载原语的库,用于轻松构建灵活且高性能的数据流水线。根据社区反馈,我们发现现有的 DataLoader 集成了太多功能,难以扩展。此外,不同的用例往往需要重复编写相同的数据加载实用程序。我们的目标是通过“DataPipes”(可组合的数据加载构建块,分为可迭代式和映射式)实现可组合的数据加载,这些构建块可以与 PyTorch 的 DataLoader 无缝配合使用。

一个 DataPipe 接收对 Python 数据结构的某种访问函数(IterDataPipe 使用 __iter__MapDataPipe 使用 __getitem__),并返回应用了简单变换后的新访问函数。您可以将多个 DataPipe 串联起来,形成一个执行所有必要数据变换的数据流水线。

我们实现了 50 多个提供不同核心功能的 DataPipe,例如打开文件、解析文本、转换样本、缓存、打乱和分批。对于希望连接云提供商(如 Google Drive 或 AWS S3)的用户,fsspec 和 iopath DataPipe 将允许您实现此功能。文档中提供了每个 IterDataPipe 和 MapDataPipe 的详细说明和使用示例。

在本版本中,一些 PyTorch 领域库已将其数据集迁移至使用 DataPipe。在 TorchText 中,该库提供的常用数据集已使用 DataPipe 实现,其 SST-2 二元文本分类教程章节展示了如何使用 DataPipe 为您的模型预处理数据。此外,TorchVision(在每日构建版本中可用)TorchRec 中也有其他使用 DataPipe 实现数据集的原型。

TorchData 的文档现已上线。其中包含一个教程,涵盖了如何使用 DataPipe、如何将其与 DataLoader 结合使用以及如何实现自定义 DataPipe。有关 DataLoader 的常见问题解答和未来计划,请参阅我们项目的 README 文件

functorch 介绍

我们很高兴地宣布 functorch 的首个测试版本。受 Google JAX 的启发,functorch 是一个为 PyTorch 添加可组合函数变换的库。它旨在提供与 PyTorch 模块和 PyTorch 自动求导功能协同工作的可组合 vmap(向量化)和自动微分变换,并具有良好的 eager-mode(即时模式)性能。

可组合的函数变换有助于解决许多目前在 PyTorch 中处理起来比较棘手的用例:

  • 计算每个样本的梯度(或每个样本的其他量)
  • 在一台机器上运行模型集成
  • 在 MAML 的内循环中高效地对任务进行批处理
  • 高效计算雅可比矩阵(Jacobians)和黑塞矩阵(Hessians)及其批处理版本

组合使用 vmap(向量化)、vjp(反向模式自动微分)和 jvp(前向模式自动微分)变换,使我们能够轻松实现上述功能,而无需为每种功能设计单独的库。

有关更多详细信息,请参阅我们的文档教程以及安装说明

分布式训练

(稳定版) DDP 静态图

DDP 静态图假设您的模型在每次迭代中都使用相同的一组已使用/未使用的参数,因此它可以在第一次迭代后确定性地获知哪些钩子(hooks)会被触发、钩子触发的次数以及梯度计算的就绪顺序。静态图在第一次迭代中缓存这些状态,从而支持以前版本中 DDP 无法支持的功能,例如,无论是否存在未使用的参数,都能支持在相同参数上进行多个激活检查点(activation checkpoints)。静态图功能还在存在未使用参数时应用性能优化,例如,避免在每次迭代中遍历图以搜索未使用的参数,并启用动态桶排序(dynamic bucketing order)。DDP 静态图中的这些优化为某些推荐模型带来了 10% 的 QPS 提升。

要启用静态图,只需在 DDP API 中设置 static_graph=True,如下所示:

ddp_model = DistributedDataParallel(model, static_graph=True)

有关更多详细信息,请参阅我们的文档教程

感谢阅读,如果您对这些更新感兴趣并希望加入 PyTorch 社区,我们鼓励您加入 讨论论坛提交 GitHub issue。要获取 PyTorch 的最新消息,请在 TwitterMediumYouTubeLinkedIn 上关注我们。

干杯!

PyTorch 团队