作者:Pedro Cuenca、Patrick von Platen、Suraj Patil、Sayak Paul

PyTorch 2.0 刚刚发布。其旗舰新功能是 torch.compile(),只需一行代码更改,即可自动提高整个代码库的性能。我们之前已经在 Hugging Face Transformers 和 TIMM 模型中验证了这一承诺,并深入探讨了其 动机、架构和未来发展方向

torch.compile() 非常重要,但 PyTorch 2.0 的功能远不止于此。值得注意的是,PyTorch 2.0 融入了多种加速 Transformer 模块的策略,这些改进对于扩散模型也非常重要。例如,FlashAttention 等技术由于能够显著加速 Stable Diffusion 并实现更大的批大小,因此在扩散社区中非常受欢迎,现在它们已成为 PyTorch 2.0 的一部分。

在这篇文章中,我们将讨论如何在 PyTorch 2.0 中优化注意力层,以及如何将这些优化应用于流行的 🧨 Diffusers 库。最后,我们将进行基准测试,展示 PyTorch 2.0 和 Diffusers 的使用如何在不同的硬件上立即转化为显著的性能提升。

更新(2023 年 6 月):添加了一个新章节,以展示在修复 diffusers 代码库中的图形中断后,最新版本的 PyTorch (2.0.1) 中 torch.compile() 的显著性能提升。有关如何查找和修复图形中断的更详细分析将在另一篇文章中发布。

加速 Transformer 模块

PyTorch 2.0 包含一个缩放点积注意力函数,作为 torch.nn.functional 的一部分。此函数包含多种实现,可以根据输入和使用的硬件应用。在 PyTorch 2.0 之前,您必须搜索第三方实现并安装单独的软件包才能利用内存优化的算法,例如 FlashAttention。可用的实现包括

所有这些方法默认都可用,PyTorch 将尝试通过使用新的缩放点积注意力 (SDPA) API 自动选择最佳方法。您也可以单独切换它们以进行更精细的控制,请参阅 文档 了解详细信息。

在 diffusers 中使用缩放点积注意力

通过使用 set_attn_processor 方法,将加速的 PyTorch 2.0 Transformer 注意力融入 Diffusers 库得以实现,该方法允许配置可插拔的注意力模块。在这种情况下,创建了一个新的注意力处理器,当 PyTorch 2.0 可用时,默认启用。为了清晰起见,以下是如何手动启用它(但通常没有必要,因为 diffusers 会自动处理它)

from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0())

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

Stable Diffusion 基准测试

我们在 Diffusers 中使用来自 PyTorch 2.0 的加速点积注意力运行了多项测试。我们通过 pip 安装了 diffusers,并使用了 PyTorch 2.0 的 nightly 版本,因为我们的测试是在正式发布之前进行的。我们还使用了 torch.set_float32_matmul_precision('high') 来启用其他快速矩阵乘法算法。

我们将结果与 diffusers 中的传统注意力实现(以下称为 vanilla)以及 pre-2.0 PyTorch 中性能最佳的解决方案进行了比较:安装了 xFormers 软件包 (v0.0.16) 的 PyTorch 1.13.1。

结果是在未编译的情况下(即,完全没有代码更改)以及通过单次调用 torch.compile() 来包装 UNet 模块的情况下测量的。我们没有编译图像解码器,因为大部分时间都花在了运行 UNet 评估的 50 次去噪迭代中。

float32 中的结果

Diffusers Speedup vs xFormers float32

下图探讨了不同代代表性 GPU 的性能提升与批大小的关系。我们收集了每种组合的数据,直到达到最大内存利用率。Vanilla 注意力比 xFormers 或 PyTorch 2.0 更早耗尽内存,这解释了较大批大小的缺失条形。同样,A100(我们使用了 40 GB 版本)能够运行 64 的批大小,但其他 GPU 在我们的测试中只能达到 32。

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (V100, float32)

我们发现,即使不使用 torch.compile(),在所有方面,性能都比 vanilla 注意力有了非常显著的提升。开箱即用的 PyTorch 2.0 和 diffusers 安装在 A100 上产生了约 50% 的加速,在 4090 GPU 上产生了 35% 到 50% 的加速,具体取决于批大小。对于 Ada (4090) 或 Ampere (A100) 等现代 CUDA 架构,性能提升更为明显,但对于云服务中仍在大量使用的旧架构,性能提升仍然非常显著。

除了更快的速度外,PyTorch 2.0 中的加速 Transformer 实现还允许使用更大的批大小。单个 40GB A100 GPU 在批大小为 10 时会耗尽内存,而 3090 和 4090 等 24 GB 高端消费级显卡一次无法生成 8 张图像。使用 PyTorch 2.0 和 diffusers,我们可以为 3090 和 4090 实现 48 的批大小,为 A100 实现 64 的批大小。这对于云服务和应用程序具有重要意义,因为它们可以一次高效地处理更多图像。

与 PyTorch 1.13.1 + xFormers 相比,新的加速 Transformer 实现仍然更快,并且不需要额外的软件包或依赖项。在这种情况下,我们发现数据中心显卡(如 A100 或 T4)的加速适中,高达 2%,但在最新两代消费级显卡上的性能非常出色:3090 的速度提升高达 20%,4090 的速度提升在 10% 到 45% 之间,具体取决于批大小。

当使用 torch.compile() 时,我们获得了(通常)比之前的改进额外 2% 和 3% 的性能提升。由于编译需要一些时间,因此更适合面向用户的推理服务或训练。更新:当图形中断最小化时,torch.compile() 实现的改进要大得多,请参阅新章节了解详细信息

float16 中的结果

Diffusers Speedup vs xFormers float16

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float16)

当我们考虑 float16 推理时,PyTorch 2.0 中加速 Transformer 实现的性能提升在所有我们测试的 GPU 上都比标准注意力高出 20% 到 28%,除了属于更现代 Ada 架构的 4090。这款 GPU 在使用 PyTorch 2.0 nightly 版本时受益于显著的性能提升。关于优化的 SDPA 与 xFormers,大多数 GPU 的结果通常相当,除了 4090 再次例外。在组合中添加 torch.compile() 会使整体性能再提高几个百分点。

最小化图形中断后 torch.compile() 的性能

在前面的章节中,我们看到,与早期版本的 PyTorch(无论是否使用 xFormers)相比,使用 PyTorch 2.0 的加速 Transformer 实现提供了重要的性能提升。然而,torch.compile() 仅贡献了适度的边际改进。在 PyTorch 团队的帮助下,我们发现这些适度改进的原因是 diffusers 源代码中的某些操作导致了图形中断,这阻止了 torch.compile() 充分利用图形优化。

在修复图形中断后(有关详细信息,请参阅 这些 PR),我们测量了 torch.compile() 与未编译版本的 PyTorch 2 相比的额外改进,并且我们看到了非常重要的增量性能提升。下图是使用 2023 年 5 月 1 日下载的 PyTorch 2 nightly 版本获得的,它显示大多数工作负载的改进范围约为 13% 到 22%。对于现代 GPU 系列,性能提升更加明显,A100 的性能提升超过 30%。图表中还有两个异常值。首先,我们看到 T4 在批大小为 16 时性能下降,这对该显卡造成了巨大的内存压力。在频谱的另一端,我们看到 A100 在仅使用批大小为 1 时性能提升超过 100%,这很有趣,但不能代表具有如此大 RAM 容量的 GPU 的实际使用情况——对于 A100 上的服务部署,能够服务多个客户的较大批大小通常会更有意义。

Diffusers Speedup using torch.compile() in float16

再次强调,这些性能提升是额外的,是在迁移到 PyTorch 2 并使用加速 Transformer 缩放点积注意力实现所取得的性能提升之外的。我们建议在生产环境中部署 diffusers 时使用 torch.compile()

结论

PyTorch 2.0 具有多项功能,可以优化基础 Transformer 模块的关键组件,并且可以通过使用 torch.compile 进一步改进这些组件。这些优化显着提高了扩散模型的内存和时间性能,并消除了对第三方库安装的需求。

要利用这些速度和内存改进,您只需升级到 PyTorch 2.0 并使用 diffusers >= 0.13.0。

有关更多示例和详细的基准测试数字,另请参阅 Diffusers with PyTorch 2.0 文档。

致谢

作者感谢 PyTorch 团队创建了如此出色的软件。