作者:Pedro Cuenca, Patrick von Platen, Suraj Patil, Sayak Paul

PyTorch 2.0 刚刚发布。其旗舰新特性是 torch.compile(),这是一个只需改动一行代码就能承诺自动提升跨代码库性能的功能。我们之前已在 Hugging Face Transformers 和 TIMM 模型中验证了这一承诺,并深入探讨了其动机、架构和未来发展方向

尽管 torch.compile() 十分重要,但 PyTorch 2.0 的亮点远不止于此。值得注意的是,PyTorch 2.0 集成了多种策略来加速 Transformer 块,这些改进对扩散模型也非常重要。例如,FlashAttention 等技术因其能够显著加速 Stable Diffusion 并实现更大批量而备受扩散社区欢迎,它们现在已成为 PyTorch 2.0 的一部分。

在这篇文章中,我们将讨论 PyTorch 2.0 中如何优化注意力层,以及这些优化如何应用于流行的 🧨 Diffusers 库。最后,我们通过一个基准测试展示了使用 PyTorch 2.0 和 Diffusers 如何立即转化为在不同硬件上的显著性能提升。

更新(2023 年 6 月):新增了一个章节,展示了在修复 Diffusers 代码库中的图断点(graph breaks)后,最新版本 PyTorch (2.0.1) 中 torch.compile() 带来的显著性能提升。关于如何查找和修复图断点的更详细分析将在另一篇文章中发布。

加速 Transformer 块

PyTorch 2.0 包含一个 *scaled dot-product attention*(缩放点积注意力)函数,作为 torch.nn.functional 的一部分。此函数包含了多种实现,可根据输入和所使用的硬件应用不同的实现。在 PyTorch 2.0 之前,您必须搜索第三方实现并安装单独的软件包才能利用内存优化算法,例如 FlashAttention。可用的实现包括

  • FlashAttention,来自官方的 FlashAttention 项目
  • Memory-Efficient Attention(内存高效注意力),来自 xFormers 项目
  • 一种原生的 C++ 实现,适用于非 CUDA 设备或需要高精度时。

所有这些方法默认可用,PyTorch 将尝试通过使用新的缩放点积注意力 (SDPA) API 自动选择最优的方法。您还可以单独切换它们以进行更精细的控制,详情请参阅文档

在 Diffusers 中使用缩放点积注意力

加速的 PyTorch 2.0 Transformer 注意力功能整合到 Diffusers 库是通过使用 set_attn_processor 方法实现的,该方法允许配置可插拔的注意力模块。在这种情况下,创建了一个新的注意力处理器,当 PyTorch 2.0 可用时,该处理器默认启用。为了清晰起见,您可以这样手动启用它(但这通常不是必需的,因为 diffusers 会自动处理)

from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0())

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

Stable Diffusion 基准测试

我们在 Diffusers 中使用 PyTorch 2.0 的加速点积注意力运行了多项测试。我们通过 pip 安装了 diffusers,并使用了 PyTorch 2.0 的每夜构建版本,因为我们的测试是在正式发布前进行的。我们还使用了 torch.set_float32_matmul_precision('high') 来启用额外的快速矩阵乘法算法。

我们将结果与 diffusers 中的传统注意力实现(下文称为 vanilla)以及 PyTorch 2.0 之前表现最佳的解决方案进行了比较:即安装了 xFormers 软件包 (v0.0.16) 的 PyTorch 1.13.1。

结果测量是在未进行编译(即完全没有代码改动)的情况下进行的,也包括仅调用一次 torch.compile() 包装 UNet 模块的情况。我们没有编译图像解码器,因为大部分时间都花在运行 UNet 评估的 50 次去噪迭代中。

Float32 结果

Diffusers Speedup vs xFormers float32

下图展示了不同代系代表性 GPU 的性能提升与批量大小的关系。我们收集了每种组合的数据,直到达到最大内存利用率。Vanilla 注意力比 xFormers 或 PyTorch 2.0 更早耗尽内存,这解释了较大批量大小的缺失条形图。同样,A100(我们使用了 40 GB 版本)能够运行 64 的批量大小,但在我们的测试中,其他 GPU 只能达到 32。

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (V100, float32)

我们发现在未甚至使用 torch.compile() 的情况下,相对于 vanilla 注意力,性能提升非常显著。PyTorch 2.0 和 diffusers 的开箱即用安装在 A100 上实现了约 50% 的加速,在 4090 GPU 上根据批量大小的不同实现了 35% 到 50% 的加速。性能提升在现代 CUDA 架构(如 Ada (4090) 或 Ampere (A100))上更为明显,但在云服务中仍大量使用的较旧架构上,性能提升仍然非常显著。

除了更快的速度外,PyTorch 2.0 中的加速 Transformer 实现允许使用更大的批量大小。单个 40GB A100 GPU 在批量大小为 10 时内存不足,而 24 GB 高端消费级显卡(如 3090 和 4090)无法一次生成 8 张图像。使用 PyTorch 2.0 和 diffusers,我们可以在 3090 和 4090 上实现 48 的批量大小,在 A100 上实现 64 的批量大小。这对于云服务和应用具有重要意义,因为它们可以一次高效地处理更多图像。

与 PyTorch 1.13.1 + xFormers 相比,新的加速 Transformer 实现仍然更快,并且无需额外的软件包或依赖项。在这种情况下,我们在数据中心显卡(如 A100 或 T4)上发现了高达 2% 的中等加速,但在最近两代消费级显卡上表现出色:在 3090 上性能提升高达 20%,在 4090 上根据批量大小不同提升 10% 到 45%。

当使用 torch.compile() 时,我们在之前的改进基础上获得了额外的性能提升(通常为 2% 和 3%)。由于编译需要一些时间,这更适合面向用户的推理服务或训练。更新:当图断点(graph breaks)最小化后,torch.compile() 带来的改进要大得多,详情请参见新章节

Float16 结果

Diffusers Speedup vs xFormers float16

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float16)

当我们考虑 float16 推理时,PyTorch 2.0 中加速 Transformer 实现的性能提升在所有测试的 GPU 上相对于标准注意力提高了 20% 到 28%,但 4090 除外,它属于更现代的 Ada 架构。使用 PyTorch 2.0 每夜构建版本时,这款 GPU 的性能获得了显著提升。与优化后的 SDPA 对比 xFormers 相比,大多数 GPU 的结果通常不相上下,但 4090 除外。将 torch.compile() 加入其中,整体性能又提升了几个百分点。

最小化图断点(graph breaks)后 torch.compile() 的性能

在之前的章节中,我们看到使用 PyTorch 2.0 的加速 Transformer 实现相对于早期版本的 PyTorch(无论是否使用 xFormers)提供了重要的性能提升。然而,torch.compile() 仅贡献了适度的边际改进。在 PyTorch 团队的帮助下,我们发现这些适度改进的原因是 diffusers 源代码中的一些操作导致了图断点(graph breaks),这使得 torch.compile() 无法充分利用图优化。

在修复了图断点(graph breaks)后(详情请参阅这些 PRs),我们测量了 torch.compile() 相对于未编译的 PyTorch 2 版本的额外改进,并看到了非常显著的增量性能提升。下图是使用 2023 年 5 月 1 日下载的 PyTorch 2 每夜构建版本获得的,它显示在大多数工作负载下有约 13% 到 22% 的改进范围。对于现代 GPU 系列,性能提升更好,在 A100 上达到 30% 以上。图表中还有两个异常值。首先,我们在 T4 上看到批量大小为 16 时性能下降,这对该显卡造成了巨大的内存压力。另一方面,我们在 A100 上看到批量大小仅为 1 时性能提升超过 100%,这很有趣,但并不能代表拥有如此大内存 GPU 的实际应用——能够服务多个客户的更大批量大小通常对 A100 上的服务部署更具吸引力。

Diffusers Speedup using torch.compile() in float16

再次强调,这些性能提升是在迁移到 PyTorch 2 并使用加速 Transformer 缩放点积注意力实现所获得的提升基础之上的*额外*提升。我们建议在生产环境中部署 diffusers 时使用 torch.compile()

结论

PyTorch 2.0 带来了多项功能,可以优化基础 Transformer 块的关键组件,并且可以通过使用 torch.compile 进一步改进。这些优化为扩散模型带来了显著的内存和时间改进,并消除了对第三方库安装的需求。

要利用这些速度和内存改进,您只需升级到 PyTorch 2.0 并使用 diffusers >= 0.13.0。

有关更多示例和详细的基准测试数据,请参阅 PyTorch 2.0 with Diffusers 文档。

致谢

作者感谢 PyTorch 团队创造了如此出色的软件。