本文是多系列博客的第一部分,重点介绍如何使用纯原生 PyTorch 加速生成式 AI 模型。我们很高兴分享一系列新发布的 PyTorch 性能特性,并提供实用示例,说明如何组合这些特性以了解我们能够将 PyTorch 原生性能提升到何种程度。
正如在 2023 年 PyTorch 开发者大会上宣布的那样,PyTorch 团队重写了 Meta 的 Segment Anything (“SAM”) 模型,代码速度比原始实现快 8 倍,且精度没有损失,全部使用原生 PyTorch 优化。我们利用了一系列新的 PyTorch 特性:
- Torch.compile:PyTorch 模型的编译器
- GPU 量化:通过降低精度运算加速模型
- 缩放点积注意力 (SDPA):内存高效的注意力实现
- 半结构化 (2:4) 稀疏性: GPU 优化的稀疏内存格式
- 嵌套张量: 将大小不均匀的数据(例如不同尺寸的图像)批量处理成单个张量。
- 使用 Triton 的自定义算子: 使用 Triton Python DSL 编写 GPU 运算,并通过自定义算子注册轻松将其集成到 PyTorch 的各种组件中。
我们鼓励读者从我们在 Github 上的 SAM 实现中复制粘贴代码,并在 Github 上向我们提问。
快速了解使用我们新发布的 PyTorch 原生特性提高吞吐量和降低内存开销的情况。基准测试在 p4d.24xlarge 实例 (8x A100) 上运行。
SegmentAnything 模型
SAM 是一种用于生成可提示图像掩码的零样本视觉模型。
SAM 架构[在其论文中描述]包括基于 Transformer 架构的多个提示和图像编码器。 在此基础上,我们测量了最小和最大的视觉 Transformer 主干的性能:ViT-B 和 ViT-H。为了简单起见,我们仅显示 ViT-B 模型的跟踪。
优化
下面我们讲述优化 SAM 的故事:性能分析、识别瓶颈以及在 PyTorch 中构建解决这些问题的新特性。 在整个过程中,我们将展示我们新的 PyTorch 特性:torch.compile、SDPA、Triton 内核、嵌套张量和半结构化稀疏性。 以下各节在彼此的基础上逐步构建,最终形成我们的 SAM-fast,现在已在 Github 上提供。 我们使用真实的内核和内存跟踪来论证每个特性的必要性,使用完全 PyTorch 原生的工具,并使用 Perfetto UI 可视化这些跟踪。
基线
我们的 SAM 基线是 Facebook Research 的未修改模型,使用 float32 dtype 和批量大小 1。在进行一些初始预热后,我们可以使用 PyTorch Profiler 查看内核跟踪
我们注意到两个可以优化的领域。
第一个是对 aten::index 的长时间调用,这是由张量索引操作(例如 [])产生的底层调用。 虽然 aten::index 上花费的实际 GPU 时间相对较短。 aten::index 正在启动两个内核,并且在两者之间发生了阻塞式 cudaStreamSynchronize。 这意味着 CPU 正在等待 GPU 完成处理,然后才启动第二个内核。 为了优化 SAM,我们应该力求消除导致空闲时间的阻塞式 GPU 同步。
第二个是在 GPU 上花费大量时间进行矩阵乘法(上面流 7 7 上的深绿色)。 这在 Transformer 中很常见。 如果我们能够减少在矩阵乘法上花费的 GPU 时间,我们可以显着加速 SAM。
我们可以测量开箱即用的 SAM 的吞吐量(图像/秒)和内存开销 (GiB),以建立基线
Bfloat16 半精度(+GPU 同步和批处理)
为了解决矩阵乘法花费时间较少的第一项问题,我们可以转向 bfloat16。 Bfloat16 是一种常用的半精度类型。 通过降低每个参数和激活的精度,我们可以节省大量的计算时间和内存。 由于降低了参数精度,因此验证端到端模型精度至关重要。
此处展示的是用半精度 bfloat16 替换填充 dtype 的示例。 代码在此。
除了简单地设置 model.to(torch.bfloat16)
之外,我们还需要更改一些假设默认 dtype 的小地方。
现在,为了消除 GPU 同步,我们需要审核导致 GPU 同步的操作。 我们可以通过在 GPU 跟踪中搜索对 cudaStreamSynchronize
的调用来找到这些代码段。 实际上,我们找到了两个可以重写为无同步的位置。
具体而言,我们看到在 SAM 的图像编码器中,存在充当坐标缩放器的变量,q_coords 和 k_coords。 这些变量都在 CPU 上分配和处理。 但是,一旦这些变量用于在 rel_pos_resized 中进行索引,索引操作就会自动将这些变量移动到 GPU。 这种复制会导致我们上面观察到的 GPU 同步。 我们注意到 SAM 的提示编码器中对索引的第二次调用:我们可以使用 torch.where 将其重写为如上所示。
内核跟踪
应用这些更改后,我们开始看到各个内核调用之间的时间间隔显着增加。 这通常在小批量大小(此处为 1)的情况下观察到,原因是启动内核的 GPU 开销。 为了更仔细地了解实际的优化领域,我们可以开始分析批量大小为 8 的 SAM 推理性能
查看每个内核花费的时间,我们观察到 SAM 的大部分 GPU 时间都花费在逐元素内核和 softmax 运算上。 这样,我们现在看到矩阵乘法已成为相对较小的开销。
将 GPU 同步和 bfloat16 优化结合在一起,我们现在已将 SAM 性能提升了高达 3 倍
Torch.compile(+图中断和 CUDA 图)
当观察到大量小操作(例如上面分析的逐元素内核)时,转向编译器来融合操作可能具有很大的好处。 PyTorch 最近发布的 torch.compile 在以下方面做得非常出色:
- 将 nn.LayerNorm 或 nn.GELU 等操作序列融合到单个 GPU 内核中,并调用该内核,以及
- 尾声:融合紧跟矩阵乘法内核之后的操作,以减少 GPU 内核调用的次数。
通过这些优化,我们减少了 GPU 全局内存往返次数,从而加快了推理速度。 我们现在可以在 SAM 的 图像编码器上尝试 torch.compile。 为了最大限度地提高性能,我们使用了一些高级编译技术,例如
- 使用 torch.compile 的 max-autotune 模式启用 CUDA 图 和具有自定义尾声的形状特定内核
- 通过设置 TORCH_LOGS=”graph_breaks,recompiles”,我们可以手动验证我们是否没有遇到图中断或重新编译。
- 用零填充输入到编码器的图像批次可确保编译接受静态形状,从而始终能够使用具有自定义尾声的形状特定优化内核,而无需重新编译。
predictor.model.image_encoder = \
torch.compile(predictor.model.image_encoder, mode=use_compile)
内核跟踪
torch.compile 运行良好。 我们启动单个 CUDA 图,该图在定时区域内占用了 GPU 时间的很大一部分。 让我们再次运行我们的分析,看看特定内核中花费的 GPU 时间百分比
我们现在看到 softmax 占用了很大一部分时间,其次是各种 GEMM 变体。 总之,我们观察到批量大小为 8 及以上更改的以下测量结果。
SDPA:scaled_dot_product_attention
接下来,我们可以解决 Transformer 性能开销最常见的领域之一:注意力机制。 朴素的注意力实现会随着序列长度的增加而在时间和内存上呈二次方扩展。 PyTorch 的 scaled_dot_product_attention 运算建立在 Flash Attention、FlashAttentionV2 和 xFormer 的内存高效注意力的原理之上,可以显着加速 GPU 注意力。 结合 torch.compile,此运算使我们能够表达和融合 MultiheadAttention 变体中的常见模式。 在进行一小组更改后,我们可以调整模型以使用 scaled_dot_product_attention。
PyTorch 原生注意力实现,请参阅此处的代码。
内核跟踪
我们现在可以看到,特别是内存高效的注意力内核占用了 GPU 大量的计算时间
使用 PyTorch 的原生 scaled_dot_product_attention,我们可以显着增加批量大小。 我们现在观察到批量大小为 32 及以上更改的以下测量结果。
Triton:用于融合相对位置编码的自定义 SDPA
暂时离开推理吞吐量,我们开始分析整体 SAM 内存。 在图像编码器中,我们看到了显着的内存分配峰值
放大后,我们看到此分配发生在 add_decomposed_rel_pos 中,在以下行:
此处的 attn 变量是两个较小张量的加法:形状为 (B, q_h, q_w, k_h, 1) 的 rel_h 和形状为 (B, q_h, q_w, 1, k_w) 的 rel_w。
毫不奇怪,内存高效的注意力内核(通过 SDPA 使用)在注意力偏差大小超过 3.0GiB 的情况下花费了很长时间。 如果我们不分配这个大的 attn 张量,而是将两个较小的 rel_h 和 rel_w 张量线程化到 SDPA 中,并且仅在需要时才构造 attn,我们预计会获得显着的性能提升。
遗憾的是,这不是一个简单的修改; SDPA 内核经过高度优化,并以 CUDA 编写。 我们可以转向 Triton,他们提供了易于理解和使用的 FlashAttention 实现教程。 经过一番深入研究,并与 xFormer 的 Daniel Haziza 密切合作,我们发现了一种输入形状的情况,在这种情况下,相对容易实现内核的融合版本。 详细信息已添加到存储库中。 令人惊讶的是,对于推理用例,这可以在 350 行代码内完成。
这是使用 Triton 代码直接构建新内核来扩展 PyTorch 的一个很好的例子。
内核跟踪
使用我们的自定义位置 Triton 内核,我们观察到批量大小为 32 的以下测量结果。
NT:NestedTensor 和批量处理 predict_torch
我们在图像编码器上花费了大量时间。 这是有道理的,因为它占用了最多的计算时间。 然而,此时它已经得到了很好的优化,而占用最多时间的操作员将需要大量的额外投资才能得到改进。
我们通过 掩码预测管道发现了一个有趣的观察结果:对于我们拥有的每个图像,都有一个关联的大小、坐标和 fg_labels 张量。 这些张量中的每一个都具有不同的批量大小。 每个图像本身的大小也不同。 这种数据表示形式看起来像锯齿状数据。 借助 PyTorch 最近发布的 NestedTensor,我们可以将我们的数据管道批次坐标和 fg_labels 张量修改为单个 NestedTensor。 这可以为图像编码器之后的提示编码器和掩码解码器带来显着的性能优势。 调用
torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged)
内核跟踪
我们现在可以看到,我们从 CPU 启动内核的速度比 GPU 处理的速度快得多,并且在我们的定时区域结束时,它花费了很长时间等待 GPU 完成(cudaDeviceSynchronize)。 我们也没有看到 GPU 上内核之间有更多的空闲时间(白色空间)。
借助嵌套张量,我们观察到批量大小为 32 及以上更改的以下测量结果。
int8:量化和近似 matmul
我们在上面的跟踪中注意到,现在大量时间花费在 GEMM 内核中。 我们已经进行了足够的优化,现在我们看到矩阵乘法在推理中所占的时间比缩放点积注意力更多。
在早期从 fp32 过渡到 bfloat16 的经验基础上,让我们更进一步,使用 int8 量化来模拟更低的精度。 在查看量化方法时,我们专注于动态量化,其中我们的模型观察层的可能输入和权重的范围,并将可表达的 int8 范围细分,以均匀地“分散”观察到的值。 最终,每个浮点输入都将映射到 [-128, 127] 范围内的单个整数。 有关更多信息,请参阅 PyTorch 的量化教程
降低精度可以立即节省峰值内存,但要实现推理速度的提升,我们必须通过 SAM 的操作充分利用 int8。 这需要构建一个高效的 int8@int8 矩阵乘法内核,以及转换逻辑,以实现从高精度到低精度(量化)的转换,以及从低精度到高精度(反量化)的反向转换。 利用 torch.compile 的强大功能,我们可以编译并将这些量化和反量化例程融合到高效的单个内核和矩阵乘法的尾声中。 生成的实现相当简短,并且少于 250 行代码。 有关 API 和用法的更多信息,请参阅 pytorch-labs/ao。
虽然在推理时量化模型时通常会看到一些精度下降,但 SAM 对较低精度推理的鲁棒性特别强,精度损失极小。 添加量化后,我们现在观察到 批量大小为 32 及以上更改的以下测量结果。
稀疏性:半结构化 (2:4) 稀疏性
矩阵乘法仍然是我们的瓶颈。 我们可以借助模型加速剧本,使用另一种经典方法来近似矩阵乘法:稀疏化。 通过稀疏化我们的矩阵(即,将值置零),我们理论上可以使用更少的位来存储权重和激活张量。 我们决定将张量中的哪些权重设置为零的过程称为剪枝。 剪枝背后的想法是,权重张量中的小权重对层的净输出贡献很小,通常是权重与激活的乘积。 剪除小权重可能会减小模型大小,而精度损失不大。
剪枝方法多种多样,从完全非结构化(其中权重被贪婪地剪除)到高度结构化(其中张量的大的子组件一次被剪除)。 方法的选择并非易事。 虽然非结构化剪枝可能在理论上对精度的影响最小,但 GPU 在乘以大型密集矩阵时也非常高效,并且在稀疏状态下可能会遭受显着的性能下降。 PyTorch 支持的一种最新的剪枝方法力求达到平衡,称为半结构化(或 2:4)稀疏性。 这种稀疏存储将原始张量减少了显着的 50%,同时产生了密集的张量输出,可以利用高性能的 2:4 GPU 内核。 请参见下图进行说明。
来自 developer.nvidia.com/blog/exploiting-ampere-structured-sparsity-with-cusparselt
为了使用这种稀疏存储格式和相关的快速内核,我们需要剪除我们的权重,使其符合格式的约束。 我们选择在一个 1x4 区域中剪除两个最小的权重,从而衡量性能与精度的权衡。 将权重从其默认 PyTorch(“步幅式”)布局更改为这种新的半结构化稀疏布局非常容易。 为了实现 apply_sparse(model)
,我们只需要 32 行 Python 代码
import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
# Sparsity helper functions
def apply_fake_sparsity(model):
"""
This function simulates 2:4 sparsity on all linear layers in a model.
It uses the torch.ao.pruning flow.
"""
# torch.ao.pruning flow
from torch.ao.pruning import WeightNormSparsifier
sparse_config = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
sparse_config.append({"tensor_fqn": f"{name}.weight"})
sparsifier = WeightNormSparsifier(sparsity_level=1.0,
sparse_block_shape=(1,4),
zeros_per_block=2)
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.step()
sparsifier.squash_mask()
def apply_sparse(model):
apply_fake_sparsity(model)
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
借助 2:4 稀疏性,我们观察到 vit_b 和批量大小 32 的 SAM 的峰值性能
结论
总结一下,我们很高兴宣布了迄今为止我们最快的 Segment Anything 实现。 我们使用一系列新发布的特性,以纯 PyTorch 重写了 Meta 的原始 SAM,且精度没有损失:
- Torch.compile PyTorch 的原生 JIT 编译器,提供快速、自动化的 PyTorch 运算融合 [教程]
- GPU 量化 通过降低精度运算加速模型 [api]
- 缩放点积注意力 (SDPA) 一种新的、内存高效的注意力实现 [教程]
- 半结构化 (2:4) 稀疏性 使用更少的位来存储权重和激活,从而加速模型 [教程]
- 嵌套张量 高度优化的、用于处理非均匀批次和图像大小的参差不齐的数组 [教程]
- Triton 内核。 自定义 GPU 运算,通过 Triton 轻松构建和优化
有关如何重现此博文中呈现的数据的更多详细信息,请查看 segment-anything-fast 的 experiments 文件夹。 如果您遇到任何技术问题,请随时联系我们或提出问题。
在我们的下一篇文章中,我们很高兴分享使用 PyTorch 原生编写的 LLM 获得的类似性能提升!
致谢
我们要感谢 Meta 的 xFormers 团队,包括 Daniel Haziza 和 Francisco Massa,感谢他们编写 SDPA 内核并帮助我们设计自定义的一次性 Triton 内核。