SAM¶

class SAM(model_path='facebook/sam-vit-huge', *, device='cpu')[source]¶

Segment Anything Model (SAM) Architecture.

Meta AI’s zero-shot segmentation model. SAM is used for interactive general-purpose segmentation.

Currently supports SAM.

SAM accepts an RGB image patch along with a list of point and bounding box coordinates as prompts.

Parameters:
  • model_path (str) – Path to the model (huggingface).

  • device (str) – Device to run inference on.

Examples

>>> # instantiate SAM with checkpoint path and model type
>>> sam = SAM(
...     model_path="facebook/sam-vit-b",
...     device="cuda",
... )

Initialize SAM.

Methods

forward

PyTorch method.

infer_batch

Run inference on an input batch.

preproc

Pre-processes an image - Converts it into a format accepted by SAM (HWC).

to

Moves the model to the specified device.

Attributes

training

forward(imgs, point_coords=None, box_coords=None)[source]¶

PyTorch method. Defines forward pass on each image in the batch.

Note: This architecture only uses a single layer, so only one forward pass is needed.

Parameters:
  • imgs (list) – List of images to process, of the shape NHWC.

  • point_coords (list) – List of point coordinates for each image.

  • box_coords (list) – Bounding box coordinates for each image.

  • self (SAM)

Returns:

Array of masks and scores for each image.

Return type:

tuple[np.ndarray, np.ndarray]

static infer_batch(model, batch_data, point_coords=None, box_coords=None, *, device='cpu')[source]¶

Run inference on an input batch.

Contains logic for forward operation as well as I/O aggregation. SAM accepts a list of points and a single bounding box per image.

Parameters:
  • model (nn.Module) – PyTorch defined model.

  • batch_data (list) – A batch of data generated by torch.utils.data.DataLoader.

  • point_coords (np.ndarray | None) – Point coordinates for each image in the batch.

  • box_coords (np.ndarray | None) – Bounding box coordinates for each image in the batch.

  • device (str) – Device to run inference on.

Returns:

Tuple of masks and scores for each image in the batch.

Return type:

pred_info (tuple[np.ndarray, np.ndarray])

static preproc(image)[source]¶

Pre-processes an image - Converts it into a format accepted by SAM (HWC).

Parameters:

image (ndarray)

Return type:

ndarray

to(device='cpu', dtype=None, *, non_blocking=False)[source]¶

Moves the model to the specified device.

Parameters:
Return type:

ModelABC | DataParallel[ModelABC]