概要:PyTorch 2.0 每夜版通过使用新的 torch.compile() 编译器和集成到 PyTorch 2 中的多头注意力优化实现,为生成式扩散模型提供开箱即用的性能提升。
引言
生成式 AI 近期的大部分进展都来自去噪扩散模型,它们能够根据文本提示生成高质量的图像和视频。这个家族包括 Imagen、DALLE、Latent Diffusion 等。然而,这个家族中的所有模型都有一个共同的缺点:由于图像生成过程的迭代性质,生成速度相当慢。这使得优化采样循环中运行的代码变得非常重要。
我们以流行的文本到图像扩散模型的开源实现为起点,并使用 PyTorch 2 中可用的两种优化方法(编译和快速注意力实现)加速了其生成过程。结合代码中一些次要的内存处理改进,这些优化相对于没有 xFormers 的原始实现,推理速度提升高达 49%,相对于使用 xFormers 的原始代码(不包括编译时间),推理速度提升 39%,具体取决于 GPU 架构和批次大小。重要的是,这种加速无需安装 xFormers 或任何其他额外依赖项。
下表显示了安装了 xFormers 的原始实现与我们使用 PyTorch 集成的内存高效注意力(最初为 xFormers 库开发并发布)和 PyTorch 编译的优化版本之间的运行时改进。不包括编译时间。
与原始 + xFormers 相比的运行时改进 (%)
请参阅“基准测试设置和结果摘要”部分中的绝对运行时数字。
| GPU | 批次大小 1 | 批次大小 2 | 批次大小 4 | 
| P100(无编译) | -3.8 | 0.44 | 5.47 | 
| T4 | 2.12 | 10.51 | 14.2 | 
| A10 | -2.34 | 8.99 | 10.57 | 
| V100 | 18.63 | 6.39 | 10.43 | 
| A100 | 38.5 | 20.33 | 12.17 | 
我们可以注意到以下几点:
- 对于 A100 和 V100 等强大的 GPU,改进非常显著。对于这些 GPU,批次大小为 1 时改进最为明显。
- 对于性能较低的 GPU,我们观察到较小的加速(或者在两种情况下略有下降)。批次大小趋势在这里是相反的:批次越大,改进越大。
在以下部分中,我们将描述所应用的优化并提供详细的基准测试数据,比较在启用/禁用各种优化功能时的生成时间。
具体来说,我们对 5 种配置进行了基准测试,下面的图表比较了它们在不同 GPU 和批次大小下的绝对性能。有关这些配置的定义,请参阅“基准测试设置和结果”部分。



优化
在这里,我们将更详细地介绍模型代码中引入的优化。这些优化依赖于最近发布的 PyTorch 2.0 的功能。
优化注意力
我们优化的一部分代码是缩放点积注意力。注意力被认为是一个重量级操作:朴素的实现会实例化注意力矩阵,导致时间和内存复杂度与序列长度呈二次方关系。扩散模型通常在 U-Net 的多个部分中将注意力(CrossAttention)作为 Transformer 块的一部分。由于 U-Net 在每个采样步骤都运行,因此这成为一个关键的优化点。与其使用自定义注意力实现,不如使用 torch.nn.MultiheadAttention,PyTorch 2 中已将其集成了优化的注意力实现。这种优化示意性地归结为以下伪代码:
class CrossAttention(nn.Module):
    def __init__(self, ...):
        # Create matrices: Q, K, V, out_proj
        ...
    def forward(self, x, context=None, mask=None):
       # Compute out = SoftMax(Q*K/sqrt(d))V
       # Return out_proj(out)
       …
被替换为
class CrossAttention(nn.Module):
    def __init__(self, ...):
        self.mha = nn.MultiheadAttention(...)
    def forward(self, x, context):
	return self.mha(x, context, context)
注意力的优化实现已在 PyTorch 1.13 中提供(参见此处),并被广泛采用(参见例如 HuggingFace transformers 库示例)。特别是,它集成了来自 xFormers 库的内存高效注意力,以及来自 https://arxiv.org/abs/2205.14135 的 Flash Attention。PyTorch 2.0 将此扩展到额外的注意力函数,例如交叉注意力和自定义内核以进一步加速,使其适用于扩散模型。
Flash attention 适用于计算能力为 SM 7.5 或 SM 8.x 的 GPU——例如,T4、A10 和 A100,这些都包含在我们的基准测试中(您可以在此处查看每个 NVIDIA GPU 的计算能力)。然而,在我们在 A100 上的测试中,由于注意力头的数量较少和批次大小较小,内存高效注意力在扩散模型的特定情况下表现优于 Flash Attention。PyTorch 理解这一点,在这种情况下,当两者都可用时,会选择内存高效注意力而不是 Flash Attention(请参阅此处的逻辑)。为了完全控制注意力后端(内存高效注意力、Flash Attention、“普通数学”或任何未来的后端),高级用户可以使用上下文管理器 torch.backends.cuda.sdp_kernel 手动启用和禁用它们。
编译
编译是 PyTorch 2.0 的新功能,它通过非常简单的用户体验实现显著加速。要调用默认行为,只需将 PyTorch 模块或函数包装到 torch.compile 中即可。
model = torch.compile(model)
PyTorch 编译器随后将 Python 代码转换为一组指令,这些指令可以高效执行,而无需 Python 开销。编译在代码首次执行时动态发生。在默认行为下,PyTorch 在底层利用 TorchDynamo 编译代码,并利用 TorchInductor 进一步优化它。有关更多详细信息,请参阅 此教程。
尽管上述一行代码足以进行编译,但代码中的某些修改可以获得更大的加速。特别是,应该避免所谓的图中断——PyTorch 无法编译的代码位置。与以前的 PyTorch 编译方法(如 TorchScript)不同,PyTorch 2 编译器在这种情况下不会中断。相反,它会回退到即时执行——因此代码会运行,但性能会降低。我们对模型代码进行了一些微小的更改,以消除图中断。这包括消除编译器不支持的库中的函数,例如 inspect.isfunction 和 einops.rearrange。请参阅此 文档 以了解有关图中断以及如何消除它们的更多信息。
理论上,可以将 torch.compile 应用于整个扩散采样循环。然而,实际上,仅编译 U-Net 就足够了。原因是 torch.compile 还没有循环分析器,并且会为采样循环的每次迭代重新编译代码。此外,编译后的采样器代码很可能会生成图中断——因此如果想从编译版本中获得良好的性能,就需要对其进行调整。
请注意,编译要求 GPU 计算能力 >= SM 7.0 才能在非即时模式下运行。这涵盖了我们基准测试中的所有 GPU——T4、V100、A10、A100——除了 P100(请参阅完整列表)。
其他优化
此外,我们通过消除一些常见陷阱来提高 GPU 内存操作的效率,例如直接在 GPU 上创建张量,而不是在 CPU 上创建后将其移动到 GPU。确定需要进行此类优化的位置是通过逐行分析和查看 CPU/GPU 跟踪以及 火焰图 来完成的。
基准测试设置和结果总结
我们有两个版本的代码进行比较:原始版本和优化版本。在此基础上,可以开启/关闭几个优化功能(xFormers、PyTorch 内存高效注意力、编译)。总的来说,正如引言中所述,我们将对 5 种配置进行基准测试。
- 没有 xFormers 的原始代码
- 有 xFormers 的原始代码
- 使用普通数学注意力后端且未编译的优化代码
- 使用内存高效注意力后端且未编译的优化代码
- 使用内存高效注意力后端并编译的优化代码
作为原始版本,我们使用了使用 PyTorch 1.12 和自定义注意力实现的代码版本。优化版本在 CrossAttention 中使用 nn.MultiheadAttention 和 PyTorch 2.0.0.dev20230111+cu117。它还在与 PyTorch 相关的代码中进行了一些其他小的优化。
下表显示了每个代码版本的运行时间(秒),以及相对于原始版本(带 xFormers)的百分比改进。不包括编译时间。
批次大小为 1 时的运行时间。括号中为相对于“原始版本(带 xFormers)”行的相对改进。
| 配置 | P100 | T4 | A10 | V100 | A100 | 
| 无 xFormers 的原始版本 | 30.4 秒 (-19.3%) | 29.8 秒 (-77.3%) | 13.0 秒 (-83.9%) | 10.9 秒 (-33.1%) | 8.0 秒 (-19.3%) | 
| 带 xFormers 的原始版本 | 25.5 秒 (0.0%) | 16.8 秒 (0.0%) | 7.1 秒 (0.0%) | 8.2 秒 (0.0%) | 6.7 秒 (0.0%) | 
| 优化版本(普通数学注意力,无编译) | 27.3 秒 (-7.0%) | 19.9 秒 (-18.7%) | 13.2 秒 (-87.2%) | 7.5 秒 (8.7%) | 5.7 秒 (15.1%) | 
| 优化版本(内存高效注意力,无编译) | 26.5 秒 (-3.8%) | 16.8 秒 (0.2%) | 7.1 秒 (-0.8%) | 6.9 秒 (16.0%) | 5.3 秒 (20.6%) | 
| 优化版本(内存高效注意力并编译) | – | 16.4 秒(2.1%) | 7.2 秒 (-2.3%) | 6.6 秒 (18.6%) | 4.1 秒 (38.5%) | 
批次大小为 2 时的运行时间
| 配置 | P100 | T4 | A10 | V100 | A100 | 
| 无 xFormers 的原始版本 | 58.0 秒 (-21.6%) | 57.6 秒 (-84.0%) | 24.4 秒 (-95.2%) | 18.6 秒 (-63.0%) | 12.0 秒 (-50.6%) | 
| 带 xFormers 的原始版本 | 47.7 秒 (0.0%) | 31.3 秒 (0.0%) | 12.5 秒 (0.0%) | 11.4 秒 (0.0%) | 8.0 秒 (0.0%) | 
| 优化版本(普通数学注意力,无编译) | 49.3 秒 (-3.5%) | 37.9 秒 (-21.0%) | 17.8 秒 (-42.2%) | 12.7 秒 (-10.7%) | 7.8 秒 (1.8%) | 
| 优化版本(内存高效注意力,无编译) | 47.5 秒(0.4%) | 31.2 秒 (0.5%) | 12.2 秒 (2.6%) | 11.5 秒 (-0.7%) | 7.0 秒 (12.6%) | 
| 优化版本(内存高效注意力并编译) | – | 28.0 秒 (10.5%) | 11.4 秒 (9.0%) | 10.7 秒(6.4%) | 6.4 秒 (20.3%) | 
批次大小为 4 时的运行时间
| 配置 | P100 | T4 | A10 | V100 | A100 | 
| 无 xFormers 的原始版本 | 117.9 秒 (-20.0%) | 112.4 秒 (-81.8%) | 47.2 秒 (-101.7%) | 35.8 秒 (-71.9%) | 22.8 秒 (-78.9%) | 
| 带 xFormers 的原始版本 | 98.3 秒 (0.0%) | 61.8 秒 (0.0%) | 23.4 秒 (0.0%) | 20.8 秒 (0.0%) | 12.7 秒 (0.0%) | 
| 优化版本(普通数学注意力,无编译) | 101.1 秒 (-2.9%) | 73.0 秒 (-18.0%) | 28.3 秒 (-21.0%) | 23.3 秒 (-11.9%) | 14.5 秒 (-13.9%) | 
| 优化版本(内存高效注意力,无编译) | 92.9 秒(5.5%) | 61.1 秒 (1.2%) | 23.9 秒 (-1.9%) | 20.8 秒 (-0.1%) | 12.8 秒 (-0.9%) | 
| 优化版本(内存高效注意力并编译) | – | 53.1 秒(14.2%) | 20.9 秒 (10.6%) | 18.6 秒 (10.4%) | 11.2 秒 (12.2%) | 
为了尽量减少波动和外部对基准代码性能的影响,我们依次运行每个代码版本,然后重复此序列 10 次:A、B、C、D、E、A、B……因此,典型运行的结果如下图所示。请注意,不应依赖不同图表之间绝对运行时间的比较,但由于我们的基准测试设置,同一图表内的运行时间比较是相当可靠的。

文本到图像生成脚本的每次运行都会生成多个批次,其数量由 CLI 参数 --n_iter 控制。在基准测试中,我们使用了 n_iter = 2,但引入了一个额外的“预热”迭代,它不计入运行时间。这对于需要编译的运行是必要的,因为编译在代码首次运行时发生,因此第一次迭代比所有后续迭代都要长得多。为了公平比较,我们还将此额外的“预热”迭代引入到所有其他运行中。
上表中的数字适用于迭代次数 2(加上一个“预热”迭代),提示语“A photo”,种子 1,PLMS 采样器,并开启 autocast。
基准测试使用 P100、V100、A100、A10 和 T4 GPU 完成。T4 基准测试在 Google Colab Pro 中完成。A10 基准测试在带有 1 个 GPU 的 g5.4xlarge AWS 实例上完成。
结论和下一步
我们已经证明,PyTorch 2 的新功能——编译器和优化的注意力实现——带来了性能改进,这些改进超越或可与之前需要安装外部依赖项(xFormers)的方案相媲美。PyTorch 之所以能够做到这一点,特别是通过将 xFormers 的内存高效注意力集成到其代码库中。鉴于 xFormers 作为最先进的库,在许多情况下需要自定义安装过程和漫长的构建,这对于用户体验来说是一个显著的改进。
这项工作可以从几个自然的方向继续进行
- 我们在此实施和描述的优化目前仅针对文本到图像推理进行了基准测试。看看它们如何影响训练性能将很有趣。PyTorch 编译可以直接应用于训练;启用使用 PyTorch 优化注意力的训练正在路线图上。
- 我们有意地最小化了对原始模型代码的更改。进一步的性能分析和优化可能会带来更多的改进。
- 目前,编译仅应用于采样器内部的 U-Net 模型。由于 U-Net 外部发生了很多事情(例如直接在采样循环中的操作),因此编译整个采样器将是有益的。然而,这需要分析编译过程以避免在每个采样步骤重新编译。
- 当前代码仅在 PLMS 采样器中应用编译,但将其扩展到其他采样器应该很简单。
- 除了文本到图像生成,扩散模型还应用于其他任务——图像到图像和图像修复。衡量它们的性能如何通过 PyTorch 2 优化得到改善将很有趣。
看看您是否可以使用我们描述的方法提高开源扩散模型的性能,并分享结果!
资源
- PyTorch 2.0 概述,其中包含大量关于 torch.compile的信息:https://pytorch.ac.cn/get-started/pytorch-2.0/
- torch.compile教程:https://pytorch.ac.cn/tutorials/intermediate/torch_compile_tutorial.html
- 一般编译故障排除:https://pytorch.ac.cn/docs/stable/torch.compiler_troubleshooting.html
- 图中断详情:https://pytorch.ac.cn/docs/stable/torch.compiler_faq.html#identifying-the-cause-of-a-graph-break
- 守卫详情:https://pytorch.ac.cn/docs/stable/torch.compiler_guards_overview.html
- TorchDynamo 深度解析视频:https://www.youtube.com/watch?v=egZB5Uxki0I
- PyTorch 1.12 中优化注意力的教程:https://pytorch.ac.cn/tutorials/beginner/bettertransformer_tutorial.html
致谢
我们要感谢 Geeta Chauhan、Natalia Gimelshein、Patrick Labatut、Bert Maher、Mark Saroufim、Michael Voznesensky 和 Francisco Massa,感谢他们宝贵的建议和对文本的早期反馈。
特别感谢 Yudong Tao 启动了在扩散模型中使用 PyTorch 原生注意力方面的工作。