跳转到主要内容
博客

通过图转换优化生产 PyTorch 模型的性能

作者: 2022年11月28日2024年11月14日暂无评论

1. 引言

PyTorch支持两种执行模式[1]:即时模式(eager mode)和图模式(graph mode)。在即时模式下,模型中的运算符在遇到时立即执行。相反,在图模式下,运算符首先被合成为一个图,然后作为一个整体进行编译和执行。即时模式更易于使用,更适合机器学习研究人员,因此是默认的执行模式。另一方面,图模式通常提供更高的性能,因此在生产环境中被大量使用。

具体来说,图模式支持算子融合[2],其中一个算子与另一个算子合并,以减少/局部化内存读取以及总内核启动开销。融合可以是水平的——将单个操作(例如BatchNorm)独立应用于多个操作数,并将这些操作数合并到一个数组中;也可以是垂直的——将一个内核与消耗第一个内核输出的另一个内核合并(例如,卷积后跟ReLU)。

Torch.FX[3, 4](简称FX)是PyTorch软件包中公开可用的工具包,支持图模式执行。它特别具备以下功能:(1)从PyTorch程序中捕获图,以及(2)允许开发人员对捕获的图进行转换。它在Meta内部用于优化生产模型的训练吞吐量。通过引入Meta开发的一些基于FX的优化,我们演示了使用图转换来优化PyTorch生产性能的方法。

2. 背景

嵌入表在推荐系统中无处不在。第3节将讨论三个优化嵌入表访问的FX转换。本节将提供关于FX(第2.1节)和嵌入表(第2.2节)的一些背景信息。

2.1 FX

图1是一个简单的例子,来源于[3],它演示了如何使用FX转换PyTorch程序。它包含三个步骤:(1)从程序中捕获图,(2)修改图(在这个例子中,所有使用RELU的地方都被GELU替换),以及(3)从修改后的图中生成一个新的程序。

图1:一个FX示例,将PyTorch模块中所有RELU的使用替换为GELU。

FX API[4]提供了更多用于检查和转换PyTorch程序图的功能。

2.2 嵌入表

图2:稀疏特征的嵌入表示例,批大小为1

在推荐系统中,稀疏特征(例如,用户ID、故事ID)通过嵌入表表示。嵌入表E是一个HxD矩阵,其中H是哈希大小,D是嵌入维度。E的每一行是一个浮点向量。特征哈希[5]用于将稀疏特征映射到E的索引列表,例如[S1,S2, …, Sk],其中0<=Si1], E[S2], …, E[Sk]),其中E[Si]是第Si行的向量,f称为池化函数,通常是以下函数之一:求和(sum)、平均(average)、最大值(maximum)。参见图2的示例。

为了充分利用GPU,稀疏特征通常以批次方式处理。批次中的每个实体都有自己的索引列表。如果一个批次有B个实体,朴素的表示方式是B个索引列表。一种更紧凑的表示方式是将B个索引列表合并成一个单一的索引列表,并添加一个索引长度列表(批次中每个实体一个长度)。例如,如果一个批次有3个实体,它们的索引列表如下:

  • 实体1:索引 = [10, 20]
  • 实体2:索引 = [5, 9, 77, 81]
  • 实体3:索引 = [15, 20, 45]

那么整个批次的索引和长度将是

  • 索引 = [10, 20, 5, 9, 77, 81, 15, 20, 45]
  • 长度 = [2, 4, 3]

整个批次的嵌入表查询输出将是一个BxD矩阵。

3. 三种FX转换

我们开发了三种FX转换来加速嵌入表的访问。第3.1节讨论了一种将多个小型输入张量合并为一个大型张量的转换;第3.2节讨论了一种将多个并行的计算链融合为一个计算链的转换;第3.3节讨论了一种重叠通信与计算的转换。

3.1 组合输入稀疏特征

回想一下,批处理中的输入稀疏特征由两个列表表示:一个索引列表和一个长度为B的列表,其中B是批处理大小。在PyTorch中,这两个列表被实现为两个张量。当PyTorch模型在GPU上运行时,嵌入表通常存储在GPU内存中(这更接近GPU,并且比CPU内存具有更高的读/写带宽)。要使用输入稀疏特征,它的两个张量需要首先从CPU复制到GPU。然而,每次主机到设备内存复制都需要启动一个内核,这相对于实际数据传输时间而言是相对昂贵的。如果一个模型使用许多输入稀疏特征,这种复制可能会成为性能瓶颈(例如,1000个输入稀疏特征需要从主机复制2000个张量到设备)。

一种减少主机到设备内存复制(memcpy)次数的优化方法是在将多个输入稀疏特征发送到设备之前将其组合。例如,给定以下三个输入特征:

  • 特征_A:索引 = [106, 211, 7],长度 = [2, 1]
  • 特征_B:索引 = [52, 498, 616, 870, 1013],长度 = [3, 2]
  • 特征_C:索引 = [2011, 19, 351, 790],长度 = [1, 3]

组合形式为

  • 特征_A_B_C:索引 = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790],长度 = [2, 1, 3, 2, 1, 3]

因此,我们只需要复制2个张量,而不是从主机复制3×2=6个张量。

图3(b)描述了这种优化的实现,它包含两个组件:

  • 在CPU端:输入流水线被修改,将稀疏特征的所有索引合并到一个张量中,同样地,将所有长度合并到另一个张量中。然后这两个张量被复制到GPU。
  • 在GPU端:使用FX,我们在模型图中插入一个Permute_and_Split操作,从组合张量中恢复单个特征的索引和长度张量,并将它们路由到相应的下游节点。

(a). 未经优化

(b). 经过优化

图3:组合输入稀疏特征

3.2 嵌入表访问开始的计算链的横向融合

在生产模型中,每个GPU上拥有数十个嵌入表是相当常见的。出于性能考虑,对这些表的查找被分组在一起,以便它们的输出被连接成一个大的张量(参见图4(a)中的红色部分)。为了对单个特征输出进行计算,使用Split操作将大张量分割成N个较小的张量(其中N是特征的数量),然后对每个张量应用所需的计算。这在图4(a)中显示,其中应用于每个特征输出O的计算是Tanh(LayerNorm(O))。所有计算结果被连接回一个大张量,然后传递给下游操作(图4(a)中的Op1)。

这里主要的运行时开销是GPU内核启动开销。例如,图4(a)中GPU内核启动的数量是2*N + 3(图中每个椭圆形代表一个GPU内核)。这可能会成为一个性能问题,因为LayerNorm和Tanh在GPU上的执行时间相对于其内核启动时间较短。此外,Split操作可能会创建嵌入输出张量的额外副本,消耗额外的GPU内存。

我们使用FX实现了一种称为水平融合的优化,它显著减少了GPU内核启动的数量(在本例中,优化后的GPU内核启动数量为5,参见图4(b))。我们不进行显式Split,而是使用Add_middle_dim操作将形状为(B, NxD)的2D嵌入张量重塑为形状为(B, N, D)的3D张量。然后对它的最后一个维度应用单个LayerNorm。然后对LayerNorm的结果应用单个Tanh。最后,我们使用Remove_middle_dim操作将Tanh的结果重塑回2D张量。此外,由于Add_middle_dim和Remove_middle_dim只重塑张量而不创建额外副本,因此GPU内存消耗也可以减少。

(a). 未经优化

(b). 经过优化

图4:水平融合

3.3 计算与通信重叠

生产推荐模型的训练通常在分布式GPU系统上进行。由于每个GPU的设备内存容量不足以容纳模型中的所有嵌入表,因此需要将它们分布在不同的GPU之间。

在一个训练步骤中,GPU需要从其他GPU上的嵌入表读写特征值。这被称为all-to-all通信[6],并且可能是一个主要的性能瓶颈。

我们使用FX实现了一种可以使计算与全对全(all-to-all)通信重叠的转换。图5(a)展示了一个模型图的例子,其中包含嵌入表访问(EmbeddingAllToAll)和其他操作。在没有任何优化的情况下,它们在GPU流上按顺序执行,如图5(b)所示。使用FX,我们将EmbeddingAllToAll分解为EmbeddingAllToAll_Request和EmbeddingAllToAll_Wait,并在它们之间调度独立的op。

(a) 模型图

(b) 原始执行顺序

(c) 优化后的执行顺序

图5:计算与通信重叠

3.4 总结

表1总结了本节讨论的优化及其解决的相应性能瓶颈。

优化解决的性能瓶颈
组合输入稀疏特征主机到设备内存复制
水平融合GPU内核启动开销
计算与通信重叠嵌入全对全访问时间

表1:优化及其解决的性能瓶颈摘要

由于篇幅限制,我们还开发了其他FX转换,但未在本节中讨论。

为了发现哪些模型将从这些转换中受益,我们分析了MAIProf[7]从Meta数据中心运行的模型中收集的性能数据。总而言之,这些转换在一些生产模型上相比即时模式提供了高达2-3倍的加速。

4. 结束语

出于性能原因,在生产环境中,PyTorch的图模式优于即时模式。FX是一个强大的工具,用于捕获和优化PyTorch程序的图。我们展示了在Meta内部用于优化生产推荐模型的三个FX转换。我们希望这篇博客能激励其他PyTorch模型开发者使用图转换来提升其模型的性能。

参考文献

[1] 端到端机器学习框架

[2] DNNFusion: 利用高级算子融合加速深度神经网络执行

[3] Torch.FX: Python中深度学习的实用程序捕获与转换, MLSys 2022。

[4] Torch.fx—PyTorch 1.12 文档

[5] 大规模多任务学习的特征哈希

[6] NVIDIA 集体通信库文档

[7] Meta生产环境PyTorch模型的性能调试