"""Define Abstract Base Class for Models defined in tiatoolbox."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable
import torch
from torch import device as torch_device
if TYPE_CHECKING: # pragma: no cover
from pathlib import Path
import numpy as np
[docs]
class IOConfigABC(ABC):
"""Define an abstract class for holding predictor I/O information.
Enforcing such that following attributes must always be defined by
the subclass.
"""
@property
@abstractmethod
def input_resolutions(self: IOConfigABC) -> None:
"""Abstract method to update input_resolution."""
raise NotImplementedError
@property
@abstractmethod
def output_resolutions(self: IOConfigABC) -> None:
"""Abstract method to update output_resolutions."""
raise NotImplementedError
[docs]
class ModelABC(ABC, torch.nn.Module):
"""Abstract base class for models used in tiatoolbox."""
def __init__(self: ModelABC) -> None:
"""Initialize Abstract class ModelABC."""
super().__init__()
self._postproc = self.postproc
self._preproc = self.preproc
[docs]
@abstractmethod
# This is generic abc, else pylint will complain
def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None:
"""Torch method, this contains logic for using layers defined in init."""
... # pragma: no cover
[docs]
@staticmethod
@abstractmethod
def infer_batch(
model: torch.nn.Module,
batch_data: np.ndarray,
*,
on_gpu: bool,
) -> None:
"""Run inference on an input batch.
Contains logic for forward operation as well as I/O aggregation.
Args:
model (nn.Module):
PyTorch defined model.
batch_data (np.ndarray):
A batch of data generated by
`torch.utils.data.DataLoader`.
on_gpu (bool):
Whether to run inference on a GPU.
"""
... # pragma: no cover
[docs]
@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Define the pre-processing of this class of model."""
return image
[docs]
@staticmethod
def postproc(image: np.ndarray) -> np.ndarray:
"""Define the post-processing of this class of model."""
return image
@property
def preproc_func(self: ModelABC) -> Callable:
"""Return the current pre-processing function of this instance."""
return self._preproc
@preproc_func.setter
def preproc_func(self: ModelABC, func: Callable) -> None:
"""Set the pre-processing function for this instance.
If `func=None`, the method will default to `self.preproc`.
Otherwise, `func` is expected to be callable.
Examples:
>>> # expected usage
>>> # model is a subclass object of this ModelABC
>>> # `func` is a user defined function
>>> model = ModelABC()
>>> model.preproc_func = func
>>> transformed_img = model.preproc_func(img)
"""
if func is not None and not callable(func):
msg = f"{func} is not callable!"
raise ValueError(msg)
if func is None:
self._preproc = self.preproc
else:
self._preproc = func
@property
def postproc_func(self: ModelABC) -> Callable:
"""Return the current post-processing function of this instance."""
return self._postproc
@postproc_func.setter
def postproc_func(self: ModelABC, func: Callable) -> None:
"""Set the pre-processing function for this instance of model.
If `func=None`, the method will default to `self.postproc`.
Otherwise, `func` is expected to be callable and behave as
follows:
Examples:
>>> # expected usage
>>> # model is a subclass object of this ModelABC
>>> # `func` is a user defined function
>>> model = ModelABC()
>>> model.postproc_func = func
>>> transformed_img = model.postproc_func(img)
"""
if func is not None and not callable(func):
msg = f"{func} is not callable!"
raise ValueError(msg)
if func is None:
self._postproc = self.postproc
else:
self._postproc = func
[docs]
def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module:
"""Transfers model to cpu/gpu.
Args:
model (torch.nn.Module):
PyTorch defined model.
device (str):
Transfers model to the specified device. Default is "cpu".
Returns:
torch.nn.Module:
The model after being moved to cpu/gpu.
"""
device = torch_device(device)
model = super().to(device)
# If target device istorch.cuda and more
# than one GPU is available, use DataParallel
if device.type == "cuda" and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model) # pragma: no cover
return model
[docs]
def load_weights_from_file(self: ModelABC, weights: str | Path) -> torch.nn.Module:
"""Helper function to load a torch model.
Args:
self (ModelABC):
A torch model as :class:`ModelABC`.
weights (str or Path):
Path to pretrained weights.
Returns:
torch.nn.Module:
Torch model with pretrained weights loaded on CPU.
"""
# ! assume to be saved in single GPU mode
# always load on to the CPU
saved_state_dict = torch.load(weights, map_location="cpu")
return super().load_state_dict(saved_state_dict, strict=True)