博客

通过图转换优化生产 PyTorch 模型的性能

作者: 2022年11月28日2024年11月14日暂无评论

1. 简介

PyTorch 支持两种执行模式 [1]:即时模式(Eager Mode)和图模式(Graph Mode)。在即时模式下,模型中的算子会在遇到时立即执行。相反,在图模式下,算子首先被合成为一个图,然后进行编译并作为一个整体执行。即时模式更易于使用,更适合机器学习研究人员,因此是默认的执行模式。另一方面,图模式通常能提供更高的性能,因此在生产环境中被大量使用。

具体而言,图模式支持算子融合 [2],即将一个算子与另一个算子合并,以减少/本地化内存读取以及总的内核启动开销。融合可以是水平的——获取一个独立应用于多个操作数的单一操作(例如 BatchNorm),并将这些操作数合并为一个数组;也可以是垂直的——将一个内核与消耗该内核输出的另一个内核合并(例如卷积层后接 ReLU)。

Torch.FX [3, 4] (简称 FX) 是 PyTorch 包中公开提供的一个支持图模式执行的工具包。特别是,它 (1) 从 PyTorch 程序中捕获图,以及 (2) 允许开发者对捕获的图编写转换逻辑。它在 Meta 内部被用于优化生产模型的训练吞吐量。通过介绍 Meta 开发的一系列基于 FX 的优化方案,我们展示了使用图转换来优化 PyTorch 生产环境性能的方法。

2. 背景

Embedding 表在推荐系统中无处不在。第 3 节将讨论三种优化 Embedding 表访问的 FX 转换。在本节中,我们将介绍 FX(第 2.1 节)和 Embedding 表(第 2.2 节)的相关背景。

2.1 FX

图 1 是取自 [3] 的一个简单示例,展示了如何使用 FX 转换 PyTorch 程序。它包含三个步骤:(1) 从程序中捕获图,(2) 修改图(在此示例中,将所有 RELU 的使用替换为 GELU),以及 (3) 从修改后的图中生成新程序。

图 1:一个将 PyTorch 模块中所有 RELU 替换为 GELU 的 FX 示例。

FX API [4] 提供了更多用于检查和转换 PyTorch 程序图的功能。

2.2 Embedding 表

图 2:Batch Size = 1 的稀疏特征 Embedding 表示意图

在推荐系统中,稀疏特征(例如用户 ID、故事 ID)由 Embedding 表表示。Embedding 表 E 是一个 HxD 矩阵,其中 H 是哈希大小,D 是 Embedding 维度。E 的每一行都是一个浮点向量。特征哈希 [5] 用于将稀疏特征映射到 E 的索引列表,记为 [S1,S2, …, Sk],其中 0<=Si<H。其输出值计算为 f(E[S1], E[S2], …, E[Sk]),其中 E[Si] 是第 Si 行的向量,f 被称为池化函数,通常是以下函数之一:求和、平均、最大值。示意图请见图 2。

为了充分利用 GPU,稀疏特征通常以 Batch 的形式处理。Batch 中的每个实体都有其自己的索引列表。如果一个 Batch 有 B 个实体,朴素的表示法包含 B 个索引列表。一种更紧凑的表示法是将 B 个索引列表合并为一个索引列表,并添加一个索引长度列表(Batch 中每个实体对应一个长度)。例如,如果一个 Batch 有 3 个实体,其索引列表如下

  • 实体 1:索引 = [10, 20]
  • 实体 2:索引 = [5, 9, 77, 81]
  • 实体 3:索引 = [15, 20, 45]

那么整个 Batch 的索引和长度将是

  • 索引 = [10, 20, 5, 9, 77, 81, 15, 20, 45]
  • 长度 = [2, 4, 3]

整个 Batch 的 Embedding 表查找输出是一个 BxD 的矩阵。

3. 三种 FX 转换

我们开发了三种加速 Embedding 表访问的 FX 转换。第 3.1 节讨论了一种将多个小输入张量合并为单个大张量的转换;第 3.2 节讨论了一种将多个并行计算链融合为单个计算链的转换;第 3.3 节讨论了一种将通信与计算重叠(Overlap)的转换。

3.1 合并输入稀疏特征

回想一下,Batch 中的输入稀疏特征由两个列表表示:索引列表和 B 个长度的列表,其中 B 是 Batch 大小。在 PyTorch 中,这两个列表被实现为两个张量。当 PyTorch 模型在 GPU 上运行时,Embedding 表通常存储在 GPU 内存中(它比 CPU 内存离 GPU 更近,并且拥有高得多的读写带宽)。要使用输入稀疏特征,其两个张量需要先从 CPU 复制到 GPU。然而,每次主机到设备的内存拷贝都需要启动一个内核,与实际的数据传输时间相比,这相对昂贵。如果模型使用许多输入稀疏特征,这种拷贝可能会成为性能瓶颈(例如,1000 个输入稀疏特征将需要从主机到设备拷贝 2000 个张量)。

一种减少主机到设备内存拷贝(memcpy)次数的优化方法是在发送到设备之前合并多个输入稀疏特征。例如,给定以下三个输入特征

  • 特征_A:索引 = [106, 211, 7],长度 = [2, 1]
  • 特征_B:索引 = [52, 498, 616, 870, 1013],长度 = [3, 2]
  • 特征_C:索引 = [2011, 19, 351, 790],长度 = [1, 3]

合并后的形式为

  • 特征_A_B_C:索引 = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790],长度 = [2, 1, 3, 2, 1, 3]

因此,我们不再需要从主机到设备拷贝 3×2=6 个张量,而只需拷贝 2 个张量。

图 3(b) 描述了这种优化的实现,它包含两个部分

  • 在 CPU 端:修改输入流水线,将所有稀疏特征的索引合并为一个张量,类似地,将所有长度合并为另一个张量。然后将这两个张量拷贝到 GPU。
  • 在 GPU 端:使用 FX,我们在模型图中插入一个 Permute_and_Split 算子,从合并后的张量中恢复各个特征的索引和长度张量,并将它们路由到下游对应的节点。

(a). 未优化

(b). 已优化

图 3:合并输入稀疏特征

3.2 以 Embedding 表访问为起点的计算链水平融合

在生产模型中,每个 GPU 上驻留数十个 Embedding 表的情况非常普遍。出于性能考虑,这些表的查找被分组在一起,以便它们的输出被连接成一个单一的大张量(见图 4(a) 中的红色部分)。为了对各个特征输出应用计算,使用 Split 算子将大张量拆分为 N 个较小的张量(N 是特征数量),然后对每个张量应用所需的计算。这在图 4(a) 中显示,其中应用于每个特征输出 O 的计算是 Tanh(LayerNorm(O))。所有计算结果被拼接回一个大张量,然后传递给下游算子(图 4(a) 中的 Op1)。

这里的主要运行时成本是 GPU 内核启动开销。例如,图 4(a) 中的 GPU 内核启动次数为 2*N + 3(图中的每个椭圆都是一个 GPU 内核)。这可能成为性能问题,因为 GPU 上 LayerNorm 和 Tanh 的执行时间相对于它们的内核启动时间很短。此外,Split 算子可能会创建 Embedding 输出张量的额外副本,消耗额外的 GPU 内存。

我们使用 FX 实现了名为“水平融合”的优化,极大地减少了 GPU 内核启动次数(在本例中,优化后的 GPU 内核启动次数为 5,见图 4(b))。我们不进行显式的 Split,而是使用 Add_middle_dim 算子将形状为 (B, NxD) 的二维 Embedding 张量重塑为 (B, N, D) 的三维张量。然后对它的最后一个维度应用单个 LayerNorm。接着对 LayerNorm 的结果应用单个 Tanh。最后,使用 Remove_middle_dim 算子将 Tanh 的结果重塑回二维张量。此外,由于 Add_middle_dim 和 Remove_middle_dim 只是重塑张量而不创建额外副本,因此 GPU 内存消耗也可以降低。

(a). 未优化

(b). 已优化

图 4:水平融合

3.3 将计算与通信重叠

生产推荐模型的训练通常在分布式 GPU 系统上进行。由于每个 GPU 的设备内存容量不足以容纳模型中的所有 Embedding 表,因此它们需要在 GPU 之间进行分布。

在训练步骤中,GPU 需要从其他 GPU 的 Embedding 表读取特征值或向其写入特征值。这被称为 all-to-all 通信 [6],并且可能成为主要的性能瓶颈。

我们使用 FX 实现了一个可以将计算与 all-to-all 通信重叠的转换。图 5(a) 显示了一个具有 Embedding 表访问 (EmbeddingAllToAll) 和其他算子的模型图示例。在没有任何优化的情况下,它们在 GPU 流上按顺序执行,如图 5(b) 所示。使用 FX,我们将 EmbeddingAllToAll 拆分为 EmbeddingAllToAll_Request 和 EmbeddingAllToAll_Wait,并在它们之间调度独立的算子。

(a) 模型图

(b) 原始执行顺序

(c) 优化后的执行顺序

图 5:将计算与通信重叠

3.4 总结

表 1 总结了本节讨论的优化方案及其解决的相应性能瓶颈。

优化所解决的性能瓶颈
合并输入稀疏特征主机到设备的内存拷贝
水平融合GPU 内核启动开销
将计算与通信重叠Embedding all-to-all 访问时间

表 1:优化方案及其解决的性能瓶颈总结

由于篇幅限制,我们还开发了其他未在本节讨论的 FX 转换。

为了发掘哪些模型可以从这些转换中受益,我们分析了 MAIProf [7] 从 Meta 数据中心运行的模型中收集的性能数据。总的来说,在一组生产模型上,这些转换相较于即时模式提供了最高 2-3 倍的加速。

4. 结束语

出于性能考虑,PyTorch 中的图模式在生产环境中优于即时模式。FX 是一个功能强大的工具,用于捕获和优化 PyTorch 程序的图。我们展示了三种用于优化 Meta 内部生产推荐模型的 FX 转换。我们希望这篇博客能够激励其他 PyTorch 模型开发者使用图转换来提升其模型的性能。

参考文献

[1] 端到端机器学习框架

[2] DNNFusion:通过先进算子融合加速深度神经网络执行

[3] Torch.FX:Python 中深度学习的实用程序捕获与转换,MLSys 2022。

[4] Torch.fx—PyTorch 1.12 文档

[5] 大规模多任务学习的特征哈希

[6] NVIDIA 集体通信库文档 (NCCL)

[7] Meta 生产环境 PyTorch 模型的性能调试