SAM¶
tiatoolbox.models.architecture.sam.SAM
- class SAM(model_path='facebook/sam-vit-huge', *, device='cpu')[source]¶
Segment Anything Model (SAM) Architecture.
Meta AI’s zero-shot segmentation model. SAM is used for interactive general-purpose segmentation.
Currently supports SAM.
SAM accepts an RGB image patch along with a list of point and bounding box coordinates as prompts.
- Parameters:
Examples
>>> # instantiate SAM with checkpoint path and model type >>> sam = SAM( ... model_path="facebook/sam-vit-b", ... device="cuda", ... )
Initialize
SAM.Methods
PyTorch method.
Run inference on an input batch.
Pre-processes an image - Converts it into a format accepted by SAM (HWC).
Moves the model to the specified device.
Attributes
training- forward(imgs, point_coords=None, box_coords=None)[source]¶
PyTorch method. Defines forward pass on each image in the batch.
Note: This architecture only uses a single layer, so only one forward pass is needed.
- Parameters:
- Returns:
Array of masks and scores for each image.
- Return type:
tuple[np.ndarray, np.ndarray]
- static infer_batch(model, batch_data, point_coords=None, box_coords=None, *, device='cpu')[source]¶
Run inference on an input batch.
Contains logic for forward operation as well as I/O aggregation. SAM accepts a list of points and a single bounding box per image.
- Parameters:
model (nn.Module) – PyTorch defined model.
batch_data (list) – A batch of data generated by torch.utils.data.DataLoader.
point_coords (np.ndarray | None) – Point coordinates for each image in the batch.
box_coords (np.ndarray | None) – Bounding box coordinates for each image in the batch.
device (str) – Device to run inference on.
- Returns:
Tuple of masks and scores for each image in the batch.
- Return type:
pred_info (tuple[np.ndarray, np.ndarray])