Source code for tiatoolbox.models.models_abc

"""Define Abstract Base Class for Models defined in tiatoolbox."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable

import torch
from torch import device as torch_device

if TYPE_CHECKING:  # pragma: no cover
    from pathlib import Path

    import numpy as np


[docs] class IOConfigABC(ABC): """Define an abstract class for holding predictor I/O information. Enforcing such that following attributes must always be defined by the subclass. """ @property @abstractmethod def input_resolutions(self: IOConfigABC) -> None: """Abstract method to update input_resolution.""" raise NotImplementedError @property @abstractmethod def output_resolutions(self: IOConfigABC) -> None: """Abstract method to update output_resolutions.""" raise NotImplementedError
[docs] class ModelABC(ABC, torch.nn.Module): """Abstract base class for models used in tiatoolbox.""" def __init__(self: ModelABC) -> None: """Initialize Abstract class ModelABC.""" super().__init__() self._postproc = self.postproc self._preproc = self.preproc
[docs] @abstractmethod # This is generic abc, else pylint will complain def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover
[docs] @staticmethod @abstractmethod def infer_batch( model: torch.nn.Module, batch_data: np.ndarray, *, on_gpu: bool, ) -> None: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (np.ndarray): A batch of data generated by `torch.utils.data.DataLoader`. on_gpu (bool): Whether to run inference on a GPU. """ ... # pragma: no cover
[docs] @staticmethod def preproc(image: np.ndarray) -> np.ndarray: """Define the pre-processing of this class of model.""" return image
[docs] @staticmethod def postproc(image: np.ndarray) -> np.ndarray: """Define the post-processing of this class of model.""" return image
@property def preproc_func(self: ModelABC) -> Callable: """Return the current pre-processing function of this instance.""" return self._preproc @preproc_func.setter def preproc_func(self: ModelABC, func: Callable) -> None: """Set the pre-processing function for this instance. If `func=None`, the method will default to `self.preproc`. Otherwise, `func` is expected to be callable. Examples: >>> # expected usage >>> # model is a subclass object of this ModelABC >>> # `func` is a user defined function >>> model = ModelABC() >>> model.preproc_func = func >>> transformed_img = model.preproc_func(img) """ if func is not None and not callable(func): msg = f"{func} is not callable!" raise ValueError(msg) if func is None: self._preproc = self.preproc else: self._preproc = func @property def postproc_func(self: ModelABC) -> Callable: """Return the current post-processing function of this instance.""" return self._postproc @postproc_func.setter def postproc_func(self: ModelABC, func: Callable) -> None: """Set the pre-processing function for this instance of model. If `func=None`, the method will default to `self.postproc`. Otherwise, `func` is expected to be callable and behave as follows: Examples: >>> # expected usage >>> # model is a subclass object of this ModelABC >>> # `func` is a user defined function >>> model = ModelABC() >>> model.postproc_func = func >>> transformed_img = model.postproc_func(img) """ if func is not None and not callable(func): msg = f"{func} is not callable!" raise ValueError(msg) if func is None: self._postproc = self.postproc else: self._postproc = func
[docs] def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module: """Transfers model to cpu/gpu. Args: model (torch.nn.Module): PyTorch defined model. device (str): Transfers model to the specified device. Default is "cpu". Returns: torch.nn.Module: The model after being moved to cpu/gpu. """ device = torch_device(device) model = super().to(device) # If target device istorch.cuda and more # than one GPU is available, use DataParallel if device.type == "cuda" and torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # pragma: no cover return model
[docs] def load_weights_from_file(self: ModelABC, weights: str | Path) -> torch.nn.Module: """Helper function to load a torch model. Args: self (ModelABC): A torch model as :class:`ModelABC`. weights (str or Path): Path to pretrained weights. Returns: torch.nn.Module: Torch model with pretrained weights loaded on CPU. """ # ! assume to be saved in single GPU mode # always load on to the CPU saved_state_dict = torch.load(weights, map_location="cpu") return super().load_state_dict(saved_state_dict, strict=True)