• 文档 >
  • 使用 dynamo 后端编译 SAM2
快捷键

使用 dynamo 后端编译 SAM2

此示例说明了使用 Torch-TensorRT 优化的最先进模型 Segment Anything Model 2 (SAM2)

Segment Anything Model 2 是一个基础模型,旨在解决图像和视频中可提示的视觉分割问题。在编译之前安装以下依赖项

pip install -r requirements.txt

需要进行某些自定义修改以确保模型成功导出。要应用这些更改,请使用 以下分支 安装 SAM2(安装说明

在自定义 SAM2 分支中,已应用以下修改以消除图中断并提高延迟性能,从而确保更高效的 Torch-TRT 转换

  • 一致的数据类型: 保留输入张量的数据类型,消除强制 FP32 转换。

  • 掩码操作: 使用基于掩码的索引而不是直接选择数据,从而提高 Torch-TRT 兼容性。

  • 安全初始化: 有条件地初始化张量,而不是连接到空张量。

  • 标准函数: 避免特殊上下文和自定义 LayerNorm,依靠内置 PyTorch 函数以获得更好的稳定性。

导入以下库

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_tensorrt
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam_components import SAM2FullModel

matplotlib.use("Agg")

定义 SAM2 模型

使用 SAM2ImagePredictor 类加载 facebook/sam2-hiera-large 预训练模型。SAM2ImagePredictor 提供实用程序来预处理图像,存储图像特征(通过 set_image 函数)并预测掩码(通过 predict 函数)

predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

为了确保我们成功导出整个模型(图像编码器和掩码预测器)组件,我们创建了一个独立的模块 SAM2FullModel,它使用 SAM2ImagePredictor 类中的这些实用程序。SAM2FullModel 执行特征提取和掩码预测,只需一步即可完成,而不是 SAM2ImagePredictor 的两步过程(set_image 和 predict 函数)

class SAM2FullModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.image_encoder = model.forward_image
        self._prepare_backbone_features = model._prepare_backbone_features
        self.directly_add_no_mem_embed = model.directly_add_no_mem_embed
        self.no_mem_embed = model.no_mem_embed
        self._features = None

        self.prompt_encoder = model.sam_prompt_encoder
        self.mask_decoder = model.sam_mask_decoder

        self._bb_feat_sizes = [(256, 256), (128, 128), (64, 64)]

    def forward(self, image, point_coords, point_labels):
        backbone_out = self.image_encoder(image)
        _, vision_feats, _, _ = self._prepare_backbone_features(backbone_out)

        if self.directly_add_no_mem_embed:
            vision_feats[-1] = vision_feats[-1] + self.no_mem_embed

        feats = [
            feat.permute(1, 2, 0).view(1, -1, *feat_size)
            for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
        ][::-1]
        features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}

        high_res_features = [
            feat_level[-1].unsqueeze(0) for feat_level in features["high_res_feats"]
        ]

        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            points=(point_coords, point_labels), boxes=None, masks=None
        )

        low_res_masks, iou_predictions, _, _ = self.mask_decoder(
            image_embeddings=features["image_embed"][-1].unsqueeze(0),
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=point_coords.shape[0] > 1,
            high_res_features=high_res_features,
        )

        out = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}
        return out

使用预训练权重初始化 SAM2 模型

使用预训练权重初始化 SAM2FullModel。由于我们已经初始化了 SAM2ImagePredictor,我们可以直接使用其中的模型 (predictor.model)。我们将模型转换为 FP16 精度以获得更快的性能。

encoder = predictor.model.eval().cuda()
sam_model = SAM2FullModel(encoder.half()).eval().cuda()

加载存储库中提供的示例图像。

input_image = Image.open("./truck.jpg").convert("RGB")

加载输入图像

这是我们将要使用的输入图像

../../../_images/truck.jpg
input_image = Image.open("./truck.jpg").convert("RGB")

除了输入图像之外,我们还提供提示作为输入,用于预测掩码。提示可以是框、点以及来自先前预测迭代的掩码。在此演示中,我们使用点作为提示,类似于 SAM2 存储库中的原始笔记本

预处理组件

以下函数实现预处理组件,这些组件对输入图像应用转换并转换给定的点坐标。我们使用通过 SAM2ImagePredictor 类提供的 SAM2Transforms。要了解有关转换的更多信息,请参阅 https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py

def preprocess_inputs(image, predictor):
    w, h = image.size
    orig_hw = [(h, w)]
    input_image = predictor._transforms(np.array(image))[None, ...].to("cuda:0")

    point_coords = torch.tensor([[500, 375]], dtype=torch.float).to("cuda:0")
    point_labels = torch.tensor([1], dtype=torch.int).to("cuda:0")

    point_coords = torch.as_tensor(
        point_coords, dtype=torch.float, device=predictor.device
    )
    unnorm_coords = predictor._transforms.transform_coords(
        point_coords, normalize=True, orig_hw=orig_hw[0]
    )
    labels = torch.as_tensor(point_labels, dtype=torch.int, device=predictor.device)
    if len(unnorm_coords.shape) == 2:
        unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]

    input_image = input_image.half()
    unnorm_coords = unnorm_coords.half()

    return (input_image, unnorm_coords, labels)

后处理组件

以下函数实现后处理组件,包括绘制和可视化掩码和点。我们使用 SAM2Transforms 来后处理这些掩码,并通过置信度分数对它们进行排序。

def postprocess_masks(out, predictor, image):
    """Postprocess low-resolution masks and convert them for visualization."""
    orig_hw = (image.size[1], image.size[0])  # (height, width)
    masks = predictor._transforms.postprocess_masks(out["low_res_masks"], orig_hw)
    masks = (masks > 0.0).squeeze(0).cpu().numpy()
    scores = out["iou_predictions"].squeeze(0).cpu().numpy()
    sorted_indices = np.argsort(scores)[::-1]
    return masks[sorted_indices], scores[sorted_indices]


def show_mask(mask, ax, random_color=False, borders=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2

        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # Try to smooth contours
        contours = [
            cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
        ]
        mask_image = cv2.drawContours(
            mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
        )
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def visualize_masks(
    image, masks, scores, point_coords, point_labels, title_prefix="", save=True
):
    """Visualize and save masks overlaid on the original image."""
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(point_coords, point_labels, plt.gca())
        plt.title(f"{title_prefix} Mask {i + 1}, Score: {score:.3f}", fontsize=18)
        plt.axis("off")
        plt.savefig(f"{title_prefix}_output_mask_{i + 1}.png")
        plt.close()

预处理输入

预处理输入。在以下代码片段中,torchtrt_inputs 包含(input_image、unnormalized_coordinates 和 labels)。unnormalized_coordinates 是点的表示,label(= 1 在此演示中)表示前景点。

torchtrt_inputs = preprocess_inputs(input_image, predictor)

Torch-TensorRT 编译

在非严格模式下导出模型,并以 FP16 精度执行 Torch-TensorRT 编译。我们启用 FP32 矩阵乘法累加,使用 use_fp32_acc=True 以保持与原始 Pytorch 模型的精度。

exp_program = torch.export.export(sam_model, torchtrt_inputs, strict=False)
trt_model = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=torchtrt_inputs,
    min_block_size=1,
    enabled_precisions={torch.float16},
    use_fp32_acc=True,
)
trt_out = trt_model(*torchtrt_inputs)

输出可视化

后处理 Torch-TensorRT 的输出,并使用上面提供的后处理组件可视化掩码。输出应存储在当前目录中。

trt_masks, trt_scores = postprocess_masks(trt_out, predictor, input_image)
visualize_masks(
    input_image,
    trt_masks,
    trt_scores,
    torch.tensor([[500, 375]]),
    torch.tensor([1]),
    title_prefix="Torch-TRT",
)
预测的掩码如下所示
../../../_images/sam_mask1.png ../../../_images/sam_mask2.png ../../../_images/sam_mask3.png

参考文献

脚本总运行时间: ( 0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源