作者:PyTorch 团队

本文是多系列博客的第一部分,重点介绍如何使用纯原生 PyTorch 加速生成式 AI 模型。我们很高兴分享一系列新发布的 PyTorch 性能特性以及如何结合这些特性的实际示例,看看我们能将 PyTorch 原生性能推到何种程度。

正如在2023 年 PyTorch 开发者大会上宣布的,PyTorch 团队重写了 Meta 的 Segment Anything (“SAM”) 模型,**代码速度比原始实现快 8 倍**,且准确率无损失,完全使用了原生 PyTorch 优化。我们利用了一系列新的 PyTorch 特性:

我们鼓励读者从我们在 Github 上的 SAM 实现中复制粘贴代码,并在 Github 上向我们提问

A quick glimpse of increasing throughput and decreasing memory overhead

通过我们新发布的 PyTorch 原生特性,快速一览吞吐量的提高和内存开销的降低。基准测试运行在 p4d.24xlarge 实例上(8x A100s)。

SegmentAnything 模型

SAM 是一种用于生成可提示图像掩码的零样本视觉模型。

sam image masks

SAM 架构[在其论文中描述]包含多个基于 Transformer 架构的提示和图像编码器。其中,我们衡量了最小和最大的视觉 Transformer 骨干网络的性能:ViT-BViT-H。为简单起见,我们只展示了 ViT-B 模型的轨迹。

优化

下面我们讲述优化 SAM 的过程:性能分析、识别瓶颈以及在 PyTorch 中构建解决这些问题的新特性。在此过程中,我们展示了我们的新 PyTorch 特性:**torch.compile、SDPA、Triton 核函数、Nested Tensor 和半结构化稀疏性。** 以下各节循序渐进,互相构建,最后是我们的 SAM-fast,现在已在 Github 上可用。我们使用完全 PyTorch 原生工具,并通过真实的核函数和内存轨迹来解释每个特性的动机,并使用Perfetto UI可视化这些轨迹。

基准线

我们的 SAM 基准线是 Facebook Research 的未修改模型,使用 float32 数据类型和批量大小为 1。在一些初始预热后,我们可以使用PyTorch 性能分析器查看核函数轨迹。

kernel trace

我们注意到两个成熟的优化领域。

第一个是对 aten::index 的长时间调用,它是 Tensor 索引操作(例如 [])产生的底层调用。虽然实际花费在 aten::index 上的 GPU 时间相对较低。aten::index 正在启动两个核函数,并且在两者之间发生阻塞的 cudaStreamSynchronize。这意味着 CPU 正在等待 GPU 完成处理,直到它启动第二个核函数。为了优化 SAM,我们应该旨在消除导致空闲时间的阻塞 GPU 同步。

第二个是在 GPU 上花费大量时间进行矩阵乘法(上面流 7 中的深绿色)。这在 Transformers 中很常见。如果我们能减少花费在矩阵乘法上的 GPU 时间,就能显著加速 SAM。

我们可以衡量开箱即用 SAM 的吞吐量(图/秒)和内存开销(GiB)以建立基准线。

throughput (img/s) and memory overhead (GiB) from out of the box SAM

Bfloat16 半精度(+GPU 同步和批处理)

为了解决矩阵乘法花费时间较少的问题,我们可以转向bfloat16。Bfloat16 是一种常用的半精度类型。通过降低每个参数和激活值的精度,我们可以在计算中节省大量时间和内存。在降低参数精度的同时,验证端到端模型的准确性至关重要。

replacing padding dtypes with half precision, bfloat16

此处展示了用半精度 bfloat16 替换填充数据类型的示例。代码在此处

除了简单地设置 model.to(torch.bfloat16) 外,我们还必须更改一些假定默认数据类型的小地方。

现在,为了移除 GPU 同步,我们需要审查导致它们的运算。我们可以通过搜索 GPU 轨迹中的 cudaStreamSynchronize 调用来找到这些代码片段。事实上,我们找到了两个位置,成功地重写为无同步操作。

code sample 1

replacing padding dtypes with half precision, bfloat16

具体来说,我们看到在 SAM 的图像编码器中,有一些变量用作坐标缩放器,q_coords 和 k_coords。这些变量都在 CPU 上分配和处理。然而,一旦这些变量用于在 rel_pos_resized 中进行索引,索引操作会自动将这些变量移动到 GPU。这种复制导致了我们上面观察到的 GPU 同步。我们在 SAM 的提示编码器中注意到了对索引的第二次调用:我们可以使用 torch.where 重写此处,如上所示。

核函数轨迹

应用这些更改后,我们开始看到单个核函数调用之间有大量时间。这通常在批量大小较小(此处为 1)时观察到,这是由于启动核函数的 GPU 开销造成的。为了更仔细地查看实际优化领域,我们可以开始以批量大小 8 对 SAM 推理进行性能分析。

profile SAM inference with batch size 8

查看每个核函数花费的时间,我们观察到 SAM 的大部分 GPU 时间花费在逐元素核函数和 softmax 运算上。由此可见,矩阵乘法已成为相对较小的开销。

matrix multiplications have become a much smaller relative overhead

结合 GPU 同步和 bfloat16 优化,我们现在将 SAM 性能提升了高达 3 倍。

SAM performance by up to 3x

Torch.compile(+图中断和 CUDA 图)

当观察到大量小型运算时(如上面分析的逐元素核函数),转向编译器融合运算可以带来显著的好处。PyTorch 最新发布的 **torch.compile** 在优化方面做得非常好,通过:

  1. 融合 nn.LayerNorm 或 nn.GELU 等一系列运算到单个 GPU 核函数中进行调用
  2. 后处理融合:融合紧随矩阵乘法核函数之后的运算,以减少 GPU 核函数调用的数量。

通过这些优化,我们减少了 GPU 全局内存往返次数,从而加速了推理。我们现在可以在 SAM 的图像编码器上尝试 torch.compile。为最大化性能,我们使用了一些高级编译技术,例如:

  • 使用 torch.compile 的 max-autotune 模式启用CUDA 图和带有自定义后处理融合的形状特定核函数。
  • 通过设置 TORCH_LOGS=”graph_breaks,recompiles”,我们可以手动验证我们没有遇到图中断或重新编译。
  • 对输入到编码器的图像批量用零进行填充,确保编译接受静态形状,从而能够始终使用带有自定义后处理融合的形状特定的优化核函数,而无需重新编译。
predictor.model.image_encoder = \
    torch.compile(predictor.model.image_encoder, mode=use_compile)

核函数轨迹

Kernel trace

torch.compile 工作得非常好。我们启动了一个单独的 CUDA 图,它在计时区域内占 GPU 时间的很大一部分。让我们再次运行性能分析,查看花费在特定核函数上的 GPU 时间百分比。

the percentage of GPU time spent in specific kernels

我们现在看到 softmax 占了很大一部分时间,紧随其后的是各种 GEMM 变体。总之,对于批量大小为 8 及以上更改,我们观察到以下测量结果。

measurements for batch size 8 and above

SDPA: scaled_dot_product_attention

接下来,我们可以解决 Transformer 性能开销中最常见的领域之一:Attention 机制。朴素的 Attention 实现的计算时间和内存会随序列长度呈二次方增长。PyTorch 基于 Flash AttentionFlashAttentionV2xFormer 的内存高效 Attention 原理构建的 scaled_dot_product_attention 运算可以显著加速 GPU Attention。结合 torch.compile,此运算允许我们表达并融合 MultiheadAttention 变体中的一个常见模式。在一小组更改后,我们可以调整模型以使用 scaled_dot_product_attention。

PyTorch native attention implementation

PyTorch 原生的 Attention 实现,在此处查看代码

核函数轨迹

我们现在可以看到,特别是内存高效 Attention 核函数占用了 GPU 上大量的计算时间。

memory efficient attention kernel is taking up a large amount of computational time on the GPU

使用 PyTorch 原生的 scaled_dot_product_attention,我们可以显著增加批量大小。我们现在观察到批量大小为 32 及以上更改的以下测量结果。

batch size 32 and above

Triton: 用于融合相对位置编码的自定义 SDPA

暂时从推理吞吐量转向,我们开始对整个 SAM 内存进行性能分析。在图像编码器中,我们看到了内存分配的显著峰值。

spikes in memory allocation

放大查看,我们看到此分配发生在 add_decomposed_rel_pos 中,在以下行:

we see this allocation happens within add_decomposed_rel_pos

这里的 attn 变量是两个较小张量的相加:形状为 (B, q_h, q_w, k_h, 1) 的 rel_h 和形状为 (B, q_h, q_w, 1, k_w) 的 rel_w。

内存高效 Attention 核函数(通过 SDPA 使用)在 Attention 偏差大小超过 3.0GiB 的情况下花费很长时间,这并不奇怪。如果不是分配这个大型 attn 张量,而是将两个较小的 rel_h 和 rel_w 张量传入 SDPA,只在需要时构建 attn,我们预期会带来显著的性能提升。

不幸的是,这不是一个简单的修改;SDPA 核函数经过高度优化并使用 CUDA 编写。我们可以转向 Triton,它提供易于理解和使用的FlashAttention 实现教程。经过大量深入研究并与 xFormer 的 Daniel Haziza 密切协作,我们发现了一种输入形状的情况,相对直接实现核函数的融合版本。详细信息已添加到仓库。令人惊讶的是,对于推理情况,这可以在不到 350 行代码内完成。

这是扩展 PyTorch 的一个很好的例子,使用 Triton 代码直接构建了一个新核函数。

核函数轨迹

kernel trace

使用我们自定义的位置 Triton 核函数,我们观察到批量大小为 32 的以下测量结果。

we observe the following measurements for batch size 32

NT: NestedTensor 和对 predict_torch 的批处理

我们在图像编码器上花费了大量时间。这是有道理的,因为它占用了最多的计算时间。然而,此时它已经相当优化,花费时间最多的算子需要大量额外投入才能改进。

我们对掩码预测流水线发现了一个有趣的观察结果:对于我们有的每张图像,都有关联的 size、coords 和 fg_labels 张量。这些张量每个的批量大小都不同。每张图像本身尺寸也不同。这种数据表示看起来像Jagged Data(锯齿数组)。使用 PyTorch 最新发布的NestedTensor,我们可以修改我们的数据流水线,将 coords 和 fg_labels 张量批处理到单个 NestedTensor 中。这可以为图像编码器之后的提示编码器和掩码解码器带来显著的性能优势。调用

torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged)

核函数轨迹

Kernel trace

we can launch kernels much faster from the CPU than the GPU can process

我们现在可以看到,我们可以从 CPU 启动核函数的速度远快于 GPU 处理速度,并且在计时区域末尾花费很长时间等待 GPU 完成 (cudaDeviceSynchronize)。在 GPU 上的核函数之间,我们也不再看到任何空闲时间(空白区域)。

使用 Nested Tensor,我们观察到批量大小为 32 及以上更改的以下测量结果。

batch size 32 and above changes

int8:量化和逼近矩阵乘法

我们注意到在上面的轨迹中,现在有大量时间花费在 GEMM 核函数中。我们优化得足够好,现在看到矩阵乘法在推理中占用的时间比 scaled dot product attention 更多。

基于之前从 fp32 到 bfloat16 的学习,让我们更进一步,通过 int8 量化模拟更低精度。查看量化方法,我们专注于动态量化,模型在此过程中观察层中可能的输入和权重的范围,并细分可表达的 int8 范围,以均匀“展开”观测值。最终每个浮点输入将被映射到范围 [-128, 127] 中的一个整数。更多信息请参见 PyTorch 的量化教程

降低精度可立即带来峰值内存节省,但要实现推理加速,我们必须通过 SAM 的运算充分利用 int8。这需要构建一个高效的 int8@int8 矩阵乘法核函数,以及将高精度转换为低精度(量化)以及将低精度反转回高精度(反量化)的类型转换逻辑。利用 torch.compile 的力量,我们可以编译并融合这些量化和反量化例程到高效的单个核函数和矩阵乘法的后处理融合中。由此产生的实现相当短,不到 250 行代码。关于 API 和用法的更多信息,请参见pytorch-labs/ao

虽然常见看到在推理时量化模型会导致一些准确率下降,但 SAM 对于低精度推理特别鲁棒,准确率损失极小。添加量化后,我们现在观察到**批量大小为 32** 及以上更改的以下测量结果。

batch size 32 and above changes

sparse:半结构化(2:4)稀疏性

矩阵乘法仍然是我们的瓶颈。我们可以转向模型加速秘籍,使用另一种经典方法来逼近矩阵乘法:稀疏化。通过稀疏化我们的矩阵(即,将值置零),理论上可以使用更少的位存储权重和激活张量。我们决定将张量中的哪些权重置零的过程称为剪枝。剪枝背后的想法是,权重张量中的小权重对层的净输出(通常是权重与激活值的乘积)贡献很小。剪掉小权重可能在不显著损失准确率的情况下减少模型大小。

剪枝方法多种多样,从完全非结构化(权重被贪婪地剪枝)到高度结构化(张量的大型子组件被一次性剪枝)。方法的选择并非易事。虽然非结构化剪枝对准确率的理论影响最小,但 GPU 在乘大型密集矩阵方面也非常高效,在稀疏模式下可能遭受显著的性能下降。PyTorch 支持的一种最近的剪枝方法力求平衡,称为半结构化(或 2:4)稀疏性。这种稀疏存储将原始张量显著减少 50%,同时产生可以利用高性能 2:4 GPU 核函数的密集张量输出。参见下图进行说明。

dense tensor output that can leverage highly performant, 2:4 GPU kernels

来源:developer.nvidia.com/blog/exploiting-ampere-structured-sparsity-with-cusparselt

为了使用这种稀疏存储格式及相关的快速核函数,我们需要修剪我们的权重,使其符合该格式的约束。我们在 1x4 区域中选择最小的两个权重进行剪枝,衡量性能与准确率之间的权衡。将权重从其默认的 PyTorch(“跨步”)布局更改为此新的半结构化稀疏布局很容易。要实现 apply_sparse(model),我们只需要 32 行 Python 代码。

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor

# Sparsity helper functions
def apply_fake_sparsity(model):
    """
    This function simulates 2:4 sparsity on all linear layers in a model.
    It uses the torch.ao.pruning flow.
    """
    # torch.ao.pruning flow
    from torch.ao.pruning import WeightNormSparsifier
    sparse_config = []
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            sparse_config.append({"tensor_fqn": f"{name}.weight"})

    sparsifier = WeightNormSparsifier(sparsity_level=1.0,
                                      sparse_block_shape=(1,4),
                                      zeros_per_block=2)
    sparsifier.prepare(model, sparse_config)
    sparsifier.step()

    sparsifier.step()
    sparsifier.squash_mask()


def apply_sparse(model):
    apply_fake_sparsity(model)
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

使用 2:4 稀疏性,我们在 SAM 上使用 vit_b 和批量大小 32 时观察到峰值性能。

With 2:4 sparsity, we observe peak performance on SAM with vit_b and batch size 32

结论

总结来说,我们很高兴宣布了迄今为止最快的 Segment Anything 实现。我们使用一系列新发布的特性,在纯 PyTorch 中重写了 Meta 的原始 SAM,且准确率无损失。

  • Torch.compile PyTorch 的原生 JIT 编译器,提供快速、自动化的 PyTorch 运算融合 [教程]
  • GPU 量化 通过低精度运算加速模型 [api]
  • Scaled Dot Product Attention (SDPA) 一种新的内存高效的 Attention 实现 [教程]
  • 半结构化(2:4)稀疏性 使用更少的位存储权重和激活值来加速模型 [教程]
  • Nested Tensor 用于处理非均匀批量和图像尺寸的高度优化的锯齿数组处理 [教程]
  • Triton 核函数。自定义 GPU 运算,通过 Triton 轻松构建和优化

关于如何重现本文数据点的更多详细信息,请查看segment-anything-fast 的 experiments 文件夹。如果您遇到任何技术问题,请随时联系我们或开启议题

在下一篇博文中,我们很高兴分享 PyTorch 原生编写的 LLM 获得的类似性能提升!

致谢

我们要感谢 Meta 的 xFormers 团队,包括 Daniel Haziza 和 Francisco Massa,感谢他们编写 SDPA 核函数并帮助我们设计自定义的一次性 Triton 核函数。