跳转到主要内容
博客

使用 PyTorch 加速生成式 AI:Segment Anything,Fast

作者: 2023 年 11 月 16 日2024 年 11 月 14 日暂无评论

这篇文章是多系列博客的第一部分,重点介绍如何使用纯原生 PyTorch 加速生成式 AI 模型。我们很高兴能分享大量新发布的 PyTorch 性能功能,以及这些功能如何结合使用的实际示例,以展示我们能将 PyTorch 原生性能推向多远。

正如PyTorch 开发者大会 2023上宣布的,PyTorch 团队重写了 Meta 的 Segment Anything (“SAM”) 模型使代码比原始实现快 8 倍,且未损失准确性,所有这些都使用了原生 PyTorch 优化。我们利用了大量新的 PyTorch 功能:

  • Torch.compile:一个用于 PyTorch 模型的编译器
  • GPU 量化:通过降低精度操作来加速模型
  • 缩放点积注意力 (SDPA):内存高效的注意力实现
  • 半结构化 (2:4) 稀疏性:一种 GPU 优化的稀疏内存格式
  • 嵌套张量:将不同大小的非均匀数据(如不同大小的图像)批量处理成单个张量。
  • 带 Triton 的自定义操作符:使用 Triton Python DSL 编写 GPU 操作,并通过自定义操作符注册轻松将其集成到 PyTorch 的各种组件中。

我们鼓励读者从我们在 Github 上的 SAM 实现复制粘贴代码,并在 Github 上向我们提问

A quick glimpse of increasing throughput and decreasing memory overhead

通过我们新发布的 PyTorch 原生功能,吞吐量增加,内存开销减少的快速一瞥。基准测试在 p4d.24xlarge 实例(8x A100)上运行。

SegmentAnything 模型

SAM 是一个用于生成可提示图像掩码的零样本视觉模型。

sam image masks

SAM 架构[在其论文中描述]包含多个基于 Transformer 架构的提示和图像编码器。其中,我们测量了最小和最大的视觉 transformer 主干网的性能:ViT-BViT-H。为简单起见,我们仅显示 ViT-B 模型的跟踪。

优化

下面我们讲述优化 SAM 的故事:分析、识别瓶颈,并将解决这些问题的新功能构建到 PyTorch 中。在此过程中,我们展示了新的 PyTorch 功能:torch.compile、SDPA、Triton 内核、嵌套张量和半结构化稀疏性。以下部分相互渐进式地构建,最终形成我们的 SAM-fast,现已在Github 上提供。我们使用完全 PyTorch 原生工具,通过真实的内核和内存跟踪来激励每个功能,并使用Perfetto UI 可视化这些跟踪。

基线

我们的 SAM 基线是 Facebook Research 的未修改模型,使用 float32 数据类型和批量大小为 1。经过一些初始预热后,我们可以使用PyTorch Profiler 查看内核跟踪。

kernel trace

我们注意到两个有待优化的领域。

首先是长时间调用 aten::index,这是张量索引操作(例如,[])导致的底层调用。虽然 aten::index 实际花费的 GPU 时间相对较低。aten::index 启动了两个内核,并且在两者之间发生了阻塞的 cudaStreamSynchronize。这意味着 CPU 正在等待 GPU 完成处理,直到它启动第二个内核。为了优化 SAM,我们应该致力于消除导致空闲时间的阻塞性 GPU 同步。

其次是 GPU 在矩阵乘法上花费了大量时间(上图 stream 7 7 上的深绿色)。这在 Transformer 中很常见。如果我们能减少 GPU 在矩阵乘法上花费的时间,我们可以显著加速 SAM。

我们可以测量开箱即用 SAM 的吞吐量(img/s)和内存开销(GiB),以建立基线。

throughput (img/s) and memory overhead (GiB) from out of the box SAM

Bfloat16 半精度(+GPU 同步和批处理)

为了解决矩阵乘法时间较短的第一个问题,我们可以转向bfloat16。Bfloat16 是一种常用的半精度类型。通过降低每个参数和激活的精度,我们可以在计算中节省大量时间和内存。在降低参数精度的同时,验证端到端模型准确性至关重要。

replacing padding dtypes with half precision, bfloat16

此处显示了一个将填充数据类型替换为半精度 bfloat16 的示例。代码在此

除了简单地设置 model.to(torch.bfloat16) 之外,我们还必须更改一些假定默认数据类型的小地方。

现在,为了消除 GPU 同步,我们需要审计导致它们的各种操作。我们可以通过在 GPU 跟踪中搜索对 cudaStreamSynchronize 的调用来找到这些代码片段。实际上,我们找到了两个可以重写为无同步的位置。

code sample 1
replacing padding dtypes with half precision, bfloat16

具体来说,我们看到在 SAM 的图像编码器中,有变量充当坐标缩放器,q_coords 和 k_coords。这些变量都在 CPU 上分配和处理。然而,一旦这些变量用于在 rel_pos_resized 中进行索引,索引操作会自动将这些变量移动到 GPU。这种复制会导致我们上面观察到的 GPU 同步。我们注意到 SAM 的提示编码器中对索引的第二次调用:我们可以使用 torch.where 重写它,如上所示。

内核跟踪

应用这些更改后,我们开始看到各个内核调用之间存在显著的时间。这通常在批次大小较小(此处为 1)时观察到,这是由于启动内核的 GPU 开销。为了更仔细地查看实际优化区域,我们可以开始以批次大小为 8 对 SAM 推理进行分析。

profile SAM inference with batch size 8

查看每个内核花费的时间,我们发现 SAM 的 GPU 时间大部分花在元素级内核和 softmax 操作上。通过这些,我们现在看到矩阵乘法的相对开销变得小得多。

matrix multiplications have become a much smaller relative overhead

综合 GPU 同步和 bfloat16 优化,我们现在已将 SAM 性能提升了多达 3 倍。

SAM performance by up to 3x

Torch.compile(+图中断和 CUDA 图)

当观察到大量小操作(例如上面分析的元素级内核)时,转向编译器来融合操作可以带来显著的好处。PyTorch 最近发布的 torch.compile 在以下方面做得很好:

  1. 将一系列操作(如 nn.LayerNorm 或 nn.GELU)融合到一个被调用的单个 GPU 内核中,并且
  2. 尾声:融合紧随矩阵乘法内核的操作,以减少 GPU 内核调用的数量。

通过这些优化,我们减少了 GPU 全局内存往返的次数,从而加速了推理。我们现在可以在 SAM 的图像编码器上尝试 torch.compile。为了最大化性能,我们使用了一些高级编译技术,例如:

  • 使用 torch.compile 的 max-autotune 模式可以启用CUDA 图和带有自定义尾声的特定形状内核。
  • 通过设置 TORCH_LOGS=”graph_breaks,recompiles”,我们可以手动验证我们没有遇到图中断或重新编译。
  • 用零填充编码器输入图像的批次,确保编译器接受静态形状,从而能够始终使用带有自定义尾声的特定形状优化内核,而无需重新编译。
predictor.model.image_encoder = \
    torch.compile(predictor.model.image_encoder, mode=use_compile)

内核跟踪

Kernel trace

torch.compile 工作得非常好。我们启动了一个 CUDA 图,它在计时区域内占据了 GPU 时间的很大一部分。让我们再次运行我们的配置文件,看看 GPU 时间花在特定内核上的百分比。

the percentage of GPU time spent in specific kernels

我们现在看到 softmax 占据了大部分时间,其次是各种 GEMM 变体。总而言之,我们观察到批次大小为 8 及以上更改的以下测量结果。

measurements for batch size 8 and above

SDPA:scaled_dot_product_attention

接下来,我们可以解决 Transformer 性能开销最常见的领域之一:注意力机制。朴素的注意力实现与序列长度在时间上和内存上呈二次方增长。PyTorch 的scaled_dot_product_attention 操作基于 Flash AttentionFlashAttentionV2xFormer 的内存高效注意力 的原理构建,可以显著加速 GPU 注意力。结合 torch.compile,此操作允许我们表达和融合 MultiheadAttention 变体中的常见模式。经过一小组更改,我们可以使模型适应使用 scaled_dot_product_attention。

PyTorch native attention implementation

PyTorch 原生注意力实现,代码在此处查看

内核跟踪

我们现在可以看到,特别是内存高效的注意力内核在 GPU 上占用了大量的计算时间。

memory efficient attention kernel is taking up a large amount of computational time on the GPU

使用 PyTorch 原生的 scaled_dot_product_attention,我们可以显著增加批次大小。我们现在观察到批次大小为 32 及以上更改的以下测量结果。

batch size 32 and above

Triton:用于融合相对位置编码的自定义 SDPA

暂时离开推理吞吐量,我们开始分析 SAM 的整体内存。在图像编码器中,我们发现内存分配出现显著峰值。

spikes in memory allocation

放大来看,我们看到此分配发生在 add_decomposed_rel_pos 中,在以下行:

we see this allocation happens within add_decomposed_rel_pos

这里的 `attn` 变量是两个较小张量的和:形状为 (B, q_h, q_w, k_h, 1) 的 `rel_h` 和形状为 (B, q_h, q_w, 1, k_w) 的 `rel_w`。

注意力偏置大小超过 3.0GiB 时,内存高效注意力内核(通过 SDPA 使用)需要很长时间也就不足为奇了。如果我们不分配这个巨大的 `attn` 张量,而是将两个较小的 `rel_h` 和 `rel_w` 张量线程化到 SDPA 中,并且只在需要时构造 `attn`,我们预计会获得显著的性能提升。

不幸的是,这不是一个简单的修改;SDPA 内核是高度优化的,并且是用 CUDA 编写的。我们可以转向 Triton,其易于理解和使用的FlashAttention 实现教程。经过大量的深入研究并与 xFormer 的 Daniel Haziza 密切合作,我们发现了一种输入形状的情况,在这种情况下,实现内核的融合版本相对简单。这些详细信息已添加到存储库中。令人惊讶的是,对于推理情况,这可以在 350 行代码内完成。

这是一个很好的例子,说明如何使用 Triton 代码轻松地扩展 PyTorch,增加新的内核。

内核跟踪

kernel trace

通过我们自定义的位置 Triton 内核,我们观察到批次大小为 32 时的以下测量结果。

we observe the following measurements for batch size 32

NT:NestedTensor 和批量 predict_torch

我们已经在图像编码器上花费了大量时间。这是有道理的,因为它占据了大部分计算时间。然而,到目前为止,它已经得到了很好的优化,而最耗时的操作需要大量的额外投资才能改进。

我们对掩码预测流水线有了一个有趣的发现:对于每个图像,都有一个相关的 `size`、`coords` 和 `fg_labels` 张量。这些张量中的每一个都具有不同的批次大小。每个图像本身也具有不同的大小。这种数据表示类似于不规则数组。借助 PyTorch 最近发布的NestedTensor,我们可以修改数据流水线,将 `coords` 和 `fg_labels` 张量批量处理成一个 NestedTensor。这可以为图像编码器之后的提示编码器和掩码解码器带来显著的性能优势。调用

torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged)

内核跟踪

Kernel trace
we can launch kernels much faster from the CPU than the GPU can process

我们现在可以看到,CPU 启动内核的速度比 GPU 处理内核的速度快得多,并且在我们的计时区域结束时,CPU 会长时间等待 GPU 完成处理(cudaDeviceSynchronize)。我们也不会在 GPU 上看到内核之间有任何空闲时间(空白)。

使用 Nested Tensor,我们观察到批次大小为 32 及以上更改的以下测量结果。

batch size 32 and above changes

int8:量化和近似矩阵乘法

我们注意到在上面的跟踪中,现在大量时间花费在 GEMM 内核中。我们已经优化到足以让矩阵乘法在推理中占用的时间超过缩放点积注意力。

在从 fp32 到 bfloat16 的早期学习基础上,我们再进一步,通过 int8 量化模拟更低的精度。在量化方法中,我们专注于动态量化,其中模型观察层可能输入和权重的范围,并细分可表达的 int8 范围以均匀地“分散”观察到的值。最终,每个浮点输入都将被映射到 [-128, 127] 范围内的单个整数。有关更多信息,请参阅 PyTorch 的量化教程

降低精度可以立即带来峰值内存节省,但要实现推理加速,我们必须通过 SAM 的操作充分利用 int8。这需要构建一个高效的 int8@int8 矩阵乘法内核,以及从高精度到低精度(量化)以及从低精度到高精度(反量化)的转换逻辑。利用 torch.compile 的强大功能,我们可以将这些量化和反量化例程编译并融合到高效的单个内核和矩阵乘法的尾声中。生成的实现相当短,不到 250 行代码。有关 API 和用法的更多信息,请参阅pytorch-labs/ao

虽然在推理时量化模型通常会导致一些精度回归,但 SAM 对低精度推理特别健壮,精度损失极小。添加量化后,我们现在观察到 批次大小 32 及以上更改的以下测量结果。

batch size 32 and above changes

稀疏:半结构化 (2:4) 稀疏性

矩阵乘法仍然是我们的瓶颈。我们可以转向模型加速策略,采用另一种经典的近似矩阵乘法方法:稀疏化。通过稀疏化我们的矩阵(即,将值置零),我们理论上可以使用更少的比特来存储权重和激活张量。我们决定将张量中哪些权重置零的过程称为剪枝。剪枝背后的思想是,权重张量中的小权重对层(通常是权重与激活的乘积)的净输出贡献很小。剪除小权重可以潜在地减小模型大小,而不会显著损失准确性。

剪枝方法多种多样,从完全非结构化(贪婪地剪枝权重)到高度结构化(一次剪枝张量的大子组件)。方法的选择并非易事。虽然非结构化剪枝在理论上对准确性的影响最小,但 GPU 在乘以大型密集矩阵时效率很高,在稀疏情况下可能会遭受显著的性能下降。PyTorch 中支持的一种最新剪枝方法试图在两者之间取得平衡,称为半结构化(或 2:4)稀疏性。这种稀疏存储将原始张量显著减少了 50%,同时产生了一个密集张量输出,可以利用高性能的 2:4 GPU 内核。请看下图以进行说明。

dense tensor output that can leverage highly performant, 2:4 GPU kernels

来自developer.nvidia.com/blog/exploiting-ampere-structured-sparsity-with-cusparselt

为了使用这种稀疏存储格式和相关的快速内核,我们需要对权重进行剪枝,使其符合格式的约束。我们在 1x4 区域中选择两个最小的权重进行剪枝,测量性能与精度之间的权衡。将权重从其默认的 PyTorch(“跨步”)布局更改为这种新的半结构化稀疏布局很容易。要实现 apply_sparse(model),我们只需要 32 行 Python 代码。

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor

# Sparsity helper functions
def apply_fake_sparsity(model):
    """
    This function simulates 2:4 sparsity on all linear layers in a model.
    It uses the torch.ao.pruning flow.
    """
    # torch.ao.pruning flow
    from torch.ao.pruning import WeightNormSparsifier
    sparse_config = []
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            sparse_config.append({"tensor_fqn": f"{name}.weight"})

    sparsifier = WeightNormSparsifier(sparsity_level=1.0,
                                      sparse_block_shape=(1,4),
                                      zeros_per_block=2)
    sparsifier.prepare(model, sparse_config)
    sparsifier.step()

    sparsifier.step()
    sparsifier.squash_mask()


def apply_sparse(model):
    apply_fake_sparsity(model)
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

通过 2:4 稀疏性,我们观察到 SAM 在 vit_b 和批次大小为 32 时达到峰值性能。

With 2:4 sparsity, we observe peak performance on SAM with vit_b and batch size 32

总结

总结一下,我们很高兴宣布了迄今为止最快的 Segment Anything 实现。我们使用大量新发布的功能,用纯 PyTorch 重写了 Meta 的原始 SAM,且未损失准确性。

  • Torch.compile PyTorch 的原生 JIT 编译器,提供快速、自动化的 PyTorch 操作融合 [教程]
  • GPU 量化 通过降低精度操作加速模型 [API]
  • 缩放点积注意力 (SDPA) 一种新的、内存高效的注意力实现 [教程]
  • 半结构化 (2:4) 稀疏性 使用更少的比特存储权重和激活来加速模型 [教程]
  • 嵌套张量 高度优化的锯齿状数组处理,适用于非均匀批次和图像大小 [教程]
  • Triton 内核。自定义 GPU 操作,通过 Triton 轻松构建和优化。

有关如何重现此博客文章中数据的更多详细信息,请查看segment-anything-fast 的 experiments 文件夹。如果您遇到任何技术问题,请随时与我们联系或提出问题

在我们的下一篇文章中,我们很高兴能分享我们 PyTorch 原生 LLM 类似的性能提升!

致谢

我们感谢 Meta 的xFormers团队,包括 Daniel Haziza 和 Francisco Massa,他们编写了 SDPA 内核并帮助我们设计了定制的一次性 Triton 内核。