博客

加速生成式AI 第三部分:快速扩散

本文是系列博客的第三篇,专注于如何利用纯原生 PyTorch 加速生成式 AI 模型。我们很高兴能分享一系列新发布的 PyTorch 性能特性,并通过实际示例展示我们能将 PyTorch 的原生性能提升到什么程度。在第一部分中,我们展示了如何仅使用纯原生 PyTorch 将 Segment Anything 加速 8 倍以上。在第二部分中,我们展示了如何仅使用原生 PyTorch 优化将 Llama-7B 加速近 10 倍。在本文中,我们将重点介绍如何将文本到图像扩散模型(text-to-image diffusion models)的生成速度最高提升 3 倍。

我们将利用多种优化手段,包括:

  • 使用 bfloat16 精度运行
  • 缩放点积注意力 (SPDA)
  • torch.compile
  • 合并 q、k、v 投影以进行注意力计算
  • 动态 int8 量化

我们将主要关注 Stable Diffusion XL (SDXL),展示 3 倍的延迟改进。这些技术均属于 PyTorch 原生,这意味着你无需依赖任何第三方库或 C++ 代码即可利用它们。

通过 🤗Diffusers 库启用这些优化只需几行代码。如果你已经迫不及待想要尝试,请查看随附的仓库:https://github.com/huggingface/diffusion-fast

SDXL Chart

(所讨论的技术并非 SDXL 专用,也可用于加速其他文本到图像的扩散系统,详见后文。)

以下是一些关于类似主题的博客文章:

设置

我们将使用 🤗Diffusers 库来演示这些优化及其各自的提速增益。此外,我们还将使用以下 PyTorch 原生库和环境:

  • Torch nightly(以利用最高效注意力机制的最快内核;2.3.0.dev20231218+cu121)
  • 🤗 PEFT(版本:0.7.1)
  • torchao(提交 SHA:54bcd5a10d0abbe7b0c045052029257099f83fd9)
  • CUDA 12.1

为了获得更简便的复现环境,你也可以参考此 Dockerfile。本文提供的基准测试数据来自一块 400W 80GB 的 A100 GPU(其时钟频率已设置为最大容量)。

由于我们这里使用的是 A100 GPU(Ampere 架构),我们可以指定 torch.set_float32_matmul_precision("high") 来受益于 TF32 精度格式

使用降低的精度运行推理

在 Diffusers 中运行 SDXL 只需几行代码:

from diffusers import StableDiffusionXLPipeline

## Load the pipeline in full-precision and place its model components on CUDA.
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0").to("cuda")

## Run the attention ops without efficiency.
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

但这并不实用,因为生成一张 30 步的图像需要 7.36 秒。这是我们的基准线,我们将尝试逐步对其进行优化。

SDXL Chart

这里,我们使用全精度运行流水线。我们可以通过使用较低的精度(如 bfloat16)立即缩短推理时间。此外,现代 GPU 配备了专门用于运行受益于降低精度的加速计算的核心。要以 bfloat16 精度运行流水线计算,我们只需在初始化流水线时指定数据类型即可:

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

## Run the attention ops without efficiency.
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]
SDXL Chart

通过使用降低的精度,我们将推理延迟从 7.36 秒降低到了 4.63 秒

关于使用 bfloat16 的一些说明:

  • 使用较低的数值精度(如 float16、bfloat16)运行推理不会影响生成质量,但会显著改善延迟。
  • 与 float16 相比,使用 bfloat16 数值精度的优势与硬件相关。现代 GPU 更倾向于 bfloat16。
  • 此外,在我们的实验中,相比 float16,我们发现 bfloat16 在与量化结合使用时具有更强的韧性。

(我们后来在 float16 下也进行了实验,发现最新版本的 torchao 在 float16 下不会导致数值问题。)

使用 SDPA 执行注意力计算

默认情况下,当使用 PyTorch 2 时,Diffusers 使用 scaled_dot_product_attention (SDPA) 来执行注意力相关计算。SDPA 提供了更快、更高效的内核来运行密集的注意力操作。要运行 SDPA 流水线,我们只需像这样不设置任何注意力处理器:

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

SDPA 带来了显著的提升,从 4.63 秒缩短至 3.31 秒

SDXL Chart

编译 UNet 和 VAE

我们可以通过 torch.compile 要求 PyTorch 执行一些低级优化(例如算子融合和使用 CUDA 图启动更快的内核)。对于 StableDiffusionXLPipeline,我们编译去噪器 (UNet) 和 VAE:

from diffusers import StableDiffusionXLPipeline
import torch

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

## Compile the UNet and VAE.
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

## First call to `pipe` will be slow, subsequent ones will be faster.
image = pipe(prompt, num_inference_steps=30).images[0]

同时使用 SDPA 注意力并编译 UNet 和 VAE,可将延迟从 3.31 秒进一步降低至 2.54 秒

SDXL Chart

关于 torch.compile 的说明:

torch.compile 提供不同的后端和模式。由于我们追求极致的推理速度,我们选择了 inductor 后端并使用 "max-autotune"。 "max-autotune" 使用 CUDA 图,并专门针对延迟优化了编译图。使用 CUDA 图极大地减少了启动 GPU 操作的开销,通过一种机制实现通过单个 CPU 操作启动多个 GPU 操作,从而节省时间。

fullgraph 指定为 True 可确保底层模型中没有图断裂(graph breaks),从而最大限度地发挥 torch.compile 的潜力。在我们的案例中,显式设置以下编译器标志也很重要:

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

有关编译器标志的完整列表,请参考 此文件。

在编译 UNet 和 VAE 时,我们还将它们的内存布局更改为“channels_last”,以确保最高速度。

pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

在下一节中,我们将展示如何进一步改善延迟。

额外优化

torch.compile 期间无图断裂

确保底层模型/方法能够完全编译(使用 torch.compile 配合 fullgraph=True)对于性能至关重要。这意味着没有图断裂。我们通过改变访问返回变量的方式对 UNet 和 VAE 实现了这一点。请考虑以下示例:

code example

消除编译后的 GPU 同步

在迭代反向扩散过程中,每当去噪器预测出噪声较小的潜在嵌入后,我们都会在调度器上 调用 step()。在 step() 内部,会对 sigmas 变量进行 索引。如果 sigmas 数组位于 GPU 上,索引操作会导致 CPU 和 GPU 之间的通信同步。这会导致延迟,并且当去噪器被编译后,这种延迟会变得更加明显。

但如果 sigmas 数组始终保留在 CPU 上(参考 此行),则不会发生同步,从而改善了延迟。通常,任何 CPU <-> GPU 通信同步都应避免或保持在最低限度,因为这会影响推理延迟。

为注意力操作使用合并投影

SDXL 中使用的 UNet 和 VAE 都使用了类 Transformer 块。Transformer 块由注意力块和前馈块组成。

在注意力块中,输入通过三个不同的投影矩阵(Q、K 和 V)被投影到三个子空间。在原始实现中,这些投影是分别在输入上执行的。但我们可以将投影矩阵水平合并为一个矩阵,并一次性执行投影。这增加了输入投影矩阵乘法(matmuls)的大小,并改善了量化的效果(下一节讨论)。

在 Diffusers 中启用这种计算只需一行代码:

pipe.fuse_qkv_projections()

这将使 UNet 和 VAE 的注意力操作利用合并投影。对于交叉注意力层,我们仅合并 key 和 value 矩阵。要了解更多信息,请参阅官方文档 此处。值得注意的是,我们在此内部 利用了 PyTorch 的 scaled_dot_product_attention

这些额外技术将推理延迟从 2.54 秒降低到了 2.52 秒

SDXL Chart

动态 int8 量化

我们选择性地对 UNet 和 VAE 应用 动态 int8 量化。这是因为量化会给模型增加额外的转换开销,我们希望通过更快的矩阵乘法(动态量化)来弥补这一点。如果矩阵乘法太小,这些技术可能会降低性能。

通过实验,我们发现 UNet 和 VAE 中的某些线性层并不能从动态 int8 量化中获益。你可以在 此处 查看过滤这些层的完整代码(下文中称为 dynamic_quant_filter_fn)。

我们利用超轻量级纯 PyTorch 库 torchao 来使用其友好的量化 API:

from torchao.quantization import apply_dynamic_quant

apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)

由于此量化支持仅限于线性层,我们还将合适的逐点卷积(pointwise convolution)层转换为线性层,以最大限度地发挥效益。在使用此选项时,我们还指定了以下编译器标志:

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

为了防止量化带来的任何数值问题,我们以 bfloat16 格式运行所有内容。

以这种方式应用量化将延迟从 2.52 秒降低到了 2.43 秒

SDXL Chart

资源

欢迎查看以下代码库,以复现这些数据并将这些技术扩展到其他文本到图像的扩散系统:

其他链接

其他流水线的改进

我们将这些技术应用于其他流水线以测试我们方法的通用性。以下是我们的发现:

SSD-1B

SSD-1B Chart

Stable Diffusion v1-5

Stable Diffusion v1-5 chart

PixArt-alpha/PixArt-XL-2-1024-MS

值得注意的是,PixArt-Alpha 在反向扩散过程中使用基于 Transformer 的架构作为其去噪器,而不是 UNet。

PixArt-alpha/PixArt-XL-2-1024-MS chart

请注意,对于 Stable Diffusion v1-5 和 PixArt-Alpha,我们没有探索应用动态 int8 量化的最佳形状组合标准。如果采用更好的组合,可能获得更好的数据。

总体而言,我们提出的方法在不降低生成质量的情况下,较基准线提供了实质性的速度提升。此外,我们相信这些方法应该能够与社区中流行的其他优化方法(如 DeepCacheStable Fast 等)互补。

结论与后续步骤

在本文中,我们展示了一系列简单而有效的技术,有助于利用纯 PyTorch 改善文本到图像扩散模型的推理延迟。总结如下:

  • 使用降低的精度来执行计算。
  • 使用缩放点积注意力来高效运行注意力块。
  • 使用带有“max-autotune”的 torch.compile 来改善延迟。
  • 将不同的投影合并在一起以计算注意力。
  • 动态 int8 量化

我们认为,关于如何将量化应用于文本到图像的扩散系统,还有很多值得探索的空间。我们没有详尽地探索 UNet 和 VAE 中的哪些层更倾向于从动态量化中获益。通过更好的量化目标层组合,或许还有进一步提速的机会。

除了以 bfloat16 运行外,我们保持了 SDXL 的文本编码器不动。对其进行优化也可能带来延迟方面的改善。

致谢

感谢 Ollin Boer Bohan,我们在整个基准测试过程中使用了其 VAE,因为它在降低的数值精度下表现出更高的数值稳定性。

感谢 Hugging Face 的 Hugo Larcher 在基础设施方面提供的帮助。