跳转到主要内容
博客

使用 PyTorch 加速生成式 AI:Segment Anything 2 – 以低延迟和快速冷启动实现极速推理

作者: 2025 年 2 月 26 日2025 年 5 月 7 日暂无评论

本文是关于“如何使用纯原生 PyTorch 加速生成式 AI 模型,并侧重于延迟和弹性扩展”的多系列博客的第一篇文章的后续。我们使用 torch.compile 和 torch.export 来创建高度优化的低延迟 SAM2 版本,这些版本可以在新实例上快速扩展。

通过利用 AOTInductor (AOTI) 通过 torch.export 进行的提前编译、降低精度、批处理提示和 GPU 预处理,我们观察到与常规急切模式 PyTorch 相比,p90 执行延迟排队时间提高了 13 倍

我们计算了最终结果,并演示了在来自 Modal 的自动扩展云基础设施上实际部署的改进。

  p50 执行延迟
(毫秒 / 改进)
p90 执行延迟
(毫秒 / 改进)
  急切模式 float32 AOTI float16 急切模式 float32 AOTI float16
AMG 741 112 (6.6x) 1140 176 (6.5x)
SPS 98 20 (4.9x) 130 28 (4.6x)
MPS 269 38 (7.1x) 714 52 (13.7x)
  p50 排队时间(毫秒 / 改进) p90 排队时间(毫秒 / 改进)
  急切模式 float32 AOTI float16 急切模式 float32 AOTI float16
AMG 201 41 (4.9x) 815 327 (2.6x)
SPS 31 33 (0.9x) 441 49 (9.0x)
MPS 40 37 (1.1x) 942 75 (12.6x)

任务

第一篇文章侧重于每张图像处理少量不同的提示(兴趣点)。这些点代表了真实掩码的中心点。对于这篇文章,我们将重点关注更广泛的任务集:单提示分割 (SPS)、多提示分割 (MPS)、自动掩码生成 (AMG),它为输入图像生成全套掩码,而无需给定提示集。第一篇文章仅侧重于 MPS。

comparison of 3 images

图像中的小星星表示用户提示。对于 AMG,没有提示,并且掩码是从密集的初始候选提示(猜测)网格中启发式过滤出来的。对于 SPS 和 MPS,用户提示来源于 AMG 掩码的中心点。对于 SPS,我们选择面积最大的掩码。

请注意,SAM2 使用与 SAM1 不同的骨干网络。特别是,本博客中我们只考虑最大且最准确的 sam2.1_hiera_large 骨干网络。

我们将重现结果所需的脚本聚合在 torchao 的示例文件夹中,并逐步将 torchao 中 SAM2 模型的更稳定部分更改上游到主 SAM2 仓库。因此,如果您有兴趣查看最新变体或希望贡献实验性功能,请随时联系 torchao 仓库和团队。有关更稳定和最新的模型版本,请直接前往 SAM2。

概述

我们将此处介绍的更改分为两类:快速更改仅限于不影响模型准确性的技术。狂暴更改则通过使用低精度数据类型等近似值,牺牲部分数值精度以换取额外速度。

近似值可能会稍微降低精度指标,以显著提高性能,同时仍能通过基于平均交并比 (mIoU) 的端到端检查。

为了衡量性能改进,我们处理了 1000 张图像,这些图像是从 SAM2 验证数据集中随机选择的。我们查看每张图像的 p50 和 p90 延迟。为了衡量准确性,我们考虑 mIoU。最值得注意的是,对于 AMG 任务,我们还定义了一个失败计数指标。如果掩码数量不同,我们认为比较失败。这被证明是一个相当不稳定的量,我们可以看到其他任务不如 AMG 对小的数值变化敏感。

设置

我们正在一台常规的 H100 开发服务器上运行离线实验,这是一台相当强大且性能良好的机器。

但是,我们尝试在实际约束下审视这些任务。特别是,我们希望模拟服务器端推理环境。这意味着我们不使用 DataLoader 来隐藏图像预处理或解码例程的延迟。

对于延迟计算,我们包括解码、分割以及将掩码转换为运行长度编码掩码字典。或者换句话说,我们不包括将图像加载到内存中的主机字节数组以及将结果字典存储为磁盘上的 json 文件。这旨在模拟更真实的设置。

更具体地说,考虑下面我们测量中包含的例程的代码。对于任何任务,`gen_masks` 都生成一个批处理的布尔张量位掩码,表示相应的对象掩码。然后,我们将此位掩码压缩为运行长度编码 (rle) 格式,该格式可以更有效地将结果从远程服务器传回。

image_tensors = decode_img_bytes(...)
masks = gen_masks(image_tensors, ...)
rle_dicts = [rle_dict_from_masks(m) for m in masks]

优化

ao:即时代码优化

这项工作最有效的工具是 PyTorch autograd 性能分析器与 `record_function` 相结合。为了构建这个软件,我们反复使用性能分析器来观察程序并确认任何更改的有效性。同样重要的是要记住,性能分析器本身也有开销。收集的数据越多,例如堆栈跟踪,引入的开销就越大,这可能会扭曲收集到的跟踪。但它非常适合查找同步点、内核之间的空间以及耗时长的 GPU 内核。

GPU 跟踪有助于您理解不一定容易通过编译解决的瓶颈。我们发现,特别是 AutomaticMaskGeneration 受到用于存储掩码的数据结构以及用于将掩码转换为运行长度编码压缩格式的例程的支配。我们还发现 AMG 性能的很大一部分是由作为单个批次创建的大量掩码所支配的。有时,可以通过重新排序操作,在后处理阶段更早地将候选掩码过滤为更少的候选掩码。这反过来又显著加快了后续操作的速度。

为了确认我们实现的准确性,我们首先在不改变设置且使用 float32 精度的情况下进行比较。我们发现,在使用完全相同的设置时,mIoU 保持不变,并且掩码完美匹配。这意味着这些即时模式更改并未影响这些任务的准确性。

AMG

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU / 失败计数
基准 864 1144 4350 参考
AO 693 786 4010 1 / 0

ao:批处理提示

我们能够应用的另一个无损性能优化是批处理用户输入提示计算。当在 H100 等服务器级 GPU 上优化批处理大小为 1 的延迟时,我们通常会剩下大量空闲内存。我们可以通过一次处理更多兴趣点(也称为用户提示)轻松地用这些内存换取更高的性能。请记住,SAM2 分为两部分:首先是骨干网络(图像编码器),其次是基于一组用户提示/兴趣点的掩码预测和解码。我们可能会在第二部分中预期到更大或甚至变化的输入数量,并且我们在此第二部分中应用批处理。

这导致内存大幅增加,但也大大降低了延迟。基线在循环中为每个提示生成一个掩码。对于 AMG,基线一次处理 64 个提示,只需将其更改为 1024,即生成的候选提示的数量。对于 SPS,我们一次处理一个提示,但为了完整性,它仍然包含在下面。

AMG

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU / 失败计数
基准 864 1144 4350 参考
AO + 批处理 613 706 33786 0.9999995 / 0

SPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU
基准 116 181 1337 参考
AO 110 170 1339 1

MPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU
基准 276 681 1337 参考
AO + 批处理 126 225 8021 0.9999992

一个技术性的旁注:最值得注意的是,为了实现 MPS 的批处理,并避免对代码库进行大量手动重写以支持同时处理多个提示,我们使用了一个名为 MapTensor 的 Tensor 子类。MapTensor 允许我们传递一个包含 N 个提示的批次,但让它报告批次大小为 1。然后,任何操作都会自动广播到包装的 Tensor,并传播到模型的预测部分。这之所以有效,是因为单个提示预测彼此独立。这与 torch.vmap 非常相似。

center_points_torch = to_map_tensor(center_points_torch)
center_points_label_torch = to_map_tensor(center_points_label_torch)
masks, scores, _ = mask_generator.predictor.predict(
    point_coords=center_points_torch,
    point_labels=center_points_label_torch,
    multimask_output=True,
    return_logits=False,
    return_type="torch",
)
# Unwrapping MapTensor
masks = masks.elems
scores = scores.elems

快速:全图编译

正如我们在第一篇文章中所述,我们首先移除 GPU 同步和图中断,以便在适当情况下使用带有 max-autotune 内核的全图编译模型代码。经过一些重写后,我们能够编译图像编码器和掩码预测。

我们运行两次实验以了解编译带来的开销。我们首先在 TORCHINDUCTOR_CACHE_DIR 为空的环境中运行一次,然后在摄取之前运行的工件时再次运行。特别是,自动调优可能需要很长时间,并且在全新环境中的第一次调用时发生。我们称第二次运行为“暖”。通常,第一次迭代预计会因为各种其他相关的初始化过程而变慢,但即使使用现有缓存并再次输入完全相同的形状,编译也会显著增加其速度。话虽如此,在暖环境下,第一次调用时几秒钟的开销通常仍然可以接受。

这些缺点大部分都可以缓解,并且编译会导致延迟显著改善和内存减少。

AMG

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU /
失败计数
第一次迭代
(ms)
AO + 批处理 613 706 33786 0.9999995 / 0 1125
+ 编译(冷) 423 513 29349 跳过 404866
+ 编译(暖) 439 530 29349 0.994 / 190 8544

在使用自动掩码分割时,每个掩码生成的掩码数量可能会略有不同。模型可能为每个对象生成的掩码数量存在歧义。例如,一辆汽车可能被细分为车架、车窗和车门,或被视为一个整体。当修改导致掩码数量发生变化时,我们认为比较失败,并且我们仅在精确匹配的掩码上计算 mIoU。这不适用于其他任务。我们发现生成的掩码数量对微小的数值变化非常敏感。其他任务使用相同的代码,特别是 MPS 可以帮助我们进一步验证正确性。

SPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 第一次迭代
(ms)
AO 110 170 1339 1 562
+ 编译(冷) 102 158 1343 跳过 319954
+ 编译(暖) 100 160 1302 0.9999 8947

MPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 第一次迭代
(ms)
AO + 批处理 126 225 8021 0.9999992 504
+ 编译(冷) 129 215 8021 跳过 333308
+ 编译(暖) 113 213 8021 0.998 8617

狂暴:TF32、float16 和 GPU 预处理

我们发现对于模型的几个重要子组件来说,使用 float16 是合适的精度水平。特别是,图像编码器和掩码解码器权重可以完全转换为 float16。我们还可以对剩余的 float32 矩阵操作使用 TensorFloat32 精度。应该可以进一步降低精度,我们可能会在未来的文章中解决这个问题。我们还在狂暴模式下将图像预处理(如图像归一化)移至 GPU。我们不能使用 GPU 解码 (nvJPEG) 例程,因为差异过大,模型在 mIoU 方面遭受显著退化,因此图像解码仍发生在 CPU 上。

AMG

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU /
失败计数
AO
+ 批处理
+ 编译(暖)
439 530 29349 0.994 / 190
+ 狂暴 165 240 28335 0.978 / 306

这导致 AMG 任务的 mIoU 显著下降,但未影响其他任务。经过深入调查,我们仍将其归因于数值不稳定性和操作重新排序。需要更多工作进一步调查,并且在较低精度下运行 AMG 任务可能并不有趣。然而,其他任务在延迟方面获得了显著的收益,同时 mIoU 变化最小。

SPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU
AO
+ 编译(暖)
100 160 1302 0.9999
+ 狂暴 32 63 861 0.9997

MPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU
AO
+ 批处理
+ 编译(暖)
113 213 8021 0.998
+ 狂暴 36 64 4222 0.997

AOTInductor (AOTI) 通过 torch.export 进行的预编译

在弹性扩展时,通常无法适应长时间的启动。这意味着第一次迭代不能太慢,我们必须快速交付结果。此时,torch.compile 当前的编译开销可能会成为阻碍。为了解决这个问题,我们可以使用 AOTInductor (AOTI) 通过 torch.export 进行预编译。AOTI 允许我们根据代表性输入编译模型,并将生成的代码存储在可以快速加载和运行的二进制文件中。

通过 torch.export 进行 AOTI 是一项新功能,我们目前无法导出所有可编译的内容。我们已经能够导出所有任务的图像编码器,但由于提示的不同,我们只能导出 AMG 和 SPS 任务的掩码预测。torch.export 也支持动态形状,但我们需要投入更多时间来为其准备代码。

AMG:AO + 批处理 + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU /
失败计数
第一次迭代
(ms)
+ 编译(暖) 165 240 28335 0.978 / 306 10341
+ 加载导出
(冷)
162 233 27927 0.974 / 308 906

SPS:AO + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 第一次迭代
(ms)
+ 编译(暖) 32 63 861 0.9997 7989
+ 加载导出
(冷)
35 66 1686 0.9997 763

请注意,加载导出的模型会显著增加内存。这可能只会增加峰值内存利用率,因为初始化确实需要延迟到加载导出的模型之后,以避免内存中同时存在两倍的权重。这是我们可以解决的问题,但内存消耗远未达到限制。我们在其他任务中没有看到增加,因为 AMG 和 MPS 的峰值内存主要由批处理掩码的开销所主导。减少这种情况的一种方法是尽早以 rle 格式(或其他稀疏格式)操作掩码,但目前,鉴于当前的内存消耗和对延迟的关注,没有理由这样做。

MPS:AO + 批处理 + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 第一次迭代
(ms)
+ 编译(暖) 36 64 4222 0.997 9626
+ 加载导出
(冷)
43 72 3813 0.997 747

单独使用导出似乎不需要大量的预热,并且可以在全新的 Inductor 缓存目录中运行。但是,我们不会逐出 CUDA 缓存或其他缓存。在 Modal 部分,我们正在全新的环境中运行一些这些实验。

当在一个新进程中只处理 1000 张图像时,使用导出确实值得,可以节省编译和其他冷启动开销。

额外:更多 GPU 预处理

至此,延迟已经相当低了。特别是,对于 SPS 和 MPS 任务,我们正在以大约 30 毫秒到 40 毫秒的速度处理。让我们再次回顾设置部分中的伪代码。

image_tensors = decode_img_bytes(...)
masks = gen_masks(image_tensors, ...)
rle_dicts = [rle_dict_from_masks(m) for m in masks]

进一步的性能分析表明,此时 `decode_img_bytes` 大约需要 10 毫秒。特别是,它使用 torchvision 的 ToTensor 转换将 numpy Tensor 转换为缩放后的 float32 torch.Tensor。传递给 ToTensor 的字节已经被解码并转换为 numpy ndarray。通过稍微重写 ToTensor,使用 torchvision 的 v2 API,并在缩放之前将 uint8 解码后的较小整数 Tensor 首先移动到 GPU,我们可以再获得 10 毫秒的延迟。如果不在我们的分析中包含 `decode_img_bytes`,我们就会错过这个对服务器端推理具有实际影响的机会。

image_tensor = torch.from_numpy(image_tensor)
image_tensor = image_tensor.permute((2, 0, 1))
image_tensor = image_tensor.cuda()
image_tensor = v2.ToDtype(torch.float32, scale=True)( image_tensor)

特别值得注意的是,使用固定内存进行异步数据传输不适用,因为将 Tensor 移动到固定内存所需的时间不值得这种数据移动的异步性增益。对于未来的工作,我们可能希望通过使用更高级的直接内存传输技术来进一步探索这里的改进。

AMG:AO + 批处理 + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU /
失败计数
第一次迭代
(ms)
+ 加载导出
(冷)
162 233 27927 0.974 / 308 906
+ 加载导出(暖) 157 230 27927 0.974 / 308 799
+ 加载导出(暖)
+ 预处理
136 208 27950 0.977 / 311 908

SPS:AO + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 第一次迭代
(ms)
+ 加载导出
(冷)
35 66 1686 0.9997 763
+ 加载导出(暖) 31 63 1686 0.9997 683
+ 加载导出(暖)
+ 预处理
19 25 1711 0.9997 658

MPS:AO + 批处理 + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 第一次迭代
(ms)
+ 加载导出
(冷)
43 72 3813 0.997 747
+ 加载导出(暖) 53 81 3813 0.997 807
+ 加载导出(暖)
+ 预处理
31 41 3837 0.997 671

这个小小的改动对 SPS 和 MPS 任务产生了显著影响。

在 Modal 上部署

最后,我们将优化的推理部署到无服务器基础设施提供商 Modal 上,以证明这些优化在更真实的部署环境中也能实现其优势。

特别是,通过 torch.export 进行编译和 AOTI 需要额外的工作。在简单部署中,这些工作可能会添加到每一次推理执行中,从而增加延迟,使模型加速带来的任何改进都相形见绌。这对于弹性或自动扩缩基础设施尤其具有挑战性,因为推理服务的副本需要定期自动创建和销毁。

我们在 torchao 仓库中分享了一个部署脚本 (cli_on_modal.py) 来演示弹性部署的一种模式。我们提前构建导出的模型,然后将它们上传到分布式存储。相对于急切执行,这在副本启动时增加了一些额外工作,因为它们需要通过网络读取这些数据,但这比编译或导出成本低得多。

我们使用大规模批量推理工作负载对本次部署进行了基准测试:发送 1000 张图像进行并发处理。部署在高峰期可扩展到 10 个 GPU 上的 10 个副本,在非活动时可缩减到 0 个 GPU。

首先,让我们看看执行延迟。

  p50 执行延迟
(毫秒 / 改进)
p90 执行延迟
(毫秒 / 改进)
  急切模式 float32 AOTI float16 急切模式 float32 AOTI float16
    Modal 离线   Modal 离线
AMG 741 112 (6.6x) 136 (5.4x) 1140 176 (6.5x) 208 (5.5x)
SPS 98 20 (4.9x) 19 (5.2x) 130 28 (4.6x) 25 (5.2x)
MPS 269 38 (7.1x) 31 (8.7x) 714 52 (13.7x) 41 (17.4x)

我们注意到 Modal 和离线上的执行延迟相当接近,特别是相对于基线而言,这表明离线优化部署是直接优化部署的合理替代方法。

除了执行延迟,我们的批量工作负载还有排队时间,因为副本数量少于输入数量,因此一些输入必须排队等待。

  p50 排队时间 (ms) p90 排队时间 (ms)
  急切模式 float32 AOTI float16 急切模式 float32 AOTI float16
AMG 201 41 (4.9x) 815 327 (2.6x)
SPS 31 33 (0.9x) 441 49 (9.0x)
MPS 40 37 (1.1x) 942 75 (12.6x)

尽管基础设施提供的排队系统未作改变,但当我们使用优化的模型时,排队延迟也随之降低——在 p90 情况下降低了 2 到 12 倍。这是因为当我们更快地完成之前的输入(由于执行延迟降低)时,我们可以更快地获取下一个输入(从而减少它们的排队时间)。

如果您有兴趣进一步优化 SAM2 推理或部署,请随时通过 torchao 仓库联系我们!

结论

我们用纯 PyTorch 重写了 Meta 的原始 SAM2,在几乎不损失准确性的情况下,极大地关注了延迟。我们优化后的推理部署在无服务器基础设施提供商 Modal 上,以证明这些优化可以在更真实的部署环境中实现其优势。

通过利用 AOTInductor (AOTI) 结合 torch.export 进行提前编译、降低精度、批处理提示和 GPU 预处理,我们观察到与常规急切模式 PyTorch 相比,p90 执行延迟和排队时间提高了 13 倍。

在弹性或自动扩缩基础设施中,推理服务的副本需要定期自动创建和销毁,而 torch.compile 的简单部署会增加推理执行的工作量,使更快的模型带来的任何改进都相形见绌。通过利用 AOTInductor (AOTI) 结合 torch.export 进行提前编译,我们能够提前上传导出的模型并通过网络读取这些数据,这使我们能够在不显著增加工作量的情况下获得编译的好处。

有关如何重现此博文中数据的更多详细信息,请查看 torchao 的 experiments 文件夹。如果您遇到任何技术问题,请随时联系我们或提出问题