TL;DR: PyTorch 2.0 nightly 版本通过使用新的 torch.compile()
编译器和与 PyTorch 2 集成的优化 Multihead Attention 实现,为生成式扩散模型提供了开箱即用的性能改进。
引言
生成式 AI 最近的很大一部分进展来自于去噪扩散模型,该模型能够根据文本提示生成高质量的图像和视频。此系列包括 Imagen、DALLE、Latent Diffusion 等。然而,此系列中的所有模型都有一个共同的缺点:由于图像生成过程的迭代性质,生成速度相当慢。这使得优化采样循环中运行的代码变得非常重要。
我们以一个流行的文本到图像扩散模型的开源实现为起点,并使用 PyTorch 2 中提供的两项优化加速了其生成:编译和快速注意力实现。这些优化与代码中一些小的内存处理改进相结合,相对于原始实现(不使用 xFormers)提供了高达 49% 的推理速度提升,相对于使用 xFormers 的原始代码(不包括编译时间)提供了 39% 的推理速度提升,具体取决于 GPU 架构和批处理大小。重要的是,速度提升无需安装 xFormers 或任何其他额外依赖项。
下表显示了安装 xFormers 的原始实现与我们使用 PyTorch 集成的内存高效注意力(最初为 xFormers 库开发并发布)和 PyTorch 编译的优化版本之间运行时间的改进。编译时间不包括在内。
与原始版本+xFormers 相比的运行时改进百分比
请参阅“基准测试设置和结果摘要”部分中的绝对运行时间数据。
GPU | 批处理大小 1 | 批处理大小 2 | 批处理大小 4 |
P100 (无编译) | -3.8 | 0.44 | 5.47 |
T4 | 2.12 | 10.51 | 14.2 |
A10 | -2.34 | 8.99 | 10.57 |
V100 | 18.63 | 6.39 | 10.43 |
A100 | 38.5 | 20.33 | 12.17 |
可以注意到以下几点
- 对于 A100 和 V100 等强大 GPU,改进显著。对于这些 GPU,批处理大小为 1 时改进最为明显。
- 对于性能较弱的 GPU,我们观察到较小的加速(或在两种情况下略有下降)。这里的批处理大小趋势是相反的:批处理越大,改进越大。
在以下部分中,我们将描述所应用的优化,并提供详细的基准测试数据,比较在启用/禁用各种优化功能情况下的生成时间。
具体来说,我们对 5 种配置进行了基准测试,下面的图表比较了它们在不同 GPU 和批处理大小下的绝对性能。有关这些配置的定义,请参见“基准测试设置和结果”部分。



优化
在这里,我们将详细介绍模型代码中引入的优化。这些优化依赖于最近发布的 PyTorch 2.0 的功能。
优化注意力
我们优化了代码的一部分是缩放点积注意力。注意力是一种繁重的操作:朴素的实现会实例化注意力矩阵,导致时间和内存复杂度与序列长度呈二次方关系。扩散模型通常在 U-Net 的多个部分中使用注意力 (CrossAttention
) 作为 Transformer 块的一部分。由于 U-Net 在每个采样步骤都运行,因此这成为一个关键的优化点。与其使用自定义注意力实现,不如使用 torch.nn.MultiheadAttention
,PyTorch 2 中集成了优化的注意力实现。这种优化可以概括为以下伪代码:
class CrossAttention(nn.Module):
def __init__(self, ...):
# Create matrices: Q, K, V, out_proj
...
def forward(self, x, context=None, mask=None):
# Compute out = SoftMax(Q*K/sqrt(d))V
# Return out_proj(out)
…
被替换为
class CrossAttention(nn.Module):
def __init__(self, ...):
self.mha = nn.MultiheadAttention(...)
def forward(self, x, context):
return self.mha(x, context, context)
注意力的优化实现已在 PyTorch 1.13 中可用(参见此处),并被广泛采用(参见例如 HuggingFace transformers 库示例)。特别是,它集成了来自 xFormers 库的内存高效注意力以及来自 https://arxiv.org/abs/2205.14135 的 Flash Attention。PyTorch 2.0 将此扩展到其他注意力功能,例如交叉注意力和用于进一步加速的自定义内核,使其适用于扩散模型。
Flash attention 在计算能力为 SM 7.5 或 SM 8.x 的 GPU 上可用——例如,T4、A10 和 A100,这些都包含在我们的基准测试中(您可以在此处查看每个 NVIDIA GPU 的计算能力)。然而,在我们在 A100 上的测试中,由于注意力头数量少且批处理大小小,内存高效注意力在扩散模型的特定情况下表现优于 Flash attention。PyTorch 理解这一点,并且在这种情况下,当两者都可用时,PyTorch 会选择内存高效注意力而不是 Flash attention(参见此处的逻辑)。为了完全控制注意力后端(内存高效注意力、Flash attention、“vanilla math”或任何未来的后端),高级用户可以在上下文管理器 torch.backends.cuda.sdp_kernel 的帮助下手动启用和禁用它们。
编译
编译是 PyTorch 2.0 的新功能,它以非常简单的用户体验实现显著的加速。要调用默认行为,只需将 PyTorch 模块或函数包装到 torch.compile
中即可
model = torch.compile(model)
PyTorch 编译器随后将 Python 代码转换为一组指令,这些指令可以高效执行,而无需 Python 开销。编译在代码首次执行时动态进行。在默认行为下,PyTorch 在底层利用 TorchDynamo 编译代码,并利用 TorchInductor 进一步优化它。有关更多详细信息,请参见 本教程。
虽然上面的一行代码足以进行编译,但代码中的某些修改可以获得更大的加速。特别是,应该避免所谓的图中断——PyTorch 无法编译代码的地方。与以前的 PyTorch 编译方法(如 TorchScript)不同,PyTorch 2 编译器在这种情况下不会中断。相反,它会回退到即时执行——因此代码会运行,但性能会降低。我们对模型代码进行了一些微小的更改,以消除图中断。这包括消除编译器不支持的库中的函数,例如 inspect.isfunction
和 einops.rearrange
。请参阅 此文档,了解有关图中断以及如何消除它们的更多信息。
理论上,可以将 torch.compile
应用于整个扩散采样循环。然而,实际上,只需编译 U-Net 就足够了。原因在于 torch.compile
尚未拥有循环分析器,并且会为采样循环的每次迭代重新编译代码。此外,编译后的采样器代码很可能会生成图中断——因此如果想要从编译版本获得良好的性能,则需要对其进行调整。
请注意,编译 需要 GPU 计算能力 >= SM 7.0 才能在非即时模式下运行。这涵盖了我们基准测试中的所有 GPU——T4、V100、A10、A100——除了 P100(请参阅 完整列表)。
其他优化
此外,我们通过消除一些常见的陷阱,例如直接在 GPU 上创建张量而不是在 CPU 上创建后移动到 GPU,提高了 GPU 内存操作的效率。需要进行此类优化的地方是通过逐行分析和查看 CPU/GPU 跟踪以及 Flame Graphs 确定的。
基准测试设置和结果总结
我们有两个版本的代码进行比较:原始版本和优化版本。在此之上,可以开启/关闭多个优化功能(xFormers、PyTorch 内存高效注意力、编译)。总的来说,如引言中所述,我们将对 5 种配置进行基准测试
- 不使用 xFormers 的原始代码
- 使用 xFormers 的原始代码
- 使用香草数学注意力后端且未编译的优化代码
- 使用内存高效注意力后端且未编译的优化代码
- 使用内存高效注意力后端和编译的优化代码
作为原始版本,我们使用了 PyTorch 1.12 和自定义注意力实现的代码版本。优化版本在 CrossAttention
中使用 nn.MultiheadAttention
和 PyTorch 2.0.0.dev20230111+cu117。它还在 PyTorch 相关代码中进行了一些其他小优化。
下表显示了每个代码版本的运行时间(秒),以及与 _原始带 xFormers_ 相比的改进百分比。编译时间不包括在内。
批处理大小为 1 的运行时间。括号中为相对于“原始带 xFormers”行的相对改进
配置 | P100 | T4 | A10 | V100 | A100 |
不带 xFormers 的原始版本 | 30.4 秒 (-19.3%) | 29.8 秒 (-77.3%) | 13.0 秒 (-83.9%) | 10.9 秒 (-33.1%) | 8.0 秒 (-19.3%) |
带 xFormers 的原始版本 | 25.5 秒 (0.0%) | 16.8 秒 (0.0%) | 7.1 秒 (0.0%) | 8.2 秒 (0.0%) | 6.7 秒 (0.0%) |
带 vanilla math attention 的优化版本,未编译 | 27.3 秒 (-7.0%) | 19.9 秒 (-18.7%) | 13.2 秒 (-87.2%) | 7.5 秒 (8.7%) | 5.7 秒 (15.1%) |
带内存高效注意力的优化版本,未编译 | 26.5 秒 (-3.8%) | 16.8 秒 (0.2%) | 7.1 秒 (-0.8%) | 6.9 秒 (16.0%) | 5.3 秒 (20.6%) |
带内存高效注意力和编译的优化版本 | -- | 16.4 秒(2.1%) | 7.2 秒 (-2.3%) | 6.6 秒 (18.6%) | 4.1 秒 (38.5%) |
批处理大小为 2 的运行时间
配置 | P100 | T4 | A10 | V100 | A100 |
不带 xFormers 的原始版本 | 58.0 秒 (-21.6%) | 57.6 秒 (-84.0%) | 24.4 秒 (-95.2%) | 18.6 秒 (-63.0%) | 12.0 秒 (-50.6%) |
带 xFormers 的原始版本 | 47.7 秒 (0.0%) | 31.3 秒 (0.0%) | 12.5 秒 (0.0%) | 11.4 秒 (0.0%) | 8.0 秒 (0.0%) |
带 vanilla math attention 的优化版本,未编译 | 49.3 秒 (-3.5%) | 37.9 秒 (-21.0%) | 17.8 秒 (-42.2%) | 12.7 秒 (-10.7%) | 7.8 秒 (1.8%) |
带内存高效注意力的优化版本,未编译 | 47.5 秒(0.4%) | 31.2 秒 (0.5%) | 12.2 秒 (2.6%) | 11.5 秒 (-0.7%) | 7.0 秒 (12.6%) |
带内存高效注意力和编译的优化版本 | -- | 28.0 秒 (10.5%) | 11.4 秒 (9.0%) | 10.7 秒(6.4%) | 6.4 秒 (20.3%) |
批处理大小为 4 的运行时间
配置 | P100 | T4 | A10 | V100 | A100 |
不带 xFormers 的原始版本 | 117.9 秒 (-20.0%) | 112.4 秒 (-81.8%) | 47.2 秒 (-101.7%) | 35.8 秒 (-71.9%) | 22.8 秒 (-78.9%) |
带 xFormers 的原始版本 | 98.3 秒 (0.0%) | 61.8 秒 (0.0%) | 23.4 秒 (0.0%) | 20.8 秒 (0.0%) | 12.7 秒 (0.0%) |
带 vanilla math attention 的优化版本,未编译 | 101.1 秒 (-2.9%) | 73.0 秒 (-18.0%) | 28.3 秒 (-21.0%) | 23.3 秒 (-11.9%) | 14.5 秒 (-13.9%) |
带内存高效注意力的优化版本,未编译 | 92.9 秒(5.5%) | 61.1 秒 (1.2%) | 23.9 秒 (-1.9%) | 20.8 秒 (-0.1%) | 12.8 秒 (-0.9%) |
带内存高效注意力和编译的优化版本 | -- | 53.1 秒(14.2%) | 20.9 秒 (10.6%) | 18.6 秒 (10.4%) | 11.2 秒 (12.2%) |
为了最大程度地减少波动和外部对基准代码性能的影响,我们依次运行每个版本的代码,然后重复此序列 10 次:A、B、C、D、E、A、B……因此,典型运行的结果如下图所示。请注意,不应依赖不同图表之间绝对运行时间的比较,但由于我们的基准测试设置,同一图表内的运行时间比较是相当可靠的。

文本到图像生成脚本的每次运行都会产生若干批次,其数量由 CLI 参数 --n_iter
控制。在基准测试中,我们使用了 n_iter = 2
,但引入了一个额外的“预热”迭代,该迭代不计入运行时间。这对于编译运行是必要的,因为编译发生在代码首次运行时,因此第一次迭代比所有后续迭代都要长得多。为了公平比较,我们还将此额外的“预热”迭代引入到所有其他运行中。
上表中的数字适用于迭代次数为 2(外加一次“预热”)、提示“一张照片”、种子 1、PLMS 采样器和开启自动转换的情况。
基准测试使用 P100、V100、A100、A10 和 T4 GPU 完成。T4 基准测试在 Google Colab Pro 中完成。A10 基准测试在具有 1 个 GPU 的 g5.4xlarge AWS 实例上完成。
结论和下一步
我们已经证明,PyTorch 2 的新功能——编译器和优化的注意力实现——提供了超越或媲美之前需要安装外部依赖项 (xFormers) 才能获得的性能改进。PyTorch 之所以能做到这一点,特别是通过将 xFormers 的内存高效注意力集成到其代码库中。鉴于 xFormers 作为最先进的库,在许多情况下需要自定义安装过程和长时间构建,这对用户体验来说是一个显著的改进。
这项工作可以沿着以下几个自然方向继续进行
- 我们在此实施和描述的优化目前仅针对文本到图像推理进行了基准测试。了解它们如何影响训练性能将是很有趣的。PyTorch 编译可以直接应用于训练;支持使用 PyTorch 优化注意力进行训练已列入路线图。
- 我们特意将对原始模型代码的更改最小化。进一步的剖析和优化可能会带来更多的改进。
- 目前,编译仅应用于采样器内部的 U-Net 模型。由于 U-Net 之外发生了很多事情(例如直接在采样循环中的操作),因此编译整个采样器将是有益的。然而,这需要对编译过程进行分析,以避免在每个采样步骤重新编译。
- 当前代码仅在 PLMS 采样器中应用编译,但将其扩展到其他采样器应该很简单。
- 除了文本到图像生成,扩散模型还应用于其他任务——图像到图像和图像修补。衡量它们的性能如何通过 PyTorch 2 优化得到改善将是很有趣的。
看看您是否可以使用我们描述的方法提高开源扩散模型的性能,并分享结果!
资源
- PyTorch 2.0 概述,其中包含大量关于
torch.compile
的信息:https://pytorch.ac.cn/get-started/pytorch-2.0/ torch.compile
教程:https://pytorch.ac.cn/tutorials/intermediate/torch_compile_tutorial.html- 通用编译故障排除:https://pytorch.ac.cn/docs/stable/torch.compiler_troubleshooting.html
- 图中断详情:https://pytorch.ac.cn/docs/stable/torch.compiler_faq.html#identifying-the-cause-of-a-graph-break
- 防护详情:https://pytorch.ac.cn/docs/stable/torch.compiler_guards_overview.html
- TorchDynamo 视频深度解析 https://www.youtube.com/watch?v=egZB5Uxki0I
- PyTorch 1.12 优化注意力教程:https://pytorch.ac.cn/tutorials/beginner/bettertransformer_tutorial.html
致谢
我们衷心感谢 Geeta Chauhan、Natalia Gimelshein、Patrick Labatut、Bert Maher、Mark Saroufim、Michael Voznesensky 和 Francisco Massa 提出的宝贵建议和早期反馈。
特别感谢 Yudong Tao 发起在扩散模型中使用 PyTorch 原生注意力方面的工作。