Source code for tiatoolbox.models.engine.prompt_segmentor

"""This module enables interactive segmentation."""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np

from tiatoolbox.models.architecture.sam import SAM
from tiatoolbox.utils.misc import dict_to_store_semantic_segmentor

if TYPE_CHECKING:  # pragma: no cover
    import torch

    from tiatoolbox.type_hints import IntPair


[docs] class PromptSegmentor: """Engine for prompt-based segmentation of WSIs. This class is designed to work with the SAM model architecture. It allows for interactive segmentation by providing point and bounding box coordinates as prompts. The model is intended to be used with image tiles selected interactively in some way and provided as np.arrays. At least one of either point_coords or box_coords must be provided to guide segmentation. Args: model (SAM): Model architecture to use. If None, defaults to SAM. """ def __init__( self, model: torch.nn.Module = None, ) -> None: """Initializes the PromptSegmentor.""" model = SAM() if model is None else model self.model = model self.scale = 1.0 self.offset = np.array([0, 0])
[docs] def run( # skipcq: PYL-W0221 self, images: list, point_coords: np.ndarray | None = None, box_coords: np.ndarray | None = None, save_dir: str | Path | None = None, device: str = "cpu", ) -> list[Path]: """Run inference on image patches with prompts. Args: images (list): List of image patch arrays to run inference on. point_coords (np.ndarray): N_im x N_points x 2 array of point coordinates for each image patch. box_coords (np.ndarray): N_im x N_boxes x 4 array of bounding box coordinates for each image patch. save_dir (str or Path): Directory to save the output databases. device (str): Device to run inference on. Returns: list[Path]: Paths to the saved output databases. """ paths = [] masks, _ = self.model.infer_batch( self.model, images, point_coords=point_coords, box_coords=box_coords, device=device, ) save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) for i, _mask in enumerate(masks): mask = np.any(_mask[0], axis=0, keepdims=False) dict_to_store_semantic_segmentor( patch_output={"predictions": mask[0]}, scale_factor=(self.scale, self.scale), offset=self.offset, save_path=Path(f"{save_dir}/{i}.db"), output_type="annotationstore", ignore_index=0, ) paths.append(Path(f"{save_dir}/{i}.db")) return paths
[docs] def calc_mpp( self, area_dims: IntPair, base_mpp: float, fixed_size: int = 1500 ) -> tuple[float, float]: """Calculates the microns per pixel for a fixed area of an image. Args: area_dims (tuple): Dimensions of the area to be scaled. base_mpp (float): Microns per pixel of the base image. fixed_size (int): Fixed size of the area. Returns: tuple[float, float]: Tuple of the scaled mpp and the scale factor. """ scale = max(area_dims) / fixed_size if max(area_dims) > fixed_size else 1.0 self.scale = scale return base_mpp * scale, scale