作者:Sayak Paul 和 Patrick von Platen (Hugging Face 🤗)

本文是多系列博客的第三部分,重点介绍如何使用纯原生PyTorch加速生成式AI模型。我们很高兴分享一系列新发布的PyTorch性能特性以及实际示例,以便了解我们能在多大程度上提升PyTorch原生性能。在第一部分中,我们展示了如何仅使用纯原生PyTorch将Segment Anything加速8倍以上。在第二部分中,我们展示了如何仅使用PyTorch原生优化将Llama-7B加速近10倍。在本博客中,我们将重点介绍如何将文本到图像扩散模型加速多达3倍。

我们将利用一系列优化措施,包括

  • 使用bfloat16精度运行
  • scaled_dot_product_attention (SPDA)
  • torch.compile
  • 组合注意力计算中的q,k,v投影
  • 动态int8量化

我们将主要关注Stable Diffusion XL (SDXL),展示延迟降低3倍的效果。这些技术是PyTorch原生的,这意味着您无需依赖任何第三方库或C++代码即可利用它们。

使用🤗Diffusers库启用这些优化只需几行代码。如果您已经很兴奋并迫不及待想看代码,请在此处查看附带的仓库:https://github.com/huggingface/diffusion-fast

SDXL Chart

(所讨论的技术并非SDXL特有,可用于加速其他文本到图像扩散系统,详见后文。)

下方是一些关于类似主题的博客文章

设置

我们将使用🤗Diffusers库展示优化及其相应的加速效果。此外,我们将使用以下PyTorch原生库和环境

  • Torch nightly(受益于高效注意力计算的最快内核;2.3.0.dev20231218+cu121)
  • 🤗 PEFT(版本:0.7.1)
  • torchao(提交SHA:54bcd5a10d0abbe7b0c045052029257099f83fd9)
  • CUDA 12.1

为了更方便地复现环境,您还可以参考此Dockerfile。本文中提供的基准测试数据来自于一块400W 80GB A100 GPU(时钟频率设置为最大)。

由于此处我们使用A100 GPU(Ampere架构),我们可以指定torch.set_float32_matmul_precision("high")来受益于TF32精度格式

使用降低的精度运行推理

在Diffusers中运行SDXL只需几行代码

from diffusers import StableDiffusionXLPipeline

## Load the pipeline in full-precision and place its model components on CUDA.
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0").to("cuda")

## Run the attention ops without efficiency.
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

但这并不是很实用,因为它需要7.36秒才能生成一张包含30个步骤的图像。这是我们的基准,我们将尝试一步步进行优化。

SDXL Chart

这里,我们以全精度运行流水线。通过使用bfloat16等降低的精度,我们可以立即缩短推理时间。此外,现代GPU配备了专用核心,可受益于降低的精度进行加速计算。要在bfloat16精度下运行流水线计算,我们只需在初始化流水线时指定数据类型

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

## Run the attention ops without efficiency.
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

SDXL Chart

通过使用降低的精度,我们将推理延迟从7.36秒缩短到4.63秒

关于bfloat16使用的一些注意事项

  • 使用降低的数值精度(如float16、bfloat16)运行推理不会影响生成质量,但能显著改善延迟。
  • 使用bfloat16数值精度相对于float16的优势取决于硬件。现代GPU倾向于支持bfloat16。
  • 此外,在我们的实验中,我们发现bfloat16在与量化结合使用时比float16更具鲁棒性。

(我们后来在float16下进行了实验,发现torchao的最新版本在使用float16时不会产生数值问题。)

使用SDPA执行注意力计算

默认情况下,当使用PyTorch 2时,Diffusers使用scaled_dot_product_attention (SDPA)执行注意力相关计算。SDPA提供了更快、更高效的内核来运行密集型注意力相关操作。要使用SDPA运行流水线,我们只需不设置任何注意力处理器,如下所示

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

SDPA带来了显著提升,从4.63秒缩短到3.31秒

SDXL Chart

编译UNet和VAE

我们可以通过使用torch.compile让PyTorch执行一些低级优化(例如算子融合和使用CUDA graphs启动更快内核)。对于StableDiffusionXLPipeline,我们编译去噪器(UNet)和VAE

from diffusers import StableDiffusionXLPipeline
import torch

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

## Compile the UNet and VAE.
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

## First call to `pipe` will be slow, subsequent ones will be faster.
image = pipe(prompt, num_inference_steps=30).images[0]

结合使用SDPA注意力和编译UNet及VAE,将延迟从3.31秒降低到2.54秒

SDXL Chart

关于torch.compile的注意事项

torch.compile提供不同的后端和模式。由于我们追求最大的推理速度,我们选择使用“max-autotune”模式的inductor后端。“max-autotune”使用CUDA graphs,并专门针对延迟优化编译图。使用CUDA graphs大大减少了启动GPU操作的开销。它通过一种机制,通过单个CPU操作启动多个GPU操作,从而节省时间。

fullgraph指定为True可以确保底层模型中没有图断点,从而充分发挥torch.compile的潜力。在我们的例子中,显式设置以下编译器标志也很重要

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

有关编译器标志的完整列表,请参阅此文件。

我们还在编译UNet和VAE时将其内存布局更改为“channels_last”,以确保最大速度

pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

在下一节中,我们将展示如何进一步改善延迟。

额外优化

torch.compile期间无图断点

确保底层模型/方法可以完全编译对于性能至关重要(torch.compile配合fullgraph=True)。这意味着不能有图断点。我们通过更改访问返回变量的方式,为UNet和VAE实现了这一点。考虑以下示例

code example

编译后消除GPU同步

在迭代反向扩散过程中,每次去噪器预测出噪声较少的潜在嵌入后,我们都会在调度器上调用step()。在step()内部,sigmas变量会被索引。如果sigmas数组位于GPU上,索引操作会导致CPU和GPU之间的通信同步。这会引起延迟,并且在去噪器已经编译后更加明显。

但是如果sigmas数组始终保留在CPU上(请参阅此行),则不会发生此同步,从而提高了延迟。一般来说,任何CPU <-> GPU通信同步都应该避免或降至最低,因为它会影响推理延迟。

对注意力操作使用组合投影

SDXL中使用的UNet和VAE都使用了类似Transformer的块。一个Transformer块由注意力块和前馈块组成。

在注意力块中,输入通过三个不同的投影矩阵(Q、K和V)投影到三个子空间。在朴素实现中,这些投影是在输入上分别执行的。但是我们可以将这些投影矩阵水平组合成一个矩阵,一次性执行投影。这会增加输入投影的矩阵乘法(matmuls)大小,并提升量化效果(将在下一节讨论)。

在Diffusers中启用这种计算只需一行代码

pipe.fuse_qkv_projections()

这将使UNet和VAE的注意力操作都能利用组合投影。对于交叉注意力层,我们仅组合key和value矩阵。要了解更多信息,您可以参考官方文档此处。值得注意的是,我们在内部利用了PyTorch的scaled_dot_product_attention

这些额外技术将推理延迟从2.54秒改善到2.52秒

SDXL Chart

动态int8量化

我们对UNet和VAE选择性地应用了动态int8量化。这是因为量化会给模型增加额外的转换开销,但希望通过更快的矩阵乘法(动态量化)来弥补。如果矩阵乘法过小,这些技术可能会降低性能。

通过实验,我们发现UNet和VAE中的某些线性层无法从动态int8量化中受益。您可以在此处查看过滤这些层的完整代码(下方称为dynamic_quant_filter_fn)。

我们利用超轻量级纯PyTorch库torchao,使用其用户友好的API进行量化

from torchao.quantization import apply_dynamic_quant

apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)

由于此量化支持仅限于线性层,我们还将合适的逐点卷积层转换为线性层以最大化收益。使用此选项时,我们还指定了以下编译器标志

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

为了防止由量化引起的任何数值问题,我们在bfloat16格式下运行所有内容。

以这种方式应用量化将延迟从2.52秒改善到2.43秒

SDXL Chart

资源

欢迎您查阅以下代码库,以复现这些数据并将技术扩展到其他文本到图像扩散系统

其他链接

其他流水线的改进

我们将这些技术应用于其他流水线,以测试我们方法的通用性。以下是我们的发现

SSD-1B

SSD-1B Chart

Stable Diffusion v1-5

Stable Diffusion v1-5 chart

PixArt-alpha/PixArt-XL-2-1024-MS

值得注意的是,PixArt-Alpha在反向扩散过程中使用基于Transformer的架构作为去噪器,而不是UNet。

PixArt-alpha/PixArt-XL-2-1024-MS chart

请注意,对于Stable Diffusion v1-5和PixArt-Alpha,我们没有探索应用动态int8量化的最佳形状组合标准。通过更好的组合可能会获得更好的结果。

总而言之,我们提出的方法在不降低生成质量的情况下,相对于基线提供了显著的加速。此外,我们相信这些方法应该可以补充社区中流行的其他优化方法(例如DeepCacheStable Fast等)。

结论与下一步

在本文中,我们介绍了一系列简单而有效的技术,可以帮助提高纯PyTorch中文本到图像扩散模型的推理延迟。总结如下

  • 使用降低的精度进行计算
  • 使用scaled-dot product attention高效运行注意力块
  • 使用带有“max-autotune”的torch.compile来改善延迟
  • 组合不同的投影以计算注意力
  • 动态int8量化

我们认为在如何将量化应用于文本到图像扩散系统方面还有很多值得探索的地方。我们没有穷尽式地探究UNet和VAE中哪些层倾向于从动态量化中受益。通过更好地组合目标量化层,可能还有进一步加速的机会。

除了在bfloat16中运行外,我们没有对SDXL的文本编码器进行任何改动。优化它们也可能带来延迟的改善。

致谢

感谢Ollin Boer Bohan,整个基准测试过程中使用了他的VAE,因为它在降低数值精度下更加数值稳定。

感谢来自Hugging Face的Hugo Larcher在基础设施方面的帮助。