注意
转到末尾 下载完整的示例代码
使用 dynamo 后端编译 SAM2¶
此示例说明了使用 Torch-TensorRT 优化的最先进模型 Segment Anything Model 2 (SAM2)。
Segment Anything Model 2 是一个基础模型,旨在解决图像和视频中可提示的视觉分割问题。在编译之前安装以下依赖项
pip install -r requirements.txt
需要进行某些自定义修改以确保模型成功导出。要应用这些更改,请使用 以下分支 安装 SAM2(安装说明)
在自定义 SAM2 分支中,已应用以下修改以消除图中断并提高延迟性能,从而确保更高效的 Torch-TRT 转换
一致的数据类型: 保留输入张量的数据类型,消除强制 FP32 转换。
掩码操作: 使用基于掩码的索引而不是直接选择数据,从而提高 Torch-TRT 兼容性。
安全初始化: 有条件地初始化张量,而不是连接到空张量。
标准函数: 避免特殊上下文和自定义 LayerNorm,依靠内置 PyTorch 函数以获得更好的稳定性。
导入以下库¶
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_tensorrt
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam_components import SAM2FullModel
matplotlib.use("Agg")
定义 SAM2 模型¶
使用 SAM2ImagePredictor
类加载 facebook/sam2-hiera-large
预训练模型。SAM2ImagePredictor
提供实用程序来预处理图像,存储图像特征(通过 set_image
函数)并预测掩码(通过 predict
函数)
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
为了确保我们成功导出整个模型(图像编码器和掩码预测器)组件,我们创建了一个独立的模块 SAM2FullModel
,它使用 SAM2ImagePredictor
类中的这些实用程序。SAM2FullModel
执行特征提取和掩码预测,只需一步即可完成,而不是 SAM2ImagePredictor
的两步过程(set_image 和 predict 函数)
class SAM2FullModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.image_encoder = model.forward_image
self._prepare_backbone_features = model._prepare_backbone_features
self.directly_add_no_mem_embed = model.directly_add_no_mem_embed
self.no_mem_embed = model.no_mem_embed
self._features = None
self.prompt_encoder = model.sam_prompt_encoder
self.mask_decoder = model.sam_mask_decoder
self._bb_feat_sizes = [(256, 256), (128, 128), (64, 64)]
def forward(self, image, point_coords, point_labels):
backbone_out = self.image_encoder(image)
_, vision_feats, _, _ = self._prepare_backbone_features(backbone_out)
if self.directly_add_no_mem_embed:
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
feats = [
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
high_res_features = [
feat_level[-1].unsqueeze(0) for feat_level in features["high_res_feats"]
]
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=(point_coords, point_labels), boxes=None, masks=None
)
low_res_masks, iou_predictions, _, _ = self.mask_decoder(
image_embeddings=features["image_embed"][-1].unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
repeat_image=point_coords.shape[0] > 1,
high_res_features=high_res_features,
)
out = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}
return out
使用预训练权重初始化 SAM2 模型¶
使用预训练权重初始化 SAM2FullModel
。由于我们已经初始化了 SAM2ImagePredictor
,我们可以直接使用其中的模型 (predictor.model
)。我们将模型转换为 FP16 精度以获得更快的性能。
encoder = predictor.model.eval().cuda()
sam_model = SAM2FullModel(encoder.half()).eval().cuda()
加载存储库中提供的示例图像。
input_image = Image.open("./truck.jpg").convert("RGB")
加载输入图像¶
这是我们将要使用的输入图像

input_image = Image.open("./truck.jpg").convert("RGB")
除了输入图像之外,我们还提供提示作为输入,用于预测掩码。提示可以是框、点以及来自先前预测迭代的掩码。在此演示中,我们使用点作为提示,类似于 SAM2 存储库中的原始笔记本
预处理组件¶
以下函数实现预处理组件,这些组件对输入图像应用转换并转换给定的点坐标。我们使用通过 SAM2ImagePredictor 类提供的 SAM2Transforms。要了解有关转换的更多信息,请参阅 https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py
def preprocess_inputs(image, predictor):
w, h = image.size
orig_hw = [(h, w)]
input_image = predictor._transforms(np.array(image))[None, ...].to("cuda:0")
point_coords = torch.tensor([[500, 375]], dtype=torch.float).to("cuda:0")
point_labels = torch.tensor([1], dtype=torch.int).to("cuda:0")
point_coords = torch.as_tensor(
point_coords, dtype=torch.float, device=predictor.device
)
unnorm_coords = predictor._transforms.transform_coords(
point_coords, normalize=True, orig_hw=orig_hw[0]
)
labels = torch.as_tensor(point_labels, dtype=torch.int, device=predictor.device)
if len(unnorm_coords.shape) == 2:
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
input_image = input_image.half()
unnorm_coords = unnorm_coords.half()
return (input_image, unnorm_coords, labels)
后处理组件¶
以下函数实现后处理组件,包括绘制和可视化掩码和点。我们使用 SAM2Transforms 来后处理这些掩码,并通过置信度分数对它们进行排序。
def postprocess_masks(out, predictor, image):
"""Postprocess low-resolution masks and convert them for visualization."""
orig_hw = (image.size[1], image.size[0]) # (height, width)
masks = predictor._transforms.postprocess_masks(out["low_res_masks"], orig_hw)
masks = (masks > 0.0).squeeze(0).cpu().numpy()
scores = out["iou_predictions"].squeeze(0).cpu().numpy()
sorted_indices = np.argsort(scores)[::-1]
return masks[sorted_indices], scores[sorted_indices]
def show_mask(mask, ax, random_color=False, borders=True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [
cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
]
mask_image = cv2.drawContours(
mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
def visualize_masks(
image, masks, scores, point_coords, point_labels, title_prefix="", save=True
):
"""Visualize and save masks overlaid on the original image."""
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(point_coords, point_labels, plt.gca())
plt.title(f"{title_prefix} Mask {i + 1}, Score: {score:.3f}", fontsize=18)
plt.axis("off")
plt.savefig(f"{title_prefix}_output_mask_{i + 1}.png")
plt.close()
预处理输入¶
预处理输入。在以下代码片段中,torchtrt_inputs
包含(input_image、unnormalized_coordinates 和 labels)。unnormalized_coordinates 是点的表示,label(= 1 在此演示中)表示前景点。
torchtrt_inputs = preprocess_inputs(input_image, predictor)
Torch-TensorRT 编译¶
在非严格模式下导出模型,并以 FP16 精度执行 Torch-TensorRT 编译。我们启用 FP32 矩阵乘法累加,使用 use_fp32_acc=True
以保持与原始 Pytorch 模型的精度。
exp_program = torch.export.export(sam_model, torchtrt_inputs, strict=False)
trt_model = torch_tensorrt.dynamo.compile(
exp_program,
inputs=torchtrt_inputs,
min_block_size=1,
enabled_precisions={torch.float16},
use_fp32_acc=True,
)
trt_out = trt_model(*torchtrt_inputs)
输出可视化¶
后处理 Torch-TensorRT 的输出,并使用上面提供的后处理组件可视化掩码。输出应存储在当前目录中。
trt_masks, trt_scores = postprocess_masks(trt_out, predictor, input_image)
visualize_masks(
input_image,
trt_masks,
trt_scores,
torch.tensor([[500, 375]]),
torch.tensor([1]),
title_prefix="Torch-TRT",
)
参考文献¶
脚本总运行时间: ( 0 分钟 0.000 秒)