• 文件 >
  • 使用 dynamo 後端編譯 SAM2
快捷方式

使用 dynamo 後端編譯 SAM2

此示例展示了使用 Torch-TensorRT 最佳化的最先進模型 Segment Anything Model 2 (SAM2)

Segment Anything Model 2 是一個基礎模型,旨在解決影像和影片中的可提示視覺分割問題。在編譯之前安裝以下依賴項

pip install -r requirements.txt

需要進行某些自定義修改以確保模型成功匯出。要應用這些更改,請使用 以下分支 安裝 SAM2 (安裝說明)

在自定義的 SAM2 分支中,已應用以下修改來移除圖中斷並增強延遲效能,從而確保更高效的 Torch-TRT 轉換

  • 資料型別一致性: 保留輸入張量 dtype,移除強制 FP32 轉換。

  • 掩碼操作: 使用基於掩碼的索引而不是直接選擇資料,提高了 Torch-TRT 的相容性。

  • 安全初始化: 有條件地初始化張量,而不是連線到空張量。

  • 標準函式: 避免使用特殊上下文和自定義 LayerNorm,依賴於 PyTorch 內建函式以獲得更好的穩定性。

匯入以下庫

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_tensorrt
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam_components import SAM2FullModel

matplotlib.use("Agg")

定義 SAM2 模型

使用 SAM2ImagePredictor 類載入 facebook/sam2-hiera-large 預訓練模型。SAM2ImagePredictor 提供用於預處理影像、儲存影像特徵(透過 set_image 函式)和預測掩碼(透過 predict 函式)的實用工具

predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

為確保成功匯出整個模型(影像編碼器和掩碼預測器)元件,我們建立了一個獨立的模組 SAM2FullModel,該模組使用來自 SAM2ImagePredictor 類的這些實用工具。SAM2FullModel 在一個步驟中執行特徵提取和掩碼預測,而不是 SAM2ImagePredictor 的兩步過程(set_image 和 predict 函式)

class SAM2FullModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.image_encoder = model.forward_image
        self._prepare_backbone_features = model._prepare_backbone_features
        self.directly_add_no_mem_embed = model.directly_add_no_mem_embed
        self.no_mem_embed = model.no_mem_embed
        self._features = None

        self.prompt_encoder = model.sam_prompt_encoder
        self.mask_decoder = model.sam_mask_decoder

        self._bb_feat_sizes = [(256, 256), (128, 128), (64, 64)]

    def forward(self, image, point_coords, point_labels):
        backbone_out = self.image_encoder(image)
        _, vision_feats, _, _ = self._prepare_backbone_features(backbone_out)

        if self.directly_add_no_mem_embed:
            vision_feats[-1] = vision_feats[-1] + self.no_mem_embed

        feats = [
            feat.permute(1, 2, 0).view(1, -1, *feat_size)
            for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
        ][::-1]
        features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}

        high_res_features = [
            feat_level[-1].unsqueeze(0) for feat_level in features["high_res_feats"]
        ]

        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            points=(point_coords, point_labels), boxes=None, masks=None
        )

        low_res_masks, iou_predictions, _, _ = self.mask_decoder(
            image_embeddings=features["image_embed"][-1].unsqueeze(0),
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=point_coords.shape[0] > 1,
            high_res_features=high_res_features,
        )

        out = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}
        return out

使用預訓練權重初始化 SAM2 模型

使用預訓練權重初始化 SAM2FullModel。由於我們已經初始化了 SAM2ImagePredictor,我們可以直接使用其中的模型 (predictor.model)。我們將模型轉換為 FP16 精度以獲得更快的效能。

encoder = predictor.model.eval().cuda()
sam_model = SAM2FullModel(encoder.half()).eval().cuda()

載入倉庫中提供的示例影像。

input_image = Image.open("./truck.jpg").convert("RGB")

載入輸入影像

這是我們將要使用的輸入影像

../../../_images/truck.jpg
input_image = Image.open("./truck.jpg").convert("RGB")

除了輸入影像,我們還提供用作預測掩碼的提示。提示可以是框、點以及前一次預測迭代中的掩碼。在此演示中,我們使用點作為提示,這與 SAM2 倉庫中的原始 notebook 類似

預處理元件

以下函式實現了預處理元件,這些元件對輸入影像應用變換並轉換給定點座標。我們使用透過 SAM2ImagePredictor 類提供的 SAM2Transforms。要了解更多關於變換的資訊,請參閱 https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py

def preprocess_inputs(image, predictor):
    w, h = image.size
    orig_hw = [(h, w)]
    input_image = predictor._transforms(np.array(image))[None, ...].to("cuda:0")

    point_coords = torch.tensor([[500, 375]], dtype=torch.float).to("cuda:0")
    point_labels = torch.tensor([1], dtype=torch.int).to("cuda:0")

    point_coords = torch.as_tensor(
        point_coords, dtype=torch.float, device=predictor.device
    )
    unnorm_coords = predictor._transforms.transform_coords(
        point_coords, normalize=True, orig_hw=orig_hw[0]
    )
    labels = torch.as_tensor(point_labels, dtype=torch.int, device=predictor.device)
    if len(unnorm_coords.shape) == 2:
        unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]

    input_image = input_image.half()
    unnorm_coords = unnorm_coords.half()

    return (input_image, unnorm_coords, labels)

後處理元件

以下函式實現了後處理元件,包括繪製和視覺化掩碼及點。我們使用 SAM2Transforms 對這些掩碼進行後處理,並根據置信度分數進行排序。

def postprocess_masks(out, predictor, image):
    """Postprocess low-resolution masks and convert them for visualization."""
    orig_hw = (image.size[1], image.size[0])  # (height, width)
    masks = predictor._transforms.postprocess_masks(out["low_res_masks"], orig_hw)
    masks = (masks > 0.0).squeeze(0).cpu().numpy()
    scores = out["iou_predictions"].squeeze(0).cpu().numpy()
    sorted_indices = np.argsort(scores)[::-1]
    return masks[sorted_indices], scores[sorted_indices]


def show_mask(mask, ax, random_color=False, borders=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2

        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # Try to smooth contours
        contours = [
            cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
        ]
        mask_image = cv2.drawContours(
            mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
        )
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def visualize_masks(
    image, masks, scores, point_coords, point_labels, title_prefix="", save=True
):
    """Visualize and save masks overlaid on the original image."""
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(point_coords, point_labels, plt.gca())
        plt.title(f"{title_prefix} Mask {i + 1}, Score: {score:.3f}", fontsize=18)
        plt.axis("off")
        plt.savefig(f"{title_prefix}_output_mask_{i + 1}.png")
        plt.close()

預處理輸入

預處理輸入。在以下程式碼片段中,torchtrt_inputs 包含 (input_image, unnormalized_coordinates 和 labels)。unnormalized_coordinates 是點的表示,而 label(在此演示中為 1)表示前景點。

torchtrt_inputs = preprocess_inputs(input_image, predictor)

Torch-TensorRT 編譯

以非嚴格模式匯出模型,並以 FP16 精度執行 Torch-TensorRT 編譯。我們透過啟用 use_fp32_acc=True 來使用 FP32 矩陣乘法累加,以保持與原始 PyTorch 模型相同的精度。

exp_program = torch.export.export(sam_model, torchtrt_inputs, strict=False)
trt_model = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=torchtrt_inputs,
    min_block_size=1,
    enabled_precisions={torch.float16},
    use_fp32_acc=True,
)
trt_out = trt_model(*torchtrt_inputs)

輸出視覺化

對 Torch-TensorRT 的輸出進行後處理,並使用上面提供的後處理元件視覺化掩碼。輸出應儲存在當前目錄中。

trt_masks, trt_scores = postprocess_masks(trt_out, predictor, input_image)
visualize_masks(
    input_image,
    trt_masks,
    trt_scores,
    torch.tensor([[500, 375]]),
    torch.tensor([1]),
    title_prefix="Torch-TRT",
)
預測的掩碼如下所示
../../../_images/sam_mask1.png ../../../_images/sam_mask2.png ../../../_images/sam_mask3.png

參考資料

指令碼總執行時間: ( 0 分 0.000 秒)

由 Sphinx-Gallery 生成的相簿

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源