跳转到主要内容
博客

使用 PyTorch 2.0 加速 Diffusers

作者: 2023 年 3 月 16 日2024 年 11 月 14 日暂无评论

PyTorch 2.0 刚刚发布。其旗舰新功能是 torch.compile(),只需一行代码即可自动改进代码库的性能。我们之前已经在 Hugging Face Transformers 和 TIMM 模型中验证了这一承诺,并深入探讨了其 动机、架构和未来发展方向

尽管 torch.compile() 很重要,但 PyTorch 2.0 还有更多功能。值得注意的是,PyTorch 2.0 融合了几种加速 Transformer 模块的策略,这些改进对扩散模型也非常重要。例如,FlashAttention 等技术因其显著加速 Stable Diffusion 并实现更大批量大小的能力而在扩散社区中广受欢迎,现在它们已成为 PyTorch 2.0 的一部分。

在这篇文章中,我们将讨论 PyTorch 2.0 中如何优化注意力层以及如何将这些优化应用于流行的 🧨 Diffusers 库。最后,我们通过一个基准测试来展示使用 PyTorch 2.0 和 Diffusers 如何立即带来不同硬件上的显著性能提升。

更新(2023 年 6 月):已添加新章节,展示了在修复 Diffusers 代码库中的图中断后,torch.compile() 在最新版 PyTorch (2.0.1) 中带来的显著性能改进。关于如何查找和修复图中断的更详细分析将在另一篇文章中发布。

加速 Transformer 模块

PyTorch 2.0 包含一个作为 torch.nn.functional 一部分的“缩放点积注意力”函数。此函数包含多种实现,可根据输入和所使用的硬件进行应用。在 PyTorch 2.0 之前,您必须搜索第三方实现并安装单独的软件包才能利用内存优化算法,例如 FlashAttention。可用的实现包括:

所有这些方法都默认可用,PyTorch 将通过使用新的缩放点积注意力 (SDPA) API 自动尝试选择最佳方法。您也可以单独切换它们以进行更精细的控制,详情请参见 文档

在 Diffusers 中使用缩放点积注意力

将加速 PyTorch 2.0 Transformer 注意力整合到 Diffusers 库是通过使用 set_attn_processor 方法 实现的,该方法允许配置可插拔的注意力模块。在这种情况下,创建了一个新的注意力处理器,当 PyTorch 2.0 可用时,该处理器默认启用。为了清晰起见,您可以手动启用它(但通常不需要,因为 Diffusers 会自动处理):

from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0())

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

Stable Diffusion 基准测试

我们使用 PyTorch 2.0 中的加速点积注意力在 Diffusers 中进行了一系列测试。我们从 pip 安装了 Diffusers,并使用了 PyTorch 2.0 的每夜版本,因为我们的测试是在正式发布之前进行的。我们还使用了 torch.set_float32_matmul_precision('high') 来启用额外的快速矩阵乘法算法。

我们将结果与 diffusers 中传统的注意力实现(下文称为 vanilla)以及 PyTorch 2.0 之前表现最佳的解决方案(PyTorch 1.13.1 安装了 xFormers 包 (v0.0.16))进行了比较。

结果在没有编译(即,完全没有代码更改)的情况下进行测量,并且还通过对 torch.compile() 的单次调用来封装 UNet 模块。我们没有编译图像解码器,因为大部分时间都花在运行 UNet 评估的 50 次去噪迭代中。

Float32 结果

Diffusers Speedup vs xFormers float32

下图探讨了不同世代代表性 GPU 的性能提升与批量大小的关系。我们为每种组合收集数据,直到达到最大内存利用率。Vanilla 注意力比 xFormers 或 PyTorch 2.0 更早耗尽内存,这解释了更大批量大小的缺失条形图。同样,A100(我们使用了 40 GB 版本)能够运行 64 的批量大小,但在我们的测试中,其他 GPU 只能达到 32。

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float32)
Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float32)
Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float32)
Diffusers Inference Speedup vs Vanilla and xFormers Attention (V100, float32)

我们发现,即使不使用 torch.compile(),与 vanilla 注意力相比,整体性能也有显著提升。PyTorch 2.0 和 Diffusers 的开箱即用安装在 A100 上带来了约 50% 的加速,在 4090 GPU 上根据批量大小的不同,加速介于 35% 到 50% 之间。性能提升在 Ada (4090) 或 Ampere (A100) 等现代 CUDA 架构上更为显著,但对于仍在云服务中大量使用的旧架构来说,仍然非常显著。

除了更快的速度,PyTorch 2.0 中加速的 transformer 实现允许使用更大的批量大小。单个 40GB A100 GPU 在批量大小为 10 时会耗尽内存,而 24 GB 的高端消费级显卡(如 3090 和 4090)无法一次生成 8 张图像。使用 PyTorch 2.0 和 Diffusers,我们可以在 3090 和 4090 上实现 **48** 的批量大小,在 A100 上实现 **64** 的批量大小。这对于云服务和应用程序具有重要意义,因为它们可以一次高效地处理更多图像。

与 PyTorch 1.13.1 + xFormers 相比,新的加速 Transformer 实现仍然更快,并且无需额外的软件包或依赖项。在这种情况下,我们发现在 A100 或 T4 等数据中心卡上实现了高达 2% 的适度加速,但在两代最新的消费级卡上表现出色:在 3090 上实现了高达 20% 的速度提升,在 4090 上根据批量大小的不同,实现了 10% 到 45% 的提升。

当使用 torch.compile() 时,我们在之前的改进基础上额外获得了(通常)2% 和 3% 的性能提升。由于编译需要一些时间,这更适合面向用户的推理服务或训练。更新:当图中断最小化时,torch.compile() 所实现的改进要大得多,详情请参见新章节

Float16 结果

Diffusers Speedup vs xFormers float16
Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float16)
Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float16)
Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float16)

当我们考虑 float16 推理时,PyTorch 2.0 中加速 transformer 实现的性能提升在标准注意力之上达到了 20% 到 28%,这适用于我们测试的所有 GPU,除了属于更现代的 Ada 架构的 4090。当使用 PyTorch 2.0 每夜版本时,这款 GPU 受益于显著的性能提升。至于优化后的 SDPA 与 xFormers 的比较,除了 4090 之外,大多数 GPU 的结果通常持平。将 torch.compile() 添加到其中,整体性能会再提升几个百分点。

torch.compile() 在最小化图中断后的性能

在前面的章节中,我们看到使用 PyTorch 2.0 的加速 Transformer 实现比早期版本的 PyTorch(无论是否使用 xFormers)提供了重要的性能改进。然而,torch.compile() 仅带来了适度的边际改进。在 PyTorch 团队的帮助下,我们发现这些适度改进的原因是 Diffusers 源代码中的某些操作导致了图中断,这使得 torch.compile() 无法充分利用图优化。

修复图中断后(详情请参见 这些 PRs),我们测量了 torch.compile() 相对于未编译的 PyTorch 2 版本的额外改进,我们看到了非常重要的增量性能增益。下图是使用 2023 年 5 月 1 日下载的 PyTorch 2 每夜版本获得的,它显示了大多数工作负载的改进范围约为 ~13% 到 22%。现代 GPU 系列的性能增益更好,A100 达到了 30% 以上。图表中还有两个异常值。首先,在 T4 上,当批量大小为 16 时,我们看到性能下降,这给该卡带来了巨大的内存压力。另一方面,当批量大小仅为 1 时,A100 的性能提升超过 100%,这很有趣,但并不代表具有如此大内存的 GPU 的实际使用情况——能够服务多个客户的更大批量大小通常对 A100 上的服务部署更感兴趣。

Diffusers Speedup using torch.compile() in float16

再次强调,这些性能提升是迁移到 PyTorch 2 并使用加速 Transformer 缩放点积注意力实现所实现的性能提升的**额外**部分。我们建议在生产环境中部署 Diffusers 时使用 torch.compile()

结论

PyTorch 2.0 具有多项功能,可优化基础 Transformer 模块的关键组件,并且可以通过使用 torch.compile 进一步改进。这些优化为扩散模型带来了显著的内存和时间改进,并消除了对第三方库安装的需求。

要利用这些速度和内存改进,您只需升级到 PyTorch 2.0 并使用 Diffusers >= 0.13.0。

有关更多示例和详细的基准测试数据,请参阅 PyTorch 2.0 Diffusers 文档

致谢

作者感谢 PyTorch 团队创建了如此出色的软件。