Source code for tiatoolbox.models.models_abc

"""Defines Abstract Base Class for Models defined in tiatoolbox."""
from abc import ABC, abstractmethod

import torch.nn as nn


[docs]class IOConfigABC(ABC): """Define an abstract class for holding predictor I/O information. Enforcing such that following attributes must always be defined by the subclass. """ @property @abstractmethod def input_resolutions(self): raise NotImplementedError @property @abstractmethod def output_resolutions(self): raise NotImplementedError
[docs]class ModelABC(ABC, nn.Module): """Abstract base class for models used in tiatoolbox.""" def __init__(self): super().__init__() self._postproc = self.postproc self._preproc = self.preproc
[docs] @abstractmethod # noqa # This is generic abc, else pylint will complain def forward(self, *args, **kwargs): """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover
[docs] @staticmethod @abstractmethod def infer_batch(model, batch_data, on_gpu): """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. on_gpu (bool): Whether to run inference on a GPU. """ ... # pragma: no cover
[docs] @staticmethod def preproc(image): """Define the pre-processing of this class of model.""" return image
[docs] @staticmethod def postproc(image): """Define the post-processing of this class of model.""" return image
@property def preproc_func(self): """Return the current pre-processing function of this instance.""" return self._preproc @preproc_func.setter def preproc_func(self, func): """Set the pre-processing function for this instance. If `func=None`, the method will default to `self.preproc`. Otherwise, `func` is expected to be callable. Examples: >>> # expected usage >>> # model is a subclass object of this ModelABC >>> # `func` is a user defined function >>> model.preproc_func = func >>> transformed_img = model.preproc_func(img) """ if func is not None and not callable(func): raise ValueError(f"{func} is not callable!") if func is None: self._preproc = self.preproc else: self._preproc = func @property def postproc_func(self): """Return the current post-processing function of this instance.""" return self._postproc @postproc_func.setter def postproc_func(self, func): """Set the pre-processing function for this instance of model. If `func=None`, the method will default to `self.postproc`. Otherwise, `func` is expected to be callable and behave as follows: Examples: >>> # expected usage >>> # model is a subclass object of this ModelABC >>> # `func` is a user defined function >>> model.postproc_func = func >>> transformed_img = model.postproc_func(img) """ if func is not None and not callable(func): raise ValueError(f"{func} is not callable!") if func is None: self._postproc = self.postproc else: self._postproc = func