Source code for tiatoolbox.models.architecture.sam

"""Define SAM architecture."""

from __future__ import annotations

import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor

from tiatoolbox.models.models_abc import ModelABC


[docs] class SAM(ModelABC): """Segment Anything Model (SAM) Architecture. Meta AI's zero-shot segmentation model. SAM is used for interactive general-purpose segmentation. Currently supports SAM. SAM accepts an RGB image patch along with a list of point and bounding box coordinates as prompts. Args: model_path (str): Path to the model (huggingface). device (str): Device to run inference on. Examples: >>> # instantiate SAM with checkpoint path and model type >>> sam = SAM( ... model_path="facebook/sam-vit-b", ... device="cuda", ... ) """ def __init__( self: SAM, model_path: str = "facebook/sam-vit-huge", *, device: str = "cpu", ) -> None: """Initialize :class:`SAM`.""" super().__init__() self.net_name = "SAM" self.device = device self.model = SamModel.from_pretrained(model_path).to(device) self.processor = SamProcessor.from_pretrained(model_path) def _process_prompts( self: SAM, image: np.ndarray, embeddings: torch.Tensor, orig_sizes: torch.Tensor, reshaped_sizes: torch.Tensor, points: list | None = None, boxes: list | None = None, point_labels: list | None = None, ) -> tuple[list, list]: """Process prompts and return masks and scores.""" inputs = self.processor( image, input_points=points, input_labels=point_labels, input_boxes=boxes, return_tensors="pt", ).to(self.device) # Replaces pixel_values with image embeddings inputs.pop("pixel_values", None) inputs.update( { "image_embeddings": embeddings, "original_sizes": orig_sizes, "reshaped_input_sizes": reshaped_sizes, } ) with torch.inference_mode(): # Forward pass through the model outputs = self.model(**inputs, multimask_output=False) image_masks = self.processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu(), ) image_scores = outputs.iou_scores.cpu() return image_masks, image_scores
[docs] def forward( # skipcq: PYL-W0221 self: SAM, imgs: list, point_coords: list | None = None, box_coords: list | None = None, ) -> tuple[np.ndarray, np.ndarray]: """PyTorch method. Defines forward pass on each image in the batch. Note: This architecture only uses a single layer, so only one forward pass is needed. Args: imgs (list): List of images to process, of the shape NHWC. point_coords (list): List of point coordinates for each image. box_coords (list): Bounding box coordinates for each image. Returns: tuple[np.ndarray, np.ndarray]: Array of masks and scores for each image. """ masks, scores = [], [] for i, img in enumerate(imgs): image = [Image.fromarray(img)] embeddings, orig_sizes, reshaped_sizes = self._encode_image(image) point_labels = None points = None boxes = None if box_coords is not None: boxes = box_coords[i] # Convert box coordinates to list boxes = [boxes[:, None, :].tolist()] image_masks, image_scores = self._process_prompts( image, embeddings, orig_sizes, reshaped_sizes, None, boxes, point_labels, ) masks.append(np.array([image_masks])) scores.append(np.array([image_scores])) if point_coords is not None: points = point_coords[i] # Convert point coordinates to list point_labels = np.ones((1, len(points), 1), dtype=int).tolist() points = [points[:, None, :].tolist()] image_masks, image_scores = self._process_prompts( image, embeddings, orig_sizes, reshaped_sizes, points, None, point_labels, ) masks.append(np.array([image_masks])) scores.append(np.array([image_scores])) torch.cuda.empty_cache() return np.concatenate(masks, axis=2), np.concatenate(scores, axis=2)
[docs] @staticmethod def infer_batch( model: torch.nn.Module, batch_data: list, point_coords: np.ndarray | None = None, box_coords: np.ndarray | None = None, *, device: str = "cpu", ) -> tuple[np.ndarray, np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. SAM accepts a list of points and a single bounding box per image. Args: model (nn.Module): PyTorch defined model. batch_data (list): A batch of data generated by `torch.utils.data.DataLoader`. point_coords (np.ndarray | None): Point coordinates for each image in the batch. box_coords (np.ndarray | None): Bounding box coordinates for each image in the batch. device (str): Device to run inference on. Returns: pred_info (tuple[np.ndarray, np.ndarray]): Tuple of masks and scores for each image in the batch. """ model.eval().to(device) if point_coords is None and box_coords is None: msg = "At least one of point_coords or box_coords must be provided." raise ValueError(msg) with torch.inference_mode(): masks, scores = model(batch_data, point_coords, box_coords) return masks, scores
def _encode_image(self: SAM, image: np.ndarray) -> np.ndarray: """Encodes image and stores size info for later mask post-processing.""" processed = self.processor(image, return_tensors="pt") original_sizes = processed["original_sizes"] reshaped_sizes = processed["reshaped_input_sizes"] inputs = processed.to(self.device) embeddings = self.model.get_image_embeddings(inputs["pixel_values"]) return embeddings, original_sizes, reshaped_sizes
[docs] @staticmethod def preproc(image: np.ndarray) -> np.ndarray: """Pre-processes an image - Converts it into a format accepted by SAM (HWC).""" # Move the tensor to the CPU if it's a PyTorch tensor if isinstance(image, torch.Tensor): image = image.permute(1, 2, 0).cpu().numpy() return image[..., :3] # Remove alpha channel if present
[docs] def to( self: ModelABC, device: str = "cpu", dtype: torch.dtype | None = None, *, non_blocking: bool = False, ) -> ModelABC | torch.nn.DataParallel[ModelABC]: """Moves the model to the specified device.""" super().to(device, dtype=dtype, non_blocking=non_blocking) self.device = device self.model.to(device) return self