"""Define SAM architecture."""
from __future__ import annotations
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
from tiatoolbox.models.models_abc import ModelABC
[docs]
class SAM(ModelABC):
"""Segment Anything Model (SAM) Architecture.
Meta AI's zero-shot segmentation model.
SAM is used for interactive general-purpose segmentation.
Currently supports SAM.
SAM accepts an RGB image patch along with a list of point and bounding
box coordinates as prompts.
Args:
model_path (str):
Path to the model (huggingface).
device (str):
Device to run inference on.
Examples:
>>> # instantiate SAM with checkpoint path and model type
>>> sam = SAM(
... model_path="facebook/sam-vit-b",
... device="cuda",
... )
"""
def __init__(
self: SAM,
model_path: str = "facebook/sam-vit-huge",
*,
device: str = "cpu",
) -> None:
"""Initialize :class:`SAM`."""
super().__init__()
self.net_name = "SAM"
self.device = device
self.model = SamModel.from_pretrained(model_path).to(device)
self.processor = SamProcessor.from_pretrained(model_path)
def _process_prompts(
self: SAM,
image: np.ndarray,
embeddings: torch.Tensor,
orig_sizes: torch.Tensor,
reshaped_sizes: torch.Tensor,
points: list | None = None,
boxes: list | None = None,
point_labels: list | None = None,
) -> tuple[list, list]:
"""Process prompts and return masks and scores."""
inputs = self.processor(
image,
input_points=points,
input_labels=point_labels,
input_boxes=boxes,
return_tensors="pt",
).to(self.device)
# Replaces pixel_values with image embeddings
inputs.pop("pixel_values", None)
inputs.update(
{
"image_embeddings": embeddings,
"original_sizes": orig_sizes,
"reshaped_input_sizes": reshaped_sizes,
}
)
with torch.inference_mode():
# Forward pass through the model
outputs = self.model(**inputs, multimask_output=False)
image_masks = self.processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)
image_scores = outputs.iou_scores.cpu()
return image_masks, image_scores
[docs]
def forward( # skipcq: PYL-W0221
self: SAM,
imgs: list,
point_coords: list | None = None,
box_coords: list | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""PyTorch method. Defines forward pass on each image in the batch.
Note: This architecture only uses a single layer, so only one forward pass
is needed.
Args:
imgs (list):
List of images to process, of the shape NHWC.
point_coords (list):
List of point coordinates for each image.
box_coords (list):
Bounding box coordinates for each image.
Returns:
tuple[np.ndarray, np.ndarray]:
Array of masks and scores for each image.
"""
masks, scores = [], []
for i, img in enumerate(imgs):
image = [Image.fromarray(img)]
embeddings, orig_sizes, reshaped_sizes = self._encode_image(image)
point_labels = None
points = None
boxes = None
if box_coords is not None:
boxes = box_coords[i]
# Convert box coordinates to list
boxes = [boxes[:, None, :].tolist()]
image_masks, image_scores = self._process_prompts(
image,
embeddings,
orig_sizes,
reshaped_sizes,
None,
boxes,
point_labels,
)
masks.append(np.array([image_masks]))
scores.append(np.array([image_scores]))
if point_coords is not None:
points = point_coords[i]
# Convert point coordinates to list
point_labels = np.ones((1, len(points), 1), dtype=int).tolist()
points = [points[:, None, :].tolist()]
image_masks, image_scores = self._process_prompts(
image,
embeddings,
orig_sizes,
reshaped_sizes,
points,
None,
point_labels,
)
masks.append(np.array([image_masks]))
scores.append(np.array([image_scores]))
torch.cuda.empty_cache()
return np.concatenate(masks, axis=2), np.concatenate(scores, axis=2)
[docs]
@staticmethod
def infer_batch(
model: torch.nn.Module,
batch_data: list,
point_coords: np.ndarray | None = None,
box_coords: np.ndarray | None = None,
*,
device: str = "cpu",
) -> tuple[np.ndarray, np.ndarray]:
"""Run inference on an input batch.
Contains logic for forward operation as well as I/O aggregation.
SAM accepts a list of points and a single bounding box per image.
Args:
model (nn.Module):
PyTorch defined model.
batch_data (list):
A batch of data generated by
`torch.utils.data.DataLoader`.
point_coords (np.ndarray | None):
Point coordinates for each image in the batch.
box_coords (np.ndarray | None):
Bounding box coordinates for each image in the batch.
device (str):
Device to run inference on.
Returns:
pred_info (tuple[np.ndarray, np.ndarray]):
Tuple of masks and scores for each image in the batch.
"""
model.eval().to(device)
if point_coords is None and box_coords is None:
msg = "At least one of point_coords or box_coords must be provided."
raise ValueError(msg)
with torch.inference_mode():
masks, scores = model(batch_data, point_coords, box_coords)
return masks, scores
def _encode_image(self: SAM, image: np.ndarray) -> np.ndarray:
"""Encodes image and stores size info for later mask post-processing."""
processed = self.processor(image, return_tensors="pt")
original_sizes = processed["original_sizes"]
reshaped_sizes = processed["reshaped_input_sizes"]
inputs = processed.to(self.device)
embeddings = self.model.get_image_embeddings(inputs["pixel_values"])
return embeddings, original_sizes, reshaped_sizes
[docs]
@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Pre-processes an image - Converts it into a format accepted by SAM (HWC)."""
# Move the tensor to the CPU if it's a PyTorch tensor
if isinstance(image, torch.Tensor):
image = image.permute(1, 2, 0).cpu().numpy()
return image[..., :3] # Remove alpha channel if present
[docs]
def to(
self: ModelABC,
device: str = "cpu",
dtype: torch.dtype | None = None,
*,
non_blocking: bool = False,
) -> ModelABC | torch.nn.DataParallel[ModelABC]:
"""Moves the model to the specified device."""
super().to(device, dtype=dtype, non_blocking=non_blocking)
self.device = device
self.model.to(device)
return self