概要:PyTorch 2.0 nightly 通过使用新的 torch.compile()
编译器以及集成到 PyTorch 2 中的优化多头注意力实现,为生成式扩散模型提供了开箱即用的性能提升。
引言
生成式 AI 近期取得的很大一部分进展来自于去噪扩散模型,这些模型可以根据文本提示生成高质量的图像和视频。这一系列模型包括 Imagen、DALLE、Latent Diffusion 等。然而,这一系列的所有模型都有一个共同的缺点:由于生成图像的采样过程具有迭代性质,生成速度相当慢。因此,优化采样循环中运行的代码非常重要。
我们以一个流行的文本到图像扩散模型的开源实现为起点,并使用 PyTorch 2 中提供的两项优化来加速其生成:编译和快速注意力实现。结合代码中一些细微的内存处理改进,这些优化相对于不使用 xFormers 的原始实现可提供高达 49% 的推理速度提升,相对于使用 xFormers 的原始代码可提供 39% 的推理速度提升(不包括编译时间),具体取决于 GPU 架构和批量大小。重要的是,这些速度提升无需安装 xFormers 或任何其他额外依赖项即可获得。
下表显示了安装了 xFormers 的原始实现与我们使用 PyTorch 集成内存高效注意力(最初为 xFormers 库开发并发布)和 PyTorch 编译的优化版本之间的运行时改进。不包括编译时间。
与 original+xFormers 相比的运行时改进 (%)
绝对运行时数字请参见“基准测试设置和结果总结”部分
GPU | 批量大小 1 | 批量大小 2 | 批量大小 4 |
P100 (无编译) | -3.8 | 0.44 | 5.47 |
T4 | 2.12 | 10.51 | 14.2 |
A10 | -2.34 | 8.99 | 10.57 |
V100 | 18.63 | 6.39 | 10.43 |
A100 | 38.5 | 20.33 | 12.17 |
可以注意到以下几点
- 对于 A100 和 V100 等高性能 GPU,改进非常显著。对于这些 GPU,批量大小为 1 时的改进最为明显。
- 对于性能较低的 GPU,我们观察到较小的加速(或在两种情况下略有退步)。这里的批量大小趋势相反:批量越大,改进越大。
在以下部分,我们将介绍应用的优化并提供详细的基准测试数据,比较开启/关闭不同优化功能时的生成时间。
具体来说,我们对 5 种配置进行了基准测试,下面的图表比较了它们在不同 GPU 和批量大小下的绝对性能。这些配置的定义请参见“基准测试设置和结果”部分。
优化
在这里,我们将更详细地介绍模型代码中引入的优化。这些优化依赖于最近发布的 PyTorch 2.0 的特性。
优化的注意力
我们优化的代码的一部分是缩放点积注意力。注意力是众所周知的计算密集型操作:朴素实现会具体化注意力矩阵,导致时间和内存复杂度与序列长度呈平方关系。扩散模型通常在 U-Net 的多个部分的 Transformer 块中使用注意力(CrossAttention
)。由于 U-Net 在每个采样步骤都会运行,因此这成为一个关键的优化点。可以使用 torch.nn.MultiheadAttention
代替自定义注意力实现,PyTorch 2 中已将优化的注意力实现集成到其中。这种优化可以大致归结为以下伪代码:
class CrossAttention(nn.Module):
def __init__(self, ...):
# Create matrices: Q, K, V, out_proj
...
def forward(self, x, context=None, mask=None):
# Compute out = SoftMax(Q*K/sqrt(d))V
# Return out_proj(out)
…
被替换为
class CrossAttention(nn.Module):
def __init__(self, ...):
self.mha = nn.MultiheadAttention(...)
def forward(self, x, context):
return self.mha(x, context, context)
注意力的优化实现已在 PyTorch 1.13 中可用(参见此处)并被广泛采用(例如参见HuggingFace transformers 库示例)。特别是,它集成了来自xFormers库的内存高效注意力和来自https://arxiv.org/abs/2205.14135 的 Flash 注意力。PyTorch 2.0 将此扩展到更多的注意力函数,例如交叉注意力和自定义内核,以进一步加速,使其适用于扩散模型。
Flash 注意力在计算能力为 SM 7.5 或 SM 8.x 的 GPU 上可用,例如 T4、A10 和 A100,这些都包含在我们的基准测试中(您可以在此处查看每种 NVIDIA GPU 的计算能力)。然而,在我们对 A100 的测试中,对于扩散模型的特定情况,内存高效注意力表现优于 Flash 注意力,原因是注意力头数量少且批量大小小。PyTorch 理解这一点,在这种情况下,当两者都可用时,它会选择内存高效注意力而非 Flash 注意力(参见此处逻辑)。为了完全控制注意力后端(内存高效注意力、Flash 注意力、“原生数学”或任何未来的后端),高级用户可以使用上下文管理器 torch.backends.cuda.sdp_kernel 手动启用和禁用它们。
编译
编译是 PyTorch 2.0 的一项新功能,它提供了非常简单的用户体验,同时实现了显著的加速。要调用默认行为,只需将 PyTorch 模块或函数包装在 torch.compile
中即可。
model = torch.compile(model)
PyTorch 编译器随后将 Python 代码转换为一组指令,这些指令可以高效执行而没有 Python 开销。编译在代码首次执行时动态发生。在默认行为下,PyTorch 在底层利用 TorchDynamo 编译代码,并利用 TorchInductor 进一步优化代码。有关更多详细信息,请参见此教程。
尽管上面的一行代码足以进行编译,但代码中的某些修改可以实现更大的加速。特别是,应该避免所谓的“图断点”——PyTorch 无法编译的代码片段。与以前的 PyTorch 编译方法(如 TorchScript)不同,PyTorch 2 编译器在这种情况下不会中断。相反,它会回退到 Eager 执行——因此代码仍然运行,但性能会降低。我们对模型代码进行了一些微小的更改,以消除图断点。这包括删除编译器不支持的库中的函数,例如 inspect.isfunction
和 einops.rearrange
。请参见此文档,了解更多关于图断点以及如何消除它们的信息。
理论上,可以将 torch.compile
应用于整个扩散采样循环。然而,在实践中,只需编译 U-Net 就足够了。原因是 torch.compile
尚未包含循环分析器,并且会在采样循环的每次迭代中重新编译代码。此外,编译后的采样器代码很可能产生图断点——因此如果想从编译版本中获得良好的性能,就需要对其进行调整。
请注意,编译要求 GPU 计算能力 >= SM 7.0 才能在非 eager 模式下运行。这包括我们基准测试中的所有 GPU - T4、V100、A10、A100 - 但 P100 除外(参见完整列表)。
其他优化
此外,我们通过消除一些常见陷阱,例如直接在 GPU 上创建张量而不是在 CPU 上创建后再移动到 GPU,提高了 GPU 内存操作的效率。需要进行此类优化的位置是通过逐行分析以及查看 CPU/GPU 跟踪和 Flame Graphs 确定的。
基准测试设置和结果总结
我们有两个版本的代码进行比较:原始版本和优化版本。此外,可以开启/关闭多个优化功能(xFormers、PyTorch 内存高效注意力、编译)。总的来说,如引言中所述,我们将对 5 种配置进行基准测试。
- 未使用 xFormers 的原始代码
- 使用 xFormers 的原始代码
- 使用原生数学注意力后端且未编译的优化代码
- 使用内存高效注意力后端且未编译的优化代码
- 使用内存高效注意力后端且已编译的优化代码
作为原始版本,我们采用了使用 PyTorch 1.12 和自定义注意力实现的代码版本。优化版本在 CrossAttention
中使用了 nn.MultiheadAttention
并使用 PyTorch 2.0.0.dev20230111+cu117。它还在 PyTorch 相关代码中包含一些其他细微优化。
下表显示了每种代码版本的运行时(以秒为单位),以及与使用 xFormers 的原始版本相比的百分比改进。不包括编译时间。
批量大小为 1 时的运行时。括号中为相对于“使用 xFormers 的原始版本”行的相对改进
配置 | P100 | T4 | A10 | V100 | A100 |
未使用 xFormers 的原始版本 | 30.4s (-19.3%) | 29.8s (-77.3%) | 13.0s (-83.9%) | 10.9s (-33.1%) | 8.0s (-19.3%) |
使用 xFormers 的原始版本 | 25.5s (0.0%) | 16.8s (0.0%) | 7.1s (0.0%) | 8.2s (0.0%) | 6.7s (0.0%) |
使用原生数学注意力、未编译的优化版本 | 27.3s (-7.0%) | 19.9s (-18.7%) | 13.2s (-87.2%) | 7.5s (8.7%) | 5.7s (15.1%) |
使用内存高效注意力、未编译的优化版本 | 26.5s (-3.8%) | 16.8s (0.2%) | 7.1s (-0.8%) | 6.9s (16.0%) | 5.3s (20.6%) |
使用内存高效注意力且已编译的优化版本 | - | 16.4s(2.1%) | 7.2s (-2.3%) | 6.6s (18.6%) | 4.1s (38.5%) |
批量大小为 2 时的运行时
配置 | P100 | T4 | A10 | V100 | A100 |
未使用 xFormers 的原始版本 | 58.0s (-21.6%) | 57.6s (-84.0%) | 24.4s (-95.2%) | 18.6s (-63.0%) | 12.0s (-50.6%) |
使用 xFormers 的原始版本 | 47.7s (0.0%) | 31.3s (0.0%) | 12.5s (0.0%) | 11.4s (0.0%) | 8.0s (0.0%) |
使用原生数学注意力、未编译的优化版本 | 49.3s (-3.5%) | 37.9s (-21.0%) | 17.8s (-42.2%) | 12.7s (-10.7%) | 7.8s (1.8%) |
使用内存高效注意力、未编译的优化版本 | 47.5s(0.4%) | 31.2s (0.5%) | 12.2s (2.6%) | 11.5s (-0.7%) | 7.0s (12.6%) |
使用内存高效注意力且已编译的优化版本 | - | 28.0s (10.5%) | 11.4s (9.0%) | 10.7s(6.4%) | 6.4s (20.3%) |
批量大小为 4 时的运行时
配置 | P100 | T4 | A10 | V100 | A100 |
未使用 xFormers 的原始版本 | 117.9s (-20.0%) | 112.4s (-81.8%) | 47.2s (-101.7%) | 35.8s (-71.9%) | 22.8s (-78.9%) |
使用 xFormers 的原始版本 | 98.3s (0.0%) | 61.8s (0.0%) | 23.4s (0.0%) | 20.8s (0.0%) | 12.7s (0.0%) |
使用原生数学注意力、未编译的优化版本 | 101.1s (-2.9%) | 73.0s (-18.0%) | 28.3s (-21.0%) | 23.3s (-11.9%) | 14.5s (-13.9%) |
使用内存高效注意力、未编译的优化版本 | 92.9s(5.5%) | 61.1s (1.2%) | 23.9s (-1.9%) | 20.8s (-0.1%) | 12.8s (-0.9%) |
使用内存高效注意力且已编译的优化版本 | - | 53.1s(14.2%) | 20.9s (10.6%) | 18.6s (10.4%) | 11.2s (12.2%) |
为了尽量减少波动和外部因素对基准测试代码性能的影响,我们按顺序依次运行每个版本的代码,然后重复这个序列 10 次:A、B、C、D、E、A、B……因此,典型运行的结果将如下图所示。请注意,不应依赖于不同图表之间绝对运行时间的比较,但由于我们的基准测试设置,同一图表内的运行时间比较是相当可靠的。
每次运行文本到图像生成脚本都会产生几个批次,其数量由 CLI 参数 --n_iter
控制。在基准测试中,我们使用了 n_iter = 2
,但引入了一个额外的“预热”迭代,这个迭代不计入运行时间。对于使用编译的运行,这是必需的,因为编译在代码第一次运行时发生,因此第一次迭代比所有后续迭代都要长得多。为了公平比较,我们也对所有其他运行引入了这个额外的“预热”迭代。
上表中的数字对应于迭代次数为 2(加上一个“预热”迭代),提示词为“A photo”,种子为 1,使用 PLMS 采样器,并且开启了 autocast。
基准测试使用了 P100、V100、A100、A10 和 T4 GPU 进行。T4 基准测试在 Google Colab Pro 中进行。A10 基准测试在带有 1 个 GPU 的 g5.4xlarge AWS 实例上进行。
结论与下一步
我们已经表明,PyTorch 2 的新特性——编译器和优化的注意力实现——提供了超越或媲美先前需要安装外部依赖项(xFormers)才能获得的性能提升。PyTorch 实现这一点,特别是通过将 xFormers 中的内存高效注意力集成到其代码库中。考虑到 xFormers 作为最先进的库,在许多场景下需要自定义安装过程和长时间的构建,这对用户体验是一个显著的改进。
这项工作可以在以下几个自然方向上继续进行
- 我们在此实现和描述的优化目前仅对文本到图像推理进行了基准测试。了解它们如何影响训练性能将会很有趣。PyTorch 编译可以直接应用于训练;在路线图上已经计划了使用 PyTorch 优化的注意力进行训练。
- 我们有意将对原始模型代码的修改降至最低。进一步的性能分析和优化可能会带来更多改进。
- 目前,编译仅应用于采样器内部的 U-Net 模型。由于 U-Net 之外有很多操作(例如直接在采样循环中的操作),编译整个采样器会很有益处。然而,这需要对编译过程进行分析,以避免在每个采样步骤都重新编译。
- 当前代码仅在 PLMS 采样器中应用编译,但将其扩展到其他采样器应该非常容易。
- 除了文本到图像生成,扩散模型还应用于其他任务——图像到图像和图像修复。衡量它们通过 PyTorch 2 优化能带来多少性能提升将会很有趣。
看看您是否可以使用我们描述的方法来提高开源扩散模型的性能,并分享结果!
资源
- PyTorch 2.0 概述,其中包含大量关于
torch.compile
的信息:https://pytorch.ac.cn/get-started/pytorch-2.0/ torch.compile
教程:https://pytorch.ac.cn/tutorials/intermediate/torch_compile_tutorial.html- 常见编译故障排除:https://pytorch.ac.cn/docs/stable/torch.compiler_troubleshooting.html
- 图断点详细信息:https://pytorch.ac.cn/docs/stable/torch.compiler_faq.html#identifying-the-cause-of-a-graph-break
- Guards 详细信息:https://pytorch.ac.cn/docs/stable/torch.compiler_guards_overview.html
- 关于 TorchDynamo 的深度解析视频 https://www.youtube.com/watch?v=egZB5Uxki0I
- PyTorch 1.12 中优化的注意力教程:https://pytorch.ac.cn/tutorials/beginner/bettertransformer_tutorial.html
致谢
我们要感谢 Geeta Chauhan、Natalia Gimelshein、Patrick Labatut、Bert Maher、Mark Saroufim、Michael Voznesensky 和 Francisco Massa 对本文提出的宝贵建议和早期反馈。
特别感谢 Yudong Tao 启动了在扩散模型中使用 PyTorch 原生注意力方面的工作。