引言
PyTorch 2.0 (PT2) 提供了一种编译执行模式,它重写Python字节码以提取PyTorch操作序列,并将其转换为Graph IR。然后,通过可定制的后端进行即时编译,在不干预用户的情况下提高训练性能。通常,生产模型可能需要经过多个优化/下沉阶段才能达到性能目标。因此,拥有编译模式是可取的,因为它可以将提高模型性能的工作与直接修改PyTorch模型实现的工作分开。因此,编译模式变得更加重要,它使PyTorch用户无需修改PyTorch代码实现即可增强模型性能。此功能对于优化复杂模型,包括大型和生产就绪模型,尤其有价值。
在我们之前的博客文章中,我们概述了如何使用启发式模型转换规则来优化复杂的生产模型。虽然这些规则为一些试点模型带来了显著的性能提升,但它们缺乏普适性;它们在不同模型之间甚至在单个模型的不同部分中表现不一致。

图1:PT1图模式 vs PT2编译模式。
在这篇博客文章中,我们提出了一种更通用的模型转换解决方案,作为PT2编译器的插件,如图1所示,它更通用、性能更好、用户友好,无需手动操作即可为模型训练和推理带来性能改进。如图2所示,通过将以前用户定义的转换合并到编译器中,我们简化了生产堆栈。这些更改为更广泛的PyTorch模型带来了优势,而不仅仅是Meta模型,并且这些更改已包含在PT2中,可供所有PyTorch模型使用。

图2:使用PT2编译模式的简化堆栈。
指导原则:原子规则
传统上,人们可能会使用预定义的启发式规则将模型子图替换为另一个性能更好的子图,以减少启动开销,最小化内存带宽,并充分利用SM。然而,这种方法扩展性不好,因为很难制定一套完美适合所有模型的规则。
与其纠结于笨重复杂的规则,我们实际上可以把它们分解成更小、更容易理解的部分——我们称之为“原子规则”。这些微小的效率利器旨在转换单个操作符,以进行一步融合/转换。这使得它们易于处理和应用,为优化模型提供了直接的途径。因此,有了这些原子规则,优化任何模型以获得顶级性能变得轻而易举!
我们将通过一些简单的例子来演示我们如何使用一系列原子规则来取代复杂的启发式规则。
案例1:以访问嵌入表开始的计算链的水平融合
水平融合意味着将并行操作符融合为一个,以减少要启动的内核数量并提高性能。在我们之前的博客(第3.2节)中,我们描述了在嵌入包之后融合layernorm和激活函数的模型转换,如图所示。然而,这种方法有局限性:
- 它只适用于嵌入后的layernorm和激活函数。
- 它仅限于具有特定架构规则的模型,导致我们生产堆栈中的各种问题,包括参数更改和推理中断。
为了改进,我们可以使用图3所示的三个原子规则来替换复杂的启发式规则:
- 水平融合跟随相同分割节点的layernorm。
- 然后,水平融合跟随相同分割节点的tanh函数。
- 最后,融合垂直分割-连接节点。
这些原子规则为模型简化和优化提供了一种简洁流畅的方式。

图3:之前,我们通过替换子图一次性优化模型。现在,使用原子规则,我们逐步优化,涵盖了更多情况。
案例2:融合水平MLP
MLP(多层感知器)是深度神经网络的基本组成部分,通常由线性、归一化和激活函数组成。在复杂模型中,通常需要融合许多水平MLP。传统方法找到并用融合模块替换并行MLP,如图4所示,但这并非总是直接的。有些模型可能没有归一化,或者它们可能使用不同的激活函数,这使得应用一刀切的规则变得困难。
这就是我们的原子规则派上用场的地方。这些简化的规则一次只针对单个操作符,使过程更容易、更易于管理。我们使用以下原子规则进行水平MLP融合:
- 融合水平线性操作符
- 融合水平layernorm。
- 融合水平激活函数。

图4:融合MLP的伪代码。传统优化需要手动Python代码更改。
这些规则的妙处在于它们不局限于一种情况。它们可以广泛应用。由于PyTorch模型是用torch操作符构建的,因此关注更少的操作符可以简化过程。与编写特定的、大型的模式替换规则相比,这种方法不仅更易于管理,而且更具通用性,从而更容易高效地优化各种模型。
编译时图搜索
我们的原则是使用链式原子规则来取代启发式规则。虽然这种方法涵盖了更广泛的情况,但它确实需要更长的图搜索和模式匹配时间。下一个问题是:如何在高效执行编译时图搜索的同时最大限度地缩短编译时间?
我们设计了一个两步贪心算法,如图5所示。此过程的第一步是识别目标节点,我们遵循某些规则,例如,识别所有具有相同输入形状的线性操作。一旦识别出,我们使用广度优先搜索(BFS)策略将这些节点分成不同的集合,以便集合内的节点没有数据依赖关系。这些集合中的节点是独立的,可以水平融合。

图5:使用图IR进行模型转换的过程。
采用我们的方法,对于我们最大的内部模型之一,搜索时间大约为60秒,这对于即时任务来说是可以接受的。
总结
在我们对内部排名模型进行的测试中,我们观察到在torch.compile带来的性能提升之上,五个模型在训练性能方面有大约5%到15%的改进。我们已经在PT2编译器堆栈中启用了此优化,并将其默认启用,当用户选择Inductor作为后端时(配置)。我们期望我们的通用转换方法能够惠及Meta之外的模型,并期待通过此编译器级转换框架进行更多讨论和改进。
致谢
非常感谢Mark Saroufim、Gregory Chanan、Adnan Aziz和Rocky Liu的详细且富有洞察力的评论。