这篇博文是一个多系列博客的第一部分,重点介绍如何使用纯原生 PyTorch 加速生成式 AI 模型。我们很高兴能分享一系列新发布的 PyTorch 性能功能,并结合实际示例,展示如何将这些功能结合起来,以最大限度地发挥 PyTorch 的原生性能。
正如在2023 年 PyTorch 开发者大会上宣布的那样,PyTorch 团队重写了 Meta 的 Segment Anything (“SAM”) 模型,代码速度比原始实现快 8 倍,且没有精度损失,所有这些都使用了原生 PyTorch 优化。我们利用了一系列新的 PyTorch 功能:
- Torch.compile:PyTorch 模型的编译器
- GPU 量化:通过降低精度操作来加速模型
- Scaled Dot Product Attention (SDPA):内存高效的注意力实现
- 半结构化 (2:4) 稀疏性:一种 GPU 优化的稀疏内存格式
- 嵌套张量:将不同大小的数据(例如不同大小的图像)批量打包到一个张量中。
- 使用 Triton 的自定义运算符:使用 Triton Python DSL 编写 GPU 操作,并通过自定义运算符注册轻松将其集成到 PyTorch 的各种组件中。
我们鼓励读者从我们在 Github 上的 SAM 实现中复制粘贴代码,并在 Github 上向我们提问。

我们新发布的 PyTorch 原生功能可以快速提高吞吐量并减少内存开销。基准测试在 p4d.24xlarge 实例(8x A100)上运行。
SegmentAnything 模型
SAM 是一个零样本视觉模型,用于生成可提示的图像掩码。

SAM 架构[在其论文中描述]包含多个基于 Transformer 架构的提示和图像编码器。其中,我们测量了最小和最大的视觉 Transformer 主干的性能:ViT-B 和ViT-H。为简单起见,我们只显示 ViT-B 模型的轨迹。
优化
下面我们将讲述 SAM 优化的故事:分析、识别瓶颈,并在 PyTorch 中构建解决这些问题的新功能。在此过程中,我们展示了我们的新 PyTorch 功能:torch.compile、SDPA、Triton 内核、嵌套张量和半结构化稀疏性。以下部分相互逐步构建,最后是我们现在在 Github 上可用的 SAM-fast。我们使用真实内核和内存轨迹,以及完全 PyTorch 原生工具来激励每个功能,并使用Perfetto UI 可视化这些轨迹。
基线
我们的 SAM 基线是 Facebook Research 的未修改模型,使用 float32 dtype 和批大小为 1。经过一些初始预热后,我们可以使用PyTorch Profiler 查看内核轨迹。

我们注意到有两个领域有待优化。
第一个是对 aten::index 的长时间调用,它是张量索引操作(例如 [])导致的底层调用。虽然在 aten::index 上花费的实际 GPU 时间相对较低。aten::index 正在启动两个内核,并且中间发生阻塞的 cudaStreamSynchronize。这意味着 CPU 正在等待 GPU 完成处理,然后才启动第二个内核。为了优化 SAM,我们应该致力于消除导致空闲时间的阻塞 GPU 同步。
第二个是在矩阵乘法上花费了大量 GPU 时间(上面流 7 上的深绿色)。这在 Transformer 中很常见。如果能减少花在矩阵乘法上的 GPU 时间,我们可以显著加快 SAM 的速度。
我们可以测量开箱即用的 SAM 的吞吐量(图像/秒)和内存开销(GiB),以建立基线。

Bfloat16 半精度(+GPU 同步和批处理)
为了解决矩阵乘法时间减少的第一个问题,我们可以转向bfloat16。Bfloat16 是一种常用的半精度类型。通过降低每个参数和激活的精度,我们可以显着节省计算时间和内存。在降低参数精度的同时,验证端到端模型精度至关重要。

此处显示了一个用半精度 bfloat16 替换填充 dtype 的示例。代码在此处。
接下来,除了简单地设置`model.to(torch.bfloat16)`之外,我们还需要更改一些假设默认数据类型的小地方。
现在,为了消除 GPU 同步,我们需要审计导致它们的操作。我们可以通过在 GPU 轨迹中搜索对 `cudaStreamSynchronize` 的调用来找到这些代码片段。实际上,我们找到了两个可以重写为无同步的位置。


具体来说,我们看到在 SAM 的图像编码器中,存在充当坐标缩放器的变量,q_coords 和 k_coords。这些都在 CPU 上分配和处理。然而,一旦这些变量用于在 rel_pos_resized 中进行索引,索引操作会自动将这些变量移动到 GPU。这种复制会导致我们上面观察到的 GPU 同步。我们注意到 SAM 的提示编码器中对 index 的第二次调用:我们可以使用 torch.where 来重写它,如上所示。
内核跟踪
应用这些更改后,我们开始看到单个内核调用之间有显著的时间间隔。这通常在小批量大小(此处为 1)下观察到,因为 GPU 启动内核的开销。为了更仔细地查看实际的优化区域,我们可以开始以批量大小 8 来分析 SAM 推理。

查看每个内核花费的时间,我们发现 SAM 的大部分 GPU 时间都花在逐元素内核和 softmax 操作上。这样我们现在可以看到矩阵乘法已成为一个相对小得多的开销。

结合 GPU 同步和 bfloat16 优化,我们现在已将 SAM 性能提升了 3 倍。

Torch.compile(+图中断和 CUDA 图)
当观察到大量小操作(例如上面分析的逐元素内核)时,使用编译器融合操作可以带来显著的好处。PyTorch 最新发布的 torch.compile 在以下方面做得很好:
- 将一系列操作(例如 nn.LayerNorm 或 nn.GELU)融合成一个被调用的单个 GPU 内核,并且
- 尾声:融合紧随矩阵乘法内核的操作,以减少 GPU 内核调用的数量。
通过这些优化,我们减少了 GPU 全局内存往返的次数,从而加快了推理速度。我们现在可以在 SAM 的图像编码器上尝试 torch.compile。为了最大限度地提高性能,我们使用了一些高级编译技术,例如
- 使用 torch.compile 的 max-autotune 模式启用CUDA 图和带有自定义尾声的特定形状内核。
- 通过设置 TORCH_LOGS=”graph_breaks,recompiles”,我们可以手动验证我们没有遇到图中断或重新编译。
- 用零填充输入到编码器的图像批次确保编译接受静态形状,从而能够始终使用带有自定义尾声的特定形状优化内核,而无需重新编译。
predictor.model.image_encoder = \
torch.compile(predictor.model.image_encoder, mode=use_compile)
内核跟踪

torch.compile 运行得非常出色。我们启动了一个 CUDA 图,它在计时区域内占据了 GPU 时间的很大一部分。让我们再次运行我们的配置文件,查看特定内核中 GPU 时间的百分比。

我们现在看到 softmax 占据了大部分时间,其次是各种 GEMM 变体。总而言之,我们观察到批量大小为 8 及以上更改的以下测量结果。

SDPA:scaled_dot_product_attention
接下来,我们可以解决 Transformer 性能开销最常见的领域之一:注意力机制。朴素的注意力实现的时间和内存与序列长度呈二次方关系。PyTorch 的scaled_dot_product_attention 操作基于 Flash Attention、FlashAttentionV2 和 xFormer 的内存高效注意力的原理,可以显著加快 GPU 注意力。结合 torch.compile,此操作允许我们表达和融合 MultiheadAttention 变体中的常见模式。经过一小部分更改后,我们可以调整模型以使用 scaled_dot_product_attention。

PyTorch 原生注意力实现,请参阅此处代码。
内核跟踪
我们现在可以看到,特别是内存高效的注意力内核在 GPU 上占用了大量的计算时间。

使用 PyTorch 的原生 scaled_dot_product_attention,我们可以显著增加批处理大小。现在我们观察到批处理大小为 32 及以上更改的以下测量结果。

Triton:用于融合相对位置编码的自定义 SDPA
暂时不谈推理吞吐量,我们开始分析整体 SAM 内存。在图像编码器中,我们看到了内存分配的显著峰值。

放大后,我们看到此分配发生在 add_decomposed_rel_pos 内部,在以下行:

这里的 attn 变量是两个较小张量的和:形状为 (B, q_h, q_w, k_h, 1) 的 rel_h 和形状为 (B, q_h, q_w, 1, k_w) 的 rel_w。
注意力偏差大小超过 3.0GiB 时,内存高效的注意力内核(通过 SDPA 使用)需要很长时间也就不足为奇了。如果我们将两个较小的 rel_h 和 rel_w 张量传入 SDPA,而不是分配这个大的 attn 张量,并且只在需要时构建 attn,我们预计会获得显著的性能提升。
不幸的是,这不是一个简单的修改;SDPA 内核是高度优化的,并且是用 CUDA 编写的。我们可以求助于 Triton,他们提供了易于理解和使用的FlashAttention 实现教程。经过大量的深入研究并与 xFormer 的 Daniel Haziza 密切合作,我们发现了一种输入形状的情况,在这种情况下实现内核的融合版本相对简单。详细信息已添加到存储库中。令人惊讶的是,对于推理情况,这可以在 350 行代码内完成。
这是一个很好的例子,说明如何使用 Triton 代码轻松构建和扩展 PyTorch 的新内核。
内核跟踪

使用我们的自定义位置 Triton 内核,我们观察到批量大小为 32 时的以下测量结果。

NT:NestedTensor 和批处理 predict_torch
我们已经在图像编码器上花费了大量时间。这是有道理的,因为它占用了最多的计算时间。然而,到目前为止,它已经得到了相当好的优化,而占用最多时间的操作符将需要大量的额外投资才能改进。
我们发现掩码预测管道有一个有趣的观察结果:对于每个图像,都有一个相关的尺寸、坐标和 fg_labels 张量。这些张量中的每一个都具有不同的批处理大小。每个图像本身也具有不同的大小。这种数据表示类似于不规则数组。借助 PyTorch 最近发布的NestedTensor,我们可以修改我们的数据管道批处理坐标和 fg_labels 张量,将它们合并到一个 NestedTensor 中。这可以为图像编码器之后的提示编码器和掩码解码器带来显著的性能优势。调用
torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged)
内核跟踪


现在我们可以看到,CPU 启动内核的速度比 GPU 处理的速度快得多,而且在计时区域的末尾,CPU 会长时间等待 GPU 完成(cudaDeviceSynchronize)。我们也没有看到 GPU 内核之间再有空闲时间(空白区域)。
使用嵌套张量,我们观察到批量大小为 32 及以上更改的以下测量结果。

int8:量化和近似矩阵乘法
在上面的跟踪中,我们注意到大量时间现在花在了 GEMM 内核上。我们已经优化到足以让矩阵乘法在推理中比缩放点积注意力占用更多时间。
在前期的学习基础上,从 fp32 到 bfloat16,让我们更进一步,用 int8 量化来模拟更低的精度。在量化方法方面,我们专注于动态量化,其中我们的模型观察层可能输入和权重的范围,并将可表达的 int8 范围细分,以均匀地“展开”观察到的值。最终,每个浮点输入将被映射到 [-128, 127] 范围内的单个整数。有关更多信息,请参阅 PyTorch 的量化教程。
降低精度可以立即带来峰值内存节省,但要实现推理加速,我们必须通过 SAM 的操作充分利用 int8。这需要构建一个高效的 int8@int8 矩阵乘法内核,以及从高精度到低精度(量化)以及从低精度到高精度(反量化)的转换逻辑。利用 torch.compile 的强大功能,我们可以将这些量化和反量化例程编译并融合到高效的单个内核和矩阵乘法的尾部。由此产生的实现相当短,不到 250 行代码。有关 API 和用法的更多信息,请参阅pytorch-labs/ao。
虽然在推理时对模型进行量化通常会看到一些精度下降,但 SAM 对较低精度推理表现出特别的鲁棒性,精度损失最小。添加量化后,我们现在观察到批量大小为 32 及以上更改的以下测量结果。

稀疏:半结构化 (2:4) 稀疏性
矩阵乘法仍然是我们的瓶颈。我们可以转向模型加速手册,使用另一种近似矩阵乘法的经典方法:稀疏化。通过稀疏化我们的矩阵(即,将值置为零),理论上我们可以使用更少的位来存储权重和激活张量。我们决定将张量中的哪些权重置为零的过程称为剪枝。剪枝背后的思想是,权重张量中的小权重对层的净输出贡献很小,通常是权重与激活的乘积。剪枝掉小权重可以潜在地减小模型大小而不会显著损失精度。
剪枝方法多种多样,从完全非结构化(权重被贪婪地剪枝)到高度结构化(张量的大的子组件一次被剪枝)。方法的选择并非易事。虽然非结构化剪枝在理论上对精度的影响最小,但 GPU 在乘法大型密集矩阵方面效率很高,在稀疏状态下可能会遭受显著的性能下降。PyTorch 最近支持的一种剪枝方法旨在达到平衡,称为半结构化(或 2:4)稀疏性。这种稀疏存储将原始张量显著减少了 50%,同时产生了一个可以利用高性能 2:4 GPU 内核的密集张量输出。请参阅下图进行说明。

来自developer.nvidia.com/blog/exploiting-ampere-structured-sparsity-with-cusparselt
为了使用这种稀疏存储格式和相关的快速内核,我们需要修剪权重,使其符合格式约束。我们选择在 1x4 区域内修剪两个最小的权重,测量性能与精度之间的权衡。将权重从其默认 PyTorch(“跨步”)布局更改为这种新的半结构化稀疏布局很容易。要实现 `apply_sparse(model)`,我们只需要 32 行 Python 代码。
import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
# Sparsity helper functions
def apply_fake_sparsity(model):
"""
This function simulates 2:4 sparsity on all linear layers in a model.
It uses the torch.ao.pruning flow.
"""
# torch.ao.pruning flow
from torch.ao.pruning import WeightNormSparsifier
sparse_config = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
sparse_config.append({"tensor_fqn": f"{name}.weight"})
sparsifier = WeightNormSparsifier(sparsity_level=1.0,
sparse_block_shape=(1,4),
zeros_per_block=2)
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.step()
sparsifier.squash_mask()
def apply_sparse(model):
apply_fake_sparsity(model)
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
凭借 2:4 稀疏性,我们观察到 SAM 在 vit_b 和批处理大小为 32 时达到峰值性能。

结论
最后,我们很高兴地宣布了迄今为止最快的 Segment Anything 实现。我们使用一系列新发布的功能,用纯 PyTorch 重写了 Meta 的原始 SAM,且没有精度损失。
- Torch.compile PyTorch 的原生 JIT 编译器,提供快速、自动化的 PyTorch 操作融合 [教程]
- GPU 量化 通过降低精度操作来加速模型 [api]
- Scaled Dot Product Attention (SDPA) 一种新的、内存高效的注意力实现 [教程]
- 半结构化 (2:4) 稀疏性 使用更少的位存储权重和激活来加速模型 [教程]
- 嵌套张量 高度优化的不规则数组处理,用于非均匀批次和图像大小 [教程]
- Triton 内核。自定义 GPU 操作,通过 Triton 轻松构建和优化。
有关如何重现本博客文章中提供的数据的更多详细信息,请查看segment-anything-fast 的 experiments 文件夹。如果您遇到任何技术问题,请随时联系我们或提出问题。
在我们的下一篇文章中,我们很高兴能分享我们用 PyTorch 原生编写的 LLM 所实现的类似性能提升!
致谢
我们要感谢 Meta 的 xFormers 团队,包括 Daniel Haziza 和 Francisco Massa,感谢他们编写 SDPA 内核并帮助我们设计自定义的一次性 Triton 内核。