作者:Jade Nie, CK Luk, Xiaodong Wang, Jackie (Jiaqi) Xu

1. 引言

PyTorch 支持两种执行模式 [1]:即时模式(eager mode)和图模式(graph mode)。在即时模式下,模型中的算子在遇到时会立即执行。相比之下,在图模式下,算子首先被合成为图,然后作为一个整体进行编译和执行。即时模式更易于使用,更适合机器学习研究人员,因此是默认的执行模式。另一方面,图模式通常能提供更高的性能,因此在生产环境中被广泛使用。

具体来说,图模式支持算子融合(operator fusion)[2],即将一个算子与另一个算子合并,以减少/局部化内存读取以及总体的核函数启动开销。融合可以是水平的——将一个独立应用于多个操作数的单一操作(例如 BatchNorm)及其操作数合并到数组中;融合也可以是垂直的——将一个核函数与消耗第一个核函数输出的另一个核函数合并(例如卷积后接 ReLU)。

Torch.FX [3, 4](简称 FX)是 PyTorch 包中一个公开可用的工具包,支持图模式执行。具体来说,它 (1) 从 PyTorch 程序中捕获图,并且 (2) 允许开发者在捕获的图上编写转换。它在 Meta 内部被用于优化生产模型的训练吞吐量。通过介绍 Meta 开发的一些基于 FX 的优化,我们展示了使用图转换来优化 PyTorch 性能以用于生产环境的方法。

2. 背景

嵌入表(Embedding tables)在推荐系统中无处不在。第 3 节将讨论三种优化嵌入表访问的 FX 转换。在本节中,我们将提供关于 FX(第 2.1 节)和嵌入表(第 2.2 节)的一些背景知识。

2.1 FX

图 1 是一个改编自 [3] 的简单示例,说明了如何使用 FX 转换 PyTorch 程序。它包含三个步骤:(1) 从程序中捕获图,(2) 修改图(在此示例中,所有使用 RELU 的地方都被 GELU 替换),以及 (3) 根据修改后的图生成新程序。

图 1:一个 FX 示例,演示了如何在 PyTorch 模块中用 GELU 替换所有使用 RELU 的地方。

FX API [4] 提供了更多用于检查和转换 PyTorch 程序图的功能。

2.2 嵌入表

图 2:批量大小为 1 的稀疏特征的嵌入表示例

在推荐系统中,稀疏特征(例如用户 ID、故事 ID)通过嵌入表表示。嵌入表 E 是一个 HxD 的矩阵,其中 H 是哈希大小,D 是嵌入维度。E 的每一行都是一个浮点向量。特征哈希(Feature hashing)[5] 用于将稀疏特征映射到 E 的索引列表,例如 [S1,S2, …, Sk],其中 0<=Si<H。其输出值计算为 f(E[S1], E[S2], …, E[Sk]),其中 E[Si] 是第 Si 行的向量,f 称为池化函数(pooling function),通常是以下函数之一:求和、平均、最大值。参见图 2 了解示例。

为了充分利用 GPU,稀疏特征通常以批量(batch)形式处理。批量中的每个实体都有自己的索引列表。如果一个批量有 B 个实体,朴素的表示方式是 B 个索引列表。更紧凑的表示方式是将 B 个索引列表合并成一个索引列表,并添加一个索引长度列表(批量中每个实体对应一个长度)。例如,如果一个批量有 3 个实体,其索引列表如下所示:

  • 实体 1:索引 = [10, 20]
  • 实体 2:索引 = [5, 9, 77, 81]
  • 实体 3:索引 = [15, 20, 45]

那么整个批量的索引和长度将是:

  • 索引 = [10, 20, 5, 9, 77, 81, 15, 20, 45]
  • 长度 = [2, 4, 3]

整个批量进行嵌入表查找的输出是一个 BxD 矩阵。

3. 三种 FX 转换

我们开发了三种 FX 转换,可加速对嵌入表的访问。第 3.1 节讨论了一种将多个小型输入张量组合成单个大型张量的转换;第 3.2 节讨论了一种将多个并行计算链融合成单个计算链的转换;第 3.3 节讨论了一种重叠通信和计算的转换。

3.1 组合输入稀疏特征

回想一下,批量中的输入稀疏特征由两个列表表示:一个索引列表和一个包含 B 个长度的列表,其中 B 是批量大小。在 PyTorch 中,这两个列表作为两个张量实现。当 PyTorch 模型在 GPU 上运行时,嵌入表通常存储在 GPU 内存中(它更靠近 GPU,读写带宽远高于 CPU 内存)。要使用输入稀疏特征,其两个张量需要先从 CPU 复制到 GPU。然而,每次主机到设备的内存复制都需要启动一个核函数,这相对于实际的数据传输时间来说是相对昂贵的。如果一个模型使用许多输入稀疏特征,这种复制可能会成为性能瓶颈(例如,1000 个输入稀疏特征将需要从主机复制 2000 个张量到设备)。

一种减少主机到设备 memcpy 次数的优化方法是在将多个输入稀疏特征发送到设备之前将它们组合起来。例如,给定以下三个输入特征:

  • 特征 A:索引 = [106, 211, 7],长度 = [2, 1]
  • 特征 B:索引 = [52, 498, 616, 870, 1013],长度 = [3, 2]
  • 特征 C:索引 = [2011, 19, 351, 790],长度 = [1, 3]

组合后的形式是:

  • 特征 A_B_C:索引 = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790],长度 = [2, 1, 3, 2, 1, 3]

因此,我们只需要复制 2 个张量,而不是从主机复制 3x2=6 个张量到设备。

图 3(b) 描述了此优化的实现,它包含两个组成部分:

  • 在 CPU 端:输入管道被修改,将所有稀疏特征的索引组合成一个张量,类似地将所有长度组合成另一个张量。然后将这两个张量复制到 GPU。
  • 在 GPU 端:使用 FX,我们在模型图中插入一个 Permute_and_Split 算子,从组合的张量中恢复单个特征的索引和长度张量,并将它们路由到下游对应的节点。

(a). 未经优化

(b). 经过优化

图 3:组合输入稀疏特征

3.2 对始于嵌入表访问的计算链进行水平融合

在生产模型中,每个 GPU 上驻留数十个嵌入表是相当常见的。出于性能考虑,对这些表的查找操作被分组在一起,以便它们的输出被连接成一个大的张量(参见图 4(a) 中的红色部分)。为了对单个特征输出应用计算,使用 Split 算子将大张量分成 N 个较小的张量(其中 N 是特征数量),然后对每个张量应用所需的计算。如图 4(a) 所示,对每个特征输出 O 应用的计算是 Tanh(LayerNorm(O))。所有计算结果被连接回一个大张量,然后传递给下游算子(图 4(a) 中的 Op1)。

这里主要的运行时开销是 GPU 核函数启动开销。例如,图 4(a) 中 GPU 核函数启动次数为 2*N + 3(图中的每个椭圆代表一个 GPU 核函数)。这可能成为一个性能问题,因为 LayerNorm 和 Tanh 在 GPU 上的执行时间相对于其核函数启动时间来说较短。此外,Split 算子可能会创建嵌入输出张量的额外副本,消耗额外的 GPU 内存。

我们使用 FX 实现了一种称为水平融合的优化,它显著减少了 GPU 核函数启动次数(在此示例中,优化后的 GPU 核函数启动次数为 5,参见图 4(b))。我们不进行显式的 Split 操作,而是使用 Add_middle_dim 算子将形状为 (B, NxD) 的二维嵌入张量重塑(reshape)为形状为 (B, N, D) 的三维张量。然后对它的最后一个维度应用单个 LayerNorm 操作。接着对 LayerNorm 的结果应用单个 Tanh 操作。最后,我们使用 Remove_middle_dim 算子将 Tanh 的结果重塑回二维张量。此外,由于 Add_middle_dim 和 Remove_middle_dim 只重塑张量而不创建额外副本,GPU 内存消耗量也可能降低。

(a). 未经优化

(b). 经过优化

图 4:水平融合

3.3 重叠计算与通信

生产推荐模型的训练通常在分布式 GPU 系统上完成。由于每个 GPU 的设备内存容量不足以容纳模型中的所有嵌入表,因此需要将它们分布在各个 GPU 上。

在训练步骤中,一个 GPU 需要从/向其他 GPU 上的嵌入表读取/写入特征值。这被称为 all-to-all 通信 [6],并且可能是一个主要的性能瓶颈。

我们使用 FX 实现了一种可以将计算与 all-to-all 通信重叠的转换。图 5(a) 显示了一个模型图示例,其中包含嵌入表访问(EmbeddingAllToAll)和其他算子。如 图 5(b) 所示,未经任何优化时,它们在 GPU 流上顺序执行。使用 FX,我们将 EmbeddingAllToAll 分解为 EmbeddingAllToAll_Request 和 EmbeddingAllToAll_Wait,并在它们之间调度独立的算子。

(a) 模型图

(b) 原始执行顺序

(c)优化后的执行顺序

图 5:重叠计算与通信

3.4 总结

表 1 总结了本节讨论的优化及其解决的相应性能瓶颈。

优化 解决的性能瓶颈
组合输入稀疏特征 主机到设备内存复制
水平融合 GPU 核函数启动开销
重叠计算与通信 嵌入 all-to-all 访问时间

表 1:优化及其解决的性能瓶颈总结

由于篇幅限制,我们还开发了其他 FX 转换,但未在本节中讨论。

为了发现哪些模型可以从这些转换中受益,我们分析了 MAIProf [7] 从 Meta 数据中心运行的模型收集的性能数据。总的来说,这些转换在一定数量的生产模型上,与即时模式相比,提供了高达 2-3 倍的加速。

4. 结论

出于性能原因,在生产环境中更倾向于使用 PyTorch 的图模式而非即时模式。FX 是一个强大的工具,用于捕获和优化 PyTorch 程序的图。我们展示了在 Meta 内部用于优化生产推荐模型的三种 FX 转换。我们希望这篇博客能够激励其他 PyTorch 模型开发者使用图转换来提升其模型的性能。

参考文献

[1] 端到端机器学习框架

[2] DNNFusion:通过高级算子融合加速深度神经网络执行

[3] Torch.FX:适用于 Python 深度学习的实用程序捕获和转换, MLSys 2022.

[4] Torch.fx—PyTorch 1.12 文档

[5] 用于大规模多任务学习的特征哈希

[6] NVIDIA 集合通信库文档

[7] Meta 生产级 PyTorch 模型的性能调试