Source code for tiatoolbox.models.abc

# ***** BEGIN GPL LICENSE BLOCK *****
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick
# All rights reserved.
# ***** END GPL LICENSE BLOCK *****

"""Defines Abstract Base Class for Models defined in tiatoolbox."""
from abc import ABC, abstractmethod

import torch.nn as nn


[docs]class IOConfigABC(ABC): """Define an abstract class for holding a predictor input output information. Enforcing such that following attributes must always be defined by the subclass. Attributes: input_resolutions (list): Define the resolution of each input, incase the predictor receives variable input. Must be in the same order as network input. units (dict): Define the resolution of each output, incase the predictor return variable output.Must be in the same order as network output. """ @property @abstractmethod def input_resolutions(self): raise NotImplementedError @property @abstractmethod def output_resolutions(self): raise NotImplementedError
[docs]class ModelABC(ABC, nn.Module): """Abstract base class for models used in tiatoolbox.""" def __init__(self): super().__init__() self._postproc = self.postproc self._preproc = self.preproc
[docs] @abstractmethod # noqa # This is generic abc, else pylint will complain def forward(self, *args, **kwargs): """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover
[docs] @staticmethod @abstractmethod def infer_batch(model, batch_data, on_gpu): """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (ndarray): a batch of data generated by torch.utils.data.DataLoader. on_gpu (bool): Whether to run inference on a GPU. """ ... # pragma: no cover
[docs] @staticmethod def preproc(image): """Define the pre-processing of this class of model.""" return image
[docs] @staticmethod def postproc(image): """Define the post-processing of this class of model.""" return image
@property def preproc_func(self): """Return the current pre-processing function of this instance.""" return self._preproc @preproc_func.setter def preproc_func(self, func): """Set the pre-processing function for this instance. If `func=None`, the method will default to `self.preproc`. Otherwise, `func` is expected to be callable. Examples: >>> # expected usage >>> # model is a subclass object of this ModelABC >>> model.preproc_func = func # `func` is an user defined function >>> transformed_img = model.preproc_func(img) """ if func is not None and not callable(func): raise ValueError(f"{func} is not callable!") if func is None: self._preproc = self.preproc else: self._preproc = func @property def postproc_func(self): """Return the current post-processing function of this instance.""" return self._postproc @postproc_func.setter def postproc_func(self, func): """Set the pre-processing function for this instance of model. If `func=None`, the method will default to `self.postproc`. Otherwise, `func` is expected to be callable and behave as follows: Examples: >>> # expected usage >>> # model is a subclass object of this ModelABC >>> model.postproc_func = func # `func` is an user defined function >>> transformed_img = model.postproc_func(img) """ if func is not None and not callable(func): raise ValueError(f"{func} is not callable!") if func is None: self._postproc = self.postproc else: self._postproc = func