博客

使用 PyTorch 加速生成式 AI:Segment Anything 2 – 以低延迟和快速冷启动实现极速推理

作者: 2025年2月26日2025年5月7日暂无评论

本文是我们关于如何使用纯原生 PyTorch 加速生成式 AI 模型的多系列博客中的第二篇,重点关注延迟和弹性扩展。我们使用 torch.compile 和 torch.export 创建了高度优化的 SAM2 低延迟版本,能够在新的实例上快速扩展。

通过利用 AOTInductor (AOTI) 进行 torch.export 预编译、降低精度、批量提示词处理以及 GPU 预处理,我们观察到 p90 执行延迟和队列时间相比常规即时(eager)模式 PyTorch 提升了高达 13 倍

我们通过 Modal 的自动伸缩云基础设施上的真实部署,计算了最终结果并展示了这些改进。

  p50 执行延迟
(ms / 提升倍数)
p90 执行延迟
(ms / 提升倍数)
  Eager float32 AOTI float16 Eager float32 AOTI float16
AMG 741 112 (6.6x) 1140 176 (6.5x)
SPS 98 20 (4.9x) 130 28 (4.6x)
MPS 269 38 (7.1x) 714 52 (13.7x)
  p50 队列时间 (ms / 提升倍数) p90 队列时间 (ms / 提升倍数)
  Eager float32 AOTI float16 Eager float32 AOTI float16
AMG 201 41 (4.9x) 815 327 (2.6x)
SPS 31 33 (0.9x) 441 49 (9.0x)
MPS 40 37 (1.1x) 942 75 (12.6x)

任务说明

第一篇文章侧重于每张图像处理少量不同的提示词(兴趣点)。这些点代表了真值遮罩的中心点。对于本文,我们将重点关注更广泛的任务:单提示词分割 (SPS)、多提示词分割 (MPS)、以及无需给定提示词集即可为输入图像生成全套遮罩的自动遮罩生成 (AMG)。第一篇文章仅关注了 MPS。

comparison of 3 images

图中的小星星代表用户提示词。对于 AMG,没有提示词,遮罩是从初始候选提示词(猜测)的密集网格中通过启发式方法过滤得到的。对于 SPS 和 MPS,用户提示词源自 AMG 遮罩的中心点。对于 SPS,我们选择面积最大的遮罩。

请注意,SAM2 使用了与 SAM1 不同的骨干网络。在本博客中,我们仅考虑最大且最准确的 sam2.1_hiera_large 骨干网络。

我们在 torchao 的示例文件夹中整理了重现结果所需的脚本,并逐步将 torchao 中 SAM2 模型的更改中更稳定的部分合并到主要的 SAM2 代码库中。因此,如果您有兴趣查看前沿变体或希望贡献实验性功能,请随时联系 torchao 仓库和团队。对于更稳定和最新的模型版本,请直接前往 SAM2 仓库。

概述

我们将此处展示的更改分为两类。“极速”(Fast)更改仅限于那些不会影响模型准确性的技术。“狂暴”(Furious)更改则通过利用低精度数据类型等近似值,以牺牲部分数值精度来换取额外的速度。

近似处理可能会略微降低精度指标,但性能却能显著提升,同时仍能通过基于平均交并比 (mIoU) 的端到端测试。

为了衡量性能提升,我们处理了从 SAM2 验证数据集中随机选择的 1000 张图像。我们查看了每张图像的 p50 和 p90 延迟。为了衡量准确性,我们考察了 mIoU。最值得注意的是,对于 AMG 任务,我们还定义了失败计数指标。如果遮罩数量不同,我们认为比较失败。事实证明,这是一个相当不稳定的量,我们可以看出其他任务对微小的数值变化不像 AMG 那样敏感。

实验设置

我们在常规的 H100 开发服务器上运行离线实验,这是一台性能相当强劲的机器。

然而,我们试图在真实的约束条件下观察这些任务。具体来说,我们希望模拟服务器端的推理环境。这意味着我们不会使用 DataLoader 来掩盖图像预处理或解码程序的延迟。

在延迟计算中,我们包含了解码、分割以及将遮罩转换为运行长度编码 (RLE) 遮罩字典的过程。换句话说,我们排除了将图像加载到内存主机字节数组以及将结果字典作为 json 文件存储在磁盘上的过程。这是为了模拟更真实的场景。

更具体地说,考虑下面我们测量程序中包含的例程代码。对于任何任务,gen_masks 都会生成一个批量布尔张量位掩码,代表相应的对象遮罩。然后我们将此位掩码压缩为运行长度编码 (rle) 格式,以便更有效地从远程服务器传回结果。

image_tensors = decode_img_bytes(...)
masks = gen_masks(image_tensors, ...)
rle_dicts = [rle_dict_from_masks(m) for m in masks]

优化

ao:即时代码优化

这项工作最有效的工具是 PyTorch autograd 分析器结合 record_function。为了构建此软件,我们反复使用分析器来观察程序并确认任何更改的有效性。还要记住,分析器本身也有开销。收集的数据越多(例如堆栈跟踪),引入的开销就越大,这可能会扭曲收集的跟踪信息。但它是发现同步点、内核之间的间隙以及耗时较长的 GPU 内核的绝佳工具。

GPU 跟踪有助于理解那些不一定能轻易通过编译解决的瓶颈。我们发现 AutomaticMaskGeneration 尤其受到用于存储遮罩的数据结构以及用于将遮罩转换为运行长度编码压缩格式的例程的主导。我们还发现 AMG 性能的很大一部分受到作为单个批次创建的大量遮罩的影响。有时,通过重排序操作,可以在后处理阶段更早地将候选遮罩过滤为较少的候选者。这反过来又显著加快了后续操作的速度。

为了确认我们实现的准确性,我们首先在不更改任何设置且使用 float32 精度的情况下进行比较。我们发现 mIoU 没有变化,并且当使用完全相同的设置时,遮罩完美匹配。这意味着这些即时模式的更改没有影响这些任务的准确性。

AMG

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU / 失败计数
基准 864 1144 4350 参考
AO 693 786 4010 1 / 0

ao:批量处理提示词

我们能够应用的另一个无损性能优化是批量处理用户输入提示词计算。当在 H100 等服务器级 GPU 上以批大小 1 优化延迟时,我们通常会剩下大量空闲内存。我们可以很容易地用这些内存换取性能,即一次处理更多的兴趣点(也称为用户提示词)。请记住,SAM2 分为两部分:第一部分是骨干网络(图像编码器),第二部分是基于一组用户提示词/兴趣点进行遮罩的预测和解码。在第二部分中,我们可能会遇到更多甚至数量可变的输入,正是这第二部分,我们应用了批量处理。

这会导致内存大幅增加,但延迟也会好得多。基准测试在循环中为每个提示词生成一个遮罩。对于 AMG,基准测试一次处理 64 个提示词,所需要的只是将其更改为 1024(生成的候选提示词数量)。对于 SPS,我们一次处理一个提示词,但为了完整起见,下面也包含了它。

AMG

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU / 失败计数
基准 864 1144 4350 参考
AO + 批量处理 613 706 33786 0.9999995 / 0

SPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU
基准 116 181 1337 参考
AO 110 170 1339 1

MPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU
基准 276 681 1337 参考
AO + 批量处理 126 225 8021 0.9999992

技术附注:最值得注意的是,为了支持 MPS 的批量处理,并避免对代码库进行繁重的手动重写以支持同时处理多个提示词,我们使用了一个名为 MapTensor 的张量子类。MapTensor 允许我们传递一批 N 个提示词,但声明批大小为 1。然后,任何操作都会自动广播到封装的张量,并在模型的预测部分传播。这之所以有效,是因为单个提示词预测彼此独立。这与 torch.vmap 非常相似。

center_points_torch = to_map_tensor(center_points_torch)
center_points_label_torch = to_map_tensor(center_points_label_torch)
masks, scores, _ = mask_generator.predictor.predict(
    point_coords=center_points_torch,
    point_labels=center_points_label_torch,
    multimask_output=True,
    return_logits=False,
    return_type="torch",
)
# Unwrapping MapTensor
masks = masks.elems
scores = scores.elems

fast:全图编译

正如我们的第一篇文章一样,我们首先移除 GPU 同步和图中断,以便在适当的情况下使用带有 max-autotune 内核的全图编译模型代码。经过一些重写,我们能够编译图像编码器和遮罩预测部分。

我们运行两次实验以了解编译带来的开销。我们首先在 TORCHINDUCTOR_CACHE_DIR 为空的环境中运行一次,然后在摄取上一次运行的工件时再次运行。特别是,自动调优可能需要很长时间,并且发生在原始环境的第一次调用中。我们将第二次运行称为“预热”。由于各种其他相关的初始化过程,第一次迭代通常预期会很慢,但即使使用现有的缓存并且输入完全相同的形状,编译也会显著增加开销。话虽如此,在第一次调用时,预热环境中的几秒钟开销通常还是可以承受的。

这些缺点中的大多数都可以缓解,并且编译会导致延迟的显著改善和内存的减少。

AMG

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU /
失败计数
首次迭代
(ms)
AO + 批量处理 613 706 33786 0.9999995 / 0 1125
+ 编译 (冷启动) 423 513 29349 跳过 404866
+ 编译 (预热) 439 530 29349 0.994 / 190 8544

使用自动遮罩分割时,每个遮罩产生的遮罩数量可能会略有不同。模型可能产生的每个对象的遮罩数量存在歧义。例如,一辆汽车可以细分为框架、窗户和门,也可以视为一个整体。当修改导致遮罩数量发生变化时,我们认为比较失败,并且我们仅计算精确匹配遮罩的 mIoU。这不适用于其他任务。我们发现生成的遮罩数量对小的数值变化非常敏感。其他任务使用相同的代码,MPS 特别有助于我们进一步验证正确性。

SPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 首次迭代
(ms)
AO 110 170 1339 1 562
+ 编译 (冷启动) 102 158 1343 跳过 319954
+ 编译 (预热) 100 160 1302 0.9999 8947

MPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 首次迭代
(ms)
AO + 批量处理 126 225 8021 0.9999992 504
+ 编译 (冷启动) 129 215 8021 跳过 333308
+ 编译 (预热) 113 213 8021 0.998 8617

furious:TF32, float16 和 GPU 预处理

我们发现,对于模型的几个重要子组件,使用 float16 是正确的精度级别。特别是,图像编码器和遮罩解码器权重可以完全转换为 float16。我们也可以对剩余的 float32 矩阵运算使用 TensorFloat32 精度。进一步降低精度应该是可能的,我们可能会在以后的文章中讨论这个问题。我们还将图像归一化等图像预处理工作在“狂暴”模式下移至 GPU。由于差异太大且模型会出现明显的 mIoU 衰减,我们无法使用 GPU 解码 (nvJPEG) 例程,因此图像解码仍然在 CPU 上进行。

AMG

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU /
失败计数
AO
+ 批量处理
+ 编译 (预热)
439 530 29349 0.994 / 190
+ 狂暴 (furious) 165 240 28335 0.978 / 306

这导致 AMG 任务的 mIoU 显著下降,但不会影响其他任务。经过深入调查,我们仍将其归因于数值不稳定和操作重排序。需要更多工作来进一步研究这一点,并且在较低精度下运行 AMG 任务可能并不划算。然而,其他任务在延迟方面受益匪浅,且 mIoU 变化极小。

SPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU
AO
+ 编译 (预热)
100 160 1302 0.9999
+ 狂暴 (furious) 32 63 861 0.9997

MPS

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU
AO
+ 批量处理
+ 编译 (预热)
113 213 8021 0.998
+ 狂暴 (furious) 36 64 4222 0.997

AOTInductor (AOTI) 通过 torch.export 进行提前编译

在弹性扩展时,通常无法容忍较长的启动时间。这意味着第一次迭代不能慢,我们必须快速交付结果。这就是 torch.compile 当前的编译开销可能会阻碍进度的地方。为了解决这个问题,我们可以使用 AOTInductor (AOTI) 通过 torch.export 进行提前编译。AOTI 允许我们在代表性输入上编译模型,并将生成的代码存储在二进制文件中,该文件加载和运行速度都很快。

AOTI 通过 torch.export 是一个新功能,目前我们无法导出所有可编译的内容。我们已经能够导出所有任务的图像编码器,但由于提示词的不同,仅能够导出 AMG 和 SPS 任务的遮罩预测。torch.export 也支持动态形状,但我们需要投入更多时间来为此准备代码。

AMG: AO + 批量处理 + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU /
失败计数
首次迭代
(ms)
+ 编译 (预热) 165 240 28335 0.978 / 306 10341
+ 加载导出模型
(冷启动)
162 233 27927 0.974 / 308 906

SPS: AO + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 首次迭代
(ms)
+ 编译 (预热) 32 63 861 0.9997 7989
+ 加载导出模型
(冷启动)
35 66 1686 0.9997 763

请注意,加载导出的模型会显著增加内存占用。这可能只会增加峰值内存利用率,因为为了避免同时在内存中拥有两份权重,真正的初始化需要在加载导出模型之前延迟进行。这是我们可以解决的问题,但内存消耗远未达到上限。我们在其他任务中没有看到增加,因为 AMG 和 MPS 的峰值内存主要由处理遮罩批次决定。减少这种情况的一种方法可能是更早地以 rle 格式(或其他稀疏格式)对遮罩进行操作,但目前考虑到当前的内存消耗和对延迟的关注,没有理由这样做。

MPS: AO + 批量处理 + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 首次迭代
(ms)
+ 编译 (预热) 36 64 4222 0.997 9626
+ 加载导出模型
(冷启动)
43 72 3813 0.997 747

仅使用导出模型似乎无法从广泛的预热中获益,可以在全新的 inductor 缓存目录中运行。但同样,我们不会清除 CUDA 缓存或其他缓存。在关于 Modal 的部分中,我们就是在原始环境中运行其中一些实验的。

当在一个新进程中仅处理 1000 张图像时,使用导出模型确实非常值得,可以节省编译和其他冷启动开销。

加餐:更多的 GPU 预处理

此时,延迟已经相当低了。特别是对于 SPS 和 MPS 任务,我们的处理速度在 30ms 到 40ms 左右。让我们再次调出设置部分中的伪代码。

image_tensors = decode_img_bytes(...)
masks = gen_masks(image_tensors, ...)
rle_dicts = [rle_dict_from_masks(m) for m in masks]

进一步的分析表明,此时 decode_img_bytes 大约需要 10ms。特别是,它使用了 torchvision 的 ToTensor 转换来将 numpy 张量转换为缩放后的 float32 torch.Tensor。传递给 ToTensor 的字节数据已经被解码并转换为 numpy ndarray。通过稍微重写 ToTensor,使用 torchvision 的 v2 API,并先将 uint8 解码后的较小整数张量移动到 GPU 再进行缩放,我们可以再获得 10ms 的延迟提升。如果不将 decode_img_bytes 包含在我们的分析中,我们就会错过这个对服务器端推理产生现实影响的机会。

image_tensor = torch.from_numpy(image_tensor)
image_tensor = image_tensor.permute((2, 0, 1))
image_tensor = image_tensor.cuda()
image_tensor = v2.ToDtype(torch.float32, scale=True)( image_tensor)

特别注意,使用固定内存 (pinned memory) 执行异步数据传输并不适用,因为将张量移动到固定内存所花费的时间对于这种数据移动来说并不值得异步性带来的增益。对于未来的工作,我们可能希望通过使用更高级的直接内存传输技术来探索此处的进一步改进。

AMG: AO + 批量处理 + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU /
失败计数
首次迭代
(ms)
+ 加载导出模型
(冷启动)
162 233 27927 0.974 / 308 906
+ 加载导出模型 (预热) 157 230 27927 0.974 / 308 799
+ 加载导出模型 (预热)
+ 预处理
136 208 27950 0.977 / 311 908

SPS: AO + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 首次迭代
(ms)
+ 加载导出模型
(冷启动)
35 66 1686 0.9997 763
+ 加载导出模型 (预热) 31 63 1686 0.9997 683
+ 加载导出模型 (预热)
+ 预处理
19 25 1711 0.9997 658

MPS: AO + 批量处理 + 狂暴

  p50 延迟 (ms) p90 延迟 (ms) 内存 (MiB) mIoU 首次迭代
(ms)
+ 加载导出模型
(冷启动)
43 72 3813 0.997 747
+ 加载导出模型 (预热) 53 81 3813 0.997 807
+ 加载导出模型 (预热)
+ 预处理
31 41 3837 0.997 671

这一小小的更改对 SPS 和 MPS 任务产生了显著影响。

在 Modal 上部署

最后,我们将优化后的推理部署到了无服务器基础设施提供商 Modal 上,以证明这些优化的好处可以在更现实的部署环境中实现。

特别是,通过 torch.export 进行编译和 AOTI 需要额外的工作。在天真的部署中,这些工作可能会被添加到每一次推理执行中,从而增加延迟,使其超过更快模型带来的任何改进。这在弹性或自动伸缩的基础设施中尤其具有挑战性,因为我们的推理服务副本需要定期自动创建和销毁。

我们在 torchao 仓库中共享了一个部署脚本 (cli_on_modal.py),以展示弹性部署的一种模式。我们提前构建导出的模型,然后将它们上传到 分布式存储。相对于即时执行,当副本启动时,这会增加一点额外的工作,因为它们需要通过网络读取这些数据,但这比编译或导出要便宜得多。

我们通过大规模批量推理工作负载对此部署进行了基准测试:发送 1000 张图像进行并发处理。该部署在峰值时扩展到 10 个 GPU 上的 10 个副本,并在不活动时缩减到 0 个 GPU。

首先,让我们看看执行延迟。

  p50 执行延迟
(ms / 提升倍数)
p90 执行延迟
(ms / 提升倍数)
  Eager float32 AOTI float16 Eager float32 AOTI float16
    Modal 离线   Modal 离线
AMG 741 112 (6.6x) 136 (5.4x) 1140 176 (6.5x) 208 (5.5x)
SPS 98 20 (4.9x) 19 (5.2x) 130 28 (4.6x) 25 (5.2x)
MPS 269 38 (7.1x) 31 (8.7x) 714 52 (13.7x) 41 (17.4x)

我们注意到 Modal 和离线环境的执行延迟相当接近,特别是相对于基准线而言,这表明离线优化部署是直接优化部署的合理代理。

除了执行延迟外,我们的批量工作负载还有队列时间,因为副本比输入少,所以一些输入必须排队等待。

  p50 队列时间 (ms) p90 队列时间 (ms)
  Eager float32 AOTI float16 Eager float32 AOTI float16
AMG 201 41 (4.9x) 815 327 (2.6x)
SPS 31 33 (0.9x) 441 49 (9.0x)
MPS 40 37 (1.1x) 942 75 (12.6x)

尽管基础设施提供的排队系统没有变化,但当我们使用优化后的模型时,队列延迟也会降低——在 p90 情况下降低了 2 到 12 倍。这是因为当我们更快地完成先前的输入(由于执行延迟减少)时,我们可以更快地拉取下一个输入(从而减少它们的排队时间)。

如果您有兴趣进一步优化 SAM2 推理或部署,请随时通过 torchao 仓库联系我们!

结论

我们用纯 PyTorch 重写了 Meta 原版的 SAM2,几乎没有损失准确性,并高度关注延迟。我们将优化后的推理部署到了无服务器基础设施提供商 Modal 上,以证明这些优化的好处可以在更现实的部署环境中实现。

通过利用 AOTInductor (AOTI) 通过 torch.export 进行提前编译、降低精度、批量提示词和 GPU 预处理,我们观察到 p90 执行延迟和队列时间相比常规即时模式 PyTorch 提升了高达 13 倍。

在弹性或自动伸缩的基础设施中,我们的推理服务副本需要定期自动创建和销毁。此时,天真地使用 torch.compile 可能会在推理执行中增加额外工作,从而掩盖更快模型带来的任何改进。通过利用 AOTInductor (AOTI) 通过 torch.export 进行提前编译,我们能够提前上传导出的模型并通过网络读取这些数据,这使我们能够在不显著增加工作量的情况下获得编译带来的好处。

有关如何重现本博客文章数据的更多详细信息,请查看 torchao 的实验文件夹。如果您遇到任何技术问题,请随时联系我们或提交 issue