1. 引言
PyTorch支持两种执行模式[1]:即时模式(eager mode)和图模式(graph mode)。在即时模式下,模型中的运算符在遇到时立即执行。相反,在图模式下,运算符首先被合成为一个图,然后作为一个整体进行编译和执行。即时模式更易于使用,更适合机器学习研究人员,因此是默认的执行模式。另一方面,图模式通常提供更高的性能,因此在生产环境中被大量使用。
具体来说,图模式支持算子融合[2],其中一个算子与另一个算子合并,以减少/局部化内存读取以及总内核启动开销。融合可以是水平的——将单个操作(例如BatchNorm)独立应用于多个操作数,并将这些操作数合并到一个数组中;也可以是垂直的——将一个内核与消耗第一个内核输出的另一个内核合并(例如,卷积后跟ReLU)。
Torch.FX[3, 4](简称FX)是PyTorch软件包中公开可用的工具包,支持图模式执行。它特别具备以下功能:(1)从PyTorch程序中捕获图,以及(2)允许开发人员对捕获的图进行转换。它在Meta内部用于优化生产模型的训练吞吐量。通过引入Meta开发的一些基于FX的优化,我们演示了使用图转换来优化PyTorch生产性能的方法。
2. 背景
嵌入表在推荐系统中无处不在。第3节将讨论三个优化嵌入表访问的FX转换。本节将提供关于FX(第2.1节)和嵌入表(第2.2节)的一些背景信息。
2.1 FX
图1是一个简单的例子,来源于[3],它演示了如何使用FX转换PyTorch程序。它包含三个步骤:(1)从程序中捕获图,(2)修改图(在这个例子中,所有使用RELU的地方都被GELU替换),以及(3)从修改后的图中生成一个新的程序。

图1:一个FX示例,将PyTorch模块中所有RELU的使用替换为GELU。
FX API[4]提供了更多用于检查和转换PyTorch程序图的功能。
2.2 嵌入表

图2:稀疏特征的嵌入表示例,批大小为1
在推荐系统中,稀疏特征(例如,用户ID、故事ID)通过嵌入表表示。嵌入表E是一个HxD矩阵,其中H是哈希大小,D是嵌入维度。E的每一行是一个浮点向量。特征哈希[5]用于将稀疏特征映射到E的索引列表,例如[S1,S2, …, Sk],其中0<=Si
为了充分利用GPU,稀疏特征通常以批次方式处理。批次中的每个实体都有自己的索引列表。如果一个批次有B个实体,朴素的表示方式是B个索引列表。一种更紧凑的表示方式是将B个索引列表合并成一个单一的索引列表,并添加一个索引长度列表(批次中每个实体一个长度)。例如,如果一个批次有3个实体,它们的索引列表如下:
- 实体1:索引 = [10, 20]
- 实体2:索引 = [5, 9, 77, 81]
- 实体3:索引 = [15, 20, 45]
那么整个批次的索引和长度将是
- 索引 = [10, 20, 5, 9, 77, 81, 15, 20, 45]
- 长度 = [2, 4, 3]
整个批次的嵌入表查询输出将是一个BxD矩阵。
3. 三种FX转换
我们开发了三种FX转换来加速嵌入表的访问。第3.1节讨论了一种将多个小型输入张量合并为一个大型张量的转换;第3.2节讨论了一种将多个并行的计算链融合为一个计算链的转换;第3.3节讨论了一种重叠通信与计算的转换。
3.1 组合输入稀疏特征
回想一下,批处理中的输入稀疏特征由两个列表表示:一个索引列表和一个长度为B的列表,其中B是批处理大小。在PyTorch中,这两个列表被实现为两个张量。当PyTorch模型在GPU上运行时,嵌入表通常存储在GPU内存中(这更接近GPU,并且比CPU内存具有更高的读/写带宽)。要使用输入稀疏特征,它的两个张量需要首先从CPU复制到GPU。然而,每次主机到设备内存复制都需要启动一个内核,这相对于实际数据传输时间而言是相对昂贵的。如果一个模型使用许多输入稀疏特征,这种复制可能会成为性能瓶颈(例如,1000个输入稀疏特征需要从主机复制2000个张量到设备)。
一种减少主机到设备内存复制(memcpy)次数的优化方法是在将多个输入稀疏特征发送到设备之前将其组合。例如,给定以下三个输入特征:
- 特征_A:索引 = [106, 211, 7],长度 = [2, 1]
- 特征_B:索引 = [52, 498, 616, 870, 1013],长度 = [3, 2]
- 特征_C:索引 = [2011, 19, 351, 790],长度 = [1, 3]
组合形式为
- 特征_A_B_C:索引 = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790],长度 = [2, 1, 3, 2, 1, 3]
因此,我们只需要复制2个张量,而不是从主机复制3×2=6个张量。
图3(b)描述了这种优化的实现,它包含两个组件:
- 在CPU端:输入流水线被修改,将稀疏特征的所有索引合并到一个张量中,同样地,将所有长度合并到另一个张量中。然后这两个张量被复制到GPU。
- 在GPU端:使用FX,我们在模型图中插入一个Permute_and_Split操作,从组合张量中恢复单个特征的索引和长度张量,并将它们路由到相应的下游节点。

(a). 未经优化

(b). 经过优化
图3:组合输入稀疏特征
3.2 嵌入表访问开始的计算链的横向融合
在生产模型中,每个GPU上拥有数十个嵌入表是相当常见的。出于性能考虑,对这些表的查找被分组在一起,以便它们的输出被连接成一个大的张量(参见图4(a)中的红色部分)。为了对单个特征输出进行计算,使用Split操作将大张量分割成N个较小的张量(其中N是特征的数量),然后对每个张量应用所需的计算。这在图4(a)中显示,其中应用于每个特征输出O的计算是Tanh(LayerNorm(O))。所有计算结果被连接回一个大张量,然后传递给下游操作(图4(a)中的Op1)。
这里主要的运行时开销是GPU内核启动开销。例如,图4(a)中GPU内核启动的数量是2*N + 3(图中每个椭圆形代表一个GPU内核)。这可能会成为一个性能问题,因为LayerNorm和Tanh在GPU上的执行时间相对于其内核启动时间较短。此外,Split操作可能会创建嵌入输出张量的额外副本,消耗额外的GPU内存。
我们使用FX实现了一种称为水平融合的优化,它显著减少了GPU内核启动的数量(在本例中,优化后的GPU内核启动数量为5,参见图4(b))。我们不进行显式Split,而是使用Add_middle_dim操作将形状为(B, NxD)的2D嵌入张量重塑为形状为(B, N, D)的3D张量。然后对它的最后一个维度应用单个LayerNorm。然后对LayerNorm的结果应用单个Tanh。最后,我们使用Remove_middle_dim操作将Tanh的结果重塑回2D张量。此外,由于Add_middle_dim和Remove_middle_dim只重塑张量而不创建额外副本,因此GPU内存消耗也可以减少。

(a). 未经优化

(b). 经过优化
图4:水平融合
3.3 计算与通信重叠
生产推荐模型的训练通常在分布式GPU系统上进行。由于每个GPU的设备内存容量不足以容纳模型中的所有嵌入表,因此需要将它们分布在不同的GPU之间。
在一个训练步骤中,GPU需要从其他GPU上的嵌入表读写特征值。这被称为all-to-all通信[6],并且可能是一个主要的性能瓶颈。
我们使用FX实现了一种可以使计算与全对全(all-to-all)通信重叠的转换。图5(a)展示了一个模型图的例子,其中包含嵌入表访问(EmbeddingAllToAll)和其他操作。在没有任何优化的情况下,它们在GPU流上按顺序执行,如图5(b)所示。使用FX,我们将EmbeddingAllToAll分解为EmbeddingAllToAll_Request和EmbeddingAllToAll_Wait,并在它们之间调度独立的op。

(a) 模型图

(b) 原始执行顺序

(c) 优化后的执行顺序
图5:计算与通信重叠
3.4 总结
表1总结了本节讨论的优化及其解决的相应性能瓶颈。
优化 | 解决的性能瓶颈 |
组合输入稀疏特征 | 主机到设备内存复制 |
水平融合 | GPU内核启动开销 |
计算与通信重叠 | 嵌入全对全访问时间 |
表1:优化及其解决的性能瓶颈摘要
由于篇幅限制,我们还开发了其他FX转换,但未在本节中讨论。
为了发现哪些模型将从这些转换中受益,我们分析了MAIProf[7]从Meta数据中心运行的模型中收集的性能数据。总而言之,这些转换在一些生产模型上相比即时模式提供了高达2-3倍的加速。
4. 结束语
出于性能原因,在生产环境中,PyTorch的图模式优于即时模式。FX是一个强大的工具,用于捕获和优化PyTorch程序的图。我们展示了在Meta内部用于优化生产推荐模型的三个FX转换。我们希望这篇博客能激励其他PyTorch模型开发者使用图转换来提升其模型的性能。
参考文献
[1] 端到端机器学习框架
[2] DNNFusion: 利用高级算子融合加速深度神经网络执行
[3] Torch.FX: Python中深度学习的实用程序捕获与转换, MLSys 2022。
[5] 大规模多任务学习的特征哈希
[6] NVIDIA 集体通信库文档