引言
PyTorch 2.0 (PT2) 提供了一种编译执行模式,该模式重写 Python 字节码以提取 PyTorch 操作序列,并将其转换为 Graph IR。然后,Graph IR 通过可定制的后端进行即时 (JIT) 编译,从而在无需用户干预的情况下提高训练性能。通常,生产模型可能需要经过多个阶段的优化/降低,以达到性能目标。因此,拥有编译模式是可取的,因为它可以将提高模型性能的工作与直接修改 PyTorch 模型实现的工作分开。因此,编译模式变得更加重要,它使得 PyTorch 用户能够在不修改 PyTorch 代码实现的情况下提升模型性能。此功能对于优化复杂模型,包括大规模和生产就绪的模型,尤其有价值。
在我们之前的博客文章中,我们概述了如何采用启发式模型转换规则来优化复杂的生产模型。虽然这些规则为一些试点模型带来了显著的性能提升,但它们缺乏普遍适应性;它们在不同模型上表现不一致,有时甚至在同一个模型的不同部分表现也不一致。
图 1:PT1 Graph 模式对比 PT2 Compile 模式。
在这篇博客文章中,我们提出了一种更通用的模型转换解决方案,它作为 PT2 编译器的插件(如图 1 所示),具有更强的通用性、更高的性能和更好的用户友好性,可在无需人工干预的情况下提高模型训练和推理性能。如图 2 所示,通过将先前用户定义的转换集成到编译器中,我们简化了生产栈。这些改变为更广泛的 PyTorch 模型带来了优势,不仅仅局限于 Meta 模型,并且这些改进已经集成到 PT2 中,可供所有 PyTorch 模型使用。
图 2:使用 PT2 Compile 模式的简化栈。
指导原则:原子规则
传统上,人们可能会使用预定义的启发式规则来将模型的子图替换为另一个性能更高的子图,以减少启动开销、最小化内存带宽并充分利用 SM(流式多处理器)。然而,这种方法的可扩展性不佳,因为很难设计出一套完美适用于所有模型的规则。
与其苦苦应对庞大复杂的规则,我们不如将其分解成更小、更容易理解的部分——我们称之为“原子规则”。这些微小而高效的规则针对单个运算符的转换,以执行融合/转换的一个步骤。这使得它们易于处理和应用,提供了一条优化模型的直接途径。因此,掌握了这些原子规则,将任何模型优化至顶级性能变得轻而易举!
我们将通过一些简单的示例来演示如何使用原子规则链来替换复杂的启发式规则。
案例 1:以访问嵌入表开始的计算链的横向融合
横向融合意味着将并行运算符融合为一个,从而减少启动的 kernel 数量并提高性能。在我们之前的博客(第 3.2 节)中,我们描述了在 embedding bags 之后融合 layernorm 和激活函数的模型转换,如图所示。然而,这种方法存在局限性
- 它仅适用于 embedding 之后的 layernorm 和激活函数。
- 它仅限于具有特定架构规则的模型,导致我们的生产栈出现各种问题,包括参数更改和推理中断。
为了改进,我们可以使用如图 3 所示的三条原子规则来替换复杂的启发式规则
- 横向融合跟随相同 split 节点的 layernorm。
- 然后,横向融合跟随相同 split 节点的 tanh 函数。
- 最后,融合垂直的 split-cat 节点。
这些原子规则提供了一种清晰简便的模型简化和优化方法。
图 3:之前,我们通过替换子图一步到位地优化模型。现在,有了原子规则,我们可以分步优化,覆盖更多情况。
案例 2:融合横向 MLP
MLP(多层感知器)是深度神经网络的基本组成部分,通常由线性、归一化和激活函数组成。在复杂模型中,常常需要融合多个横向 MLP。传统方法如图 4 所示,通过找到并用融合模块替换并行 MLP,但这并非总是直接的。有些模型可能没有归一化,或者使用不同的激活函数,这使得应用通用的规则变得困难。
这就是我们的原子规则发挥作用的地方。这些简化的规则一次只针对一个运算符,使过程更简单、更易于管理。我们使用以下原子规则进行横向 MLP 融合
- 融合横向线性运算符
- 融合横向 layernorm。
- 融合横向激活函数。
图 4:融合 MLP 的伪代码。传统优化需要手动修改 Python 代码。
这些规则的妙处在于它们不局限于某一种情况,可以广泛应用。由于 PyTorch 模型是使用 torch 运算符构建的,所以专注于更小的一组运算符可以简化流程。与编写特定的、大型模式替换规则相比,这种方法不仅更易于管理,而且更具通用性,使得能够更有效地优化各种模型。
编译时图搜索
我们的原则是使用链式原子规则来替换启发式规则。虽然这种方法覆盖了更广泛的情况,但它确实需要更长的图搜索和模式匹配时间。接下来的问题是:如何在高效执行编译时图搜索的同时最大限度地减少编译时间?
我们设计了一种两步贪婪算法,如图 5 所示。此过程的第一步是识别目标节点,我们遵循某些规则进行识别,例如,识别所有具有相同输入形状的线性操作。识别后,我们使用广度优先搜索 (BFS) 策略将这些节点分成不同的集合,以便集合内的节点没有数据依赖关系。每个集合内的节点都是独立的,可以横向融合。
图 5:使用 Graph IR 的模型转换过程。
使用我们的方法,对于我们最大的内部模型之一,搜索时间大约为 60 秒,这对于即时任务来说是可以接受的。
总结
在我们对内部排序模型进行的测试中,我们观察到在 torch.compile 带来的性能提升基础上,五个模型的训练性能平均提高了 5% 到 15%。我们已在 PT2 编译器栈中启用了此优化,并在用户选择 Inductor 作为后端时将其设置为默认选项 (配置)。我们期望我们通用的转换方法能够使 Meta 模型以外的模型受益,并期待通过这个编译器级别的转换框架进行更多讨论和改进。
致谢
非常感谢 Mark Saroufim、Gregory Chanan、Adnan Aziz 和 Rocky Liu 提供的详细且富有洞察力的评审意见。