作者:CK Luk, Daohang Shi, Yuzhen Huang, Jackie (Jiaqi) Xu, Jade Nie, Zhou Wang, Lu Fang, Flavio Sales Truzzi, Devashish Shankar, Dima Ivashchenko, Chunzhi Yang, Nicolas Macchioni, David Berard, Yu Guo, Xiaodong Wang, Bert Maher, Yanbo Liang, Edward Yang, Brian Hirsh, Michael Voznesensky, Animesh Jain, Michael Anderson

1. 引言

PyTorch 2.0 (简称 PT2) 可以使用名为 torch.compile 的编译器显著提高 AI 模型的训练和推理性能,同时与 PyTorch 1.x 100% 向后兼容。已有报告阐述了 PT2 如何提高常见 基准测试 (例如 huggingface 的 diffusers) 的性能。在本博客中,我们将讨论我们在 Meta 将 PT2 应用于 生产 AI 模型中的经验。

2. 背景

2.1 为什么自动化性能优化对于生产至关重要?

性能对于生产尤其重要——例如,即使常用模型的训练时间减少 5%,也能显著节省 GPU 成本和数据中心 电力。另一个重要指标是 开发效率,它衡量将模型投入生产所需投入的工程师月数。通常,这项启动工作的很大一部分花在 手动 性能调优上,例如重写 GPU 内核以提高训练速度。通过提供 自动化 性能优化,PT2 可以同时提高成本和开发效率。

2.2 PT2 如何提升性能

作为编译器,PT2 可以查看从模型捕获的训练图中的 多个 操作(与 PT1.x 不同,PT1.x 一次只执行一个操作)。因此,PT2 可以利用多种性能优化机会,包括

  • 将多个操作融合到单个 GPU 内核中
    • 运行 GPU 程序时的典型性能开销是启动小型 GPU 内核的 CPU 开销。通过将多个操作融合到单个 GPU 内核中,PT2 可以显著减少 CPU 上的内核启动开销。例如,考虑图 1(a) 中的 PyTorch 程序。当它在 PT1 中于 GPU 上执行时,它有三个 GPU 内核(两个用于两个 sin() 操作,一个用于加法操作)。使用 PT2 时,只生成一个内核,它融合了所有这三个操作。
    • 融合某些操作后,图中的某些操作可能会变成“死代码”(dead code),因此可以被优化掉。这可以节省 GPU 上的计算资源和内存带宽。例如,在图 1(b) 中,一个重复的 sin() 操作可以被优化掉。
    • 此外,融合还可以减少 GPU 设备内存读写(通过组合逐点内核)并帮助提高硬件利用率。

Fig.1  How PT2 improves performance with fusion and dead-code elimination.

图 1: PT2 如何通过融合和死代码消除提升性能。

  • 减少使用低精度数据类型的类型转换开销
    • PyTorch 1.x 支持 自动混合精度 (AMP)。虽然 AMP 可以减少操作的计算时间,但它在操作之前和之后会引入类型转换开销。PT2 可以通过优化掉不必要的类型转换代码来提高 AMP 性能,显著减少其开销。例如,图 2(a) 在执行矩阵乘法之前将三个 32 位输入张量 (a32, b32, c32) 转换为 bf16。然而,在这个例子中,a32 和 c32 实际上是同一个张量 (a_float32)。因此,没有必要对 a_float32 进行两次转换,如图 2(b) 中由 torch.compile 生成的代码所示。请注意,虽然本例和前一个例子都优化掉了冗余计算,但它们是不同的:本例中的类型转换代码通过 torch.autocast 是 隐式的,而前一个例子中 torch.sin(x).cuda() 在用户代码中是 显式的

Fig.2  How PT2 reduces type conversion overhead when using AMP.

图 2: PT2 如何在使用 AMP 时减少类型转换开销。

  • 重用 GPU 上的缓冲区
    • 通过全局视图,torch.compile 中的调度器可以重用 GPU 上的缓冲区,从而减少内存分配时间和内存消耗。图 3 显示了调用为图 2(a) 中的程序生成的 Triton 内核的驱动程序。我们可以看到 buf1 被重用为 buf4

Fig.3  Reuse of buffers.

图 3: 缓冲区重用。

  • 自动调优
    • PT2 提供了选项,可以在矩阵乘法操作、逐点操作和归约操作上启用自动调优(通过 Triton)。可调参数包括块大小、阶段数和 warp 数量。通过自动调优,可以通过实验找到一个操作性能最佳的实现。

3. 生产环境考量

在本节中,我们将描述在将 PT2 应用于生产环境时的一些重要考量。

3.1 确保使用 torch.compile 后模型质量不下降

将 torch.compile 应用于模型会导致数值变化,原因在于 (1) 在融合等各种优化过程中浮点操作的重新排序,以及 (2) 如果启用 AMP,会使用 bf16 等低精度数据类型。因此,不能期望与 PT 1.x 实现 100% 的比特级兼容性。然而,我们仍然需要确保在应用 torch.compile 后模型质量(以某种数值评分衡量)得到保留。通常,每个生产模型都有自己可接受的评分范围(例如,百分比变化必须在 0.01% 以内)。

如果 torch.compile 导致模型质量下降,我们需要进行深入调试。

调试与 torch.compile 相关的数值问题的一个有用技术是,除了“inductor”之外,尝试使用不同的后端应用 torch.compile,特别是“eager”和“aot_eager”

  • 如果数值问题在使用“eager”后端时发生,则 torch.compile 构建的前向图可能不正确;
  • 如果数值问题在使用“eager”时未发生,但在使用“aot_eager”时发生,则 torch.compile 构建的后向图可能不正确;
  • 如果数值问题在使用“eager”或“aot_eager”时均未发生,但在使用“inductor”时发生,则 inductor 内部的代码生成可能不正确。

3.2 生产环境中的自动调优

默认情况下,torch.inductor 中的自动调优是在模型执行期间 在线 完成的。对于某些生产模型,我们发现自动调优时间可能需要数小时,这对于生产环境是不可接受的。因此,我们增加了 离线自动调优 功能,其工作原理如图 4 所示。模型首次运行时,所有需要调优的操作的详细信息(例如,输入张量形状、数据类型等)将被记录到数据库中。然后,对这些操作进行调优过程,连夜搜索每个操作性能最佳的实现;搜索结果会更新到持久性缓存中(实现为 torch.inductor 的一个源文件)。下次再次运行模型时,每个操作的调优实现将在缓存中找到并选择执行。

Fig.4  The offline autotuning used in production.

图 4: 生产环境中使用的离线自动调优。

3.3 对 torch.compile 的性能分析支持

正如我们之前在这篇博客中讨论的那样,性能分析器对于调试生产模型的性能至关重要。我们增强了性能分析器,使其在时间轴上显示与 torch.compile 相关的事件。其中最有用的功能是标记模型的哪些部分正在运行编译后的代码,这样我们就可以快速验证模型中应该被 torch.compile 编译的部分是否真的被编译了。例如,图 5 中的跟踪记录有两个编译区域(标记为“CompiledFunction”)。其他有用的事件包括编译所花费的时间以及访问编译器代码缓存所花费的时间。

Fig.5  A trace with two compiled regions.

图 5: 具有两个编译区域的跟踪记录。

3.4 控制即时编译时间

torch.compile 使用即时编译(just-in-time compilation)。编译发生在第一批数据被训练时。在我们的生产环境中,对训练作业达到第一批数据所需的时间有一个上限,即 首批时间 (Time-To-First-Batch, TTFB)。我们需要确保启用 torch.compile 不会将 TTFB 增加超过上限。这可能具有挑战性,因为生产模型很大,并且 torch.compile 可能需要大量的编译时间。我们启用了 并行编译 以控制编译时间(这由 torch/_inductor/config.py 中的全局变量 compile_threads 控制,在 OSS Linux 上该变量已设置为 CPU 核心数)。模型被分解为一个或多个计算图;每个图被分解为多个 Triton 内核。如果启用并行编译,同一图中的所有 Triton 内核可以同时编译(然而,不同图中的内核仍然是串行编译的)。图 6 说明了并行编译如何提供帮助。

Fig.6  Using parallel compilation in production.

图 6: 在生产环境中使用并行编译。

4. 结果

在本节中,我们使用三个生产模型来评估 PT2。首先,我们将展示使用不同优化配置时 PT2 带来的训练时间加速效果。其次,我们将展示并行编译对编译时间的重要性。

4.1 使用 torch.compile 带来的训练时间加速

图 7 报告了使用 PT2 带来的训练时间加速效果。对于每个模型,我们展示了四种情况:(i) 未编译,使用 bf16;(ii) 编译,使用 fp32;(iii) 编译,使用 bf16;(iv) 编译,使用 bf16 并自动调优。y 轴表示相对于基准线(即未编译,使用 fp32)的加速比。请注意,未编译,使用 bf16 实际上比未编译,使用 fp32 要慢,这是由于类型转换开销所致。相比之下,编译并使用 bf16 通过减少大部分开销,实现了更大的加速。总体而言,考虑到这些模型已经经过大量手动优化,我们很高兴看到 torch.compile 仍然可以提供 1.14-1.24 倍的加速。

Fig.7 Training-time speedup with torch.compile (note: the baseline, no-compile/fp32, is  omitted in this figure).

图 7: 使用 torch.compile 带来的训练时间加速(注意:图中省略了基准线,即未编译/fp32)。

4.2 使用并行编译缩短编译时间

图 8 显示了使用和不使用并行编译时的编译时间。虽然串行编译时间仍有改进空间,但并行编译已将 TTFB 的编译开销降低到可接受的水平。模型 B 和 C 从并行编译中受益更多,原因在于它们每个图具有更多不同的 Triton 内核。

Fig.8 PT2 compilation time.

图 8: PT2 编译时间。

5. 总结与展望

在本博客中,我们展示了 PT2 可以显著加速大型复杂生产级 AI 模型的训练,且编译时间合理。在下一篇博客中,我们将讨论 PT2 如何执行通用的图转换。

6. 致谢

非常感谢 Mark Saroufim, Adnan Aziz, 和 Gregory Chanan 提供的详细且富有洞见的审阅意见。