"""Define Abstract Base Class for Models defined in tiatoolbox."""from__future__importannotationsfromabcimportABC,abstractmethodfromtypingimportTYPE_CHECKING,Any,Callableimporttorchimporttorch._dynamofromtorchimportdeviceastorch_devicetorch._dynamo.config.suppress_errors=True# skipcq: PYL-W0212 # noqa: SLF001ifTYPE_CHECKING:# pragma: no coverfrompathlibimportPathimportnumpyasnp
[docs]classIOConfigABC(ABC):"""Define an abstract class for holding predictor I/O information. Enforcing such that following attributes must always be defined by the subclass. """@property@abstractmethoddefinput_resolutions(self:IOConfigABC)->None:"""Abstract method to update input_resolution."""raiseNotImplementedError@property@abstractmethoddefoutput_resolutions(self:IOConfigABC)->None:"""Abstract method to update output_resolutions."""raiseNotImplementedError
[docs]defmodel_to(model:torch.nn.Module,device:str="cpu")->torch.nn.Module:"""Transfers model to specified device e.g., "cpu" or "cuda". Args: model (torch.nn.Module): PyTorch defined model. device (str): Transfers model to the specified device. Default is "cpu". Returns: torch.nn.Module: The model after being moved to specified device. """ifdevice!="cpu":# DataParallel work only for cudamodel=torch.nn.DataParallel(model)device=torch.device(device)returnmodel.to(device)
[docs]classModelABC(ABC,torch.nn.Module):"""Abstract base class for models used in tiatoolbox."""def__init__(self:ModelABC)->None:"""Initialize Abstract class ModelABC."""super().__init__()self._postproc=self.postprocself._preproc=self.preproc
[docs]@abstractmethod# This is generic abc, else pylint will complaindefforward(self:ModelABC,*args:tuple[Any,...],**kwargs:dict)->None:"""Torch method, this contains logic for using layers defined in init."""...# pragma: no cover
[docs]@staticmethod@abstractmethoddefinfer_batch(model:torch.nn.Module,batch_data:np.ndarray,device:str,)->None:"""Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (np.ndarray): A batch of data generated by `torch.utils.data.DataLoader`. device (str): Transfers model to the specified device. Default is "cpu". Returns: dict: Returns a dictionary of predictions and other expected outputs depending on the network architecture. """...# pragma: no cover
[docs]@staticmethoddefpreproc(image:np.ndarray)->np.ndarray:"""Define the pre-processing of this class of model."""returnimage
[docs]@staticmethoddefpostproc(image:np.ndarray)->np.ndarray:"""Define the post-processing of this class of model."""returnimage
@propertydefpreproc_func(self:ModelABC)->Callable:"""Return the current pre-processing function of this instance."""returnself._preproc@preproc_func.setterdefpreproc_func(self:ModelABC,func:Callable)->None:"""Set the pre-processing function for this instance. If `func=None`, the method will default to `self.preproc`. Otherwise, `func` is expected to be callable. Examples: >>> # expected usage >>> # model is a subclass object of this ModelABC >>> # `func` is a user defined function >>> model = ModelABC() >>> model.preproc_func = func >>> transformed_img = model.preproc_func(image=np.ndarray) """iffuncisnotNoneandnotcallable(func):msg=f"{func} is not callable!"raiseValueError(msg)iffuncisNone:self._preproc=self.preprocelse:self._preproc=func@propertydefpostproc_func(self:ModelABC)->Callable:"""Return the current post-processing function of this instance."""returnself._postproc@postproc_func.setterdefpostproc_func(self:ModelABC,func:Callable)->None:"""Set the pre-processing function for this instance of model. If `func=None`, the method will default to `self.postproc`. Otherwise, `func` is expected to be callable and behave as follows: Examples: >>> # expected usage >>> # model is a subclass object of this ModelABC >>> # `func` is a user defined function >>> model = ModelABC() >>> model.postproc_func = func >>> transformed_img = model.postproc_func(image=np.ndarray) """iffuncisnotNoneandnotcallable(func):msg=f"{func} is not callable!"raiseValueError(msg)iffuncisNone:self._postproc=self.postprocelse:self._postproc=func
[docs]defto(self:ModelABC,device:str="cpu")->torch.nn.Module:"""Transfers model to cpu/gpu. Args: model (torch.nn.Module): PyTorch defined model. device (str): Transfers model to the specified device. Default is "cpu". Returns: torch.nn.Module: The model after being moved to cpu/gpu. """device=torch_device(device)model=super().to(device)# If target device istorch.cuda and more# than one GPU is available, use DataParallelifdevice.type=="cuda"andtorch.cuda.device_count()>1:model=torch.nn.DataParallel(model)# pragma: no coverreturnmodel
[docs]defload_weights_from_file(self:ModelABC,weights:str|Path)->torch.nn.Module:"""Helper function to load a torch model. Args: self (ModelABC): A torch model as :class:`ModelABC`. weights (str or Path): Path to pretrained weights. Returns: torch.nn.Module: Torch model with pretrained weights loaded on CPU. """# ! assume to be saved in single GPU mode# always load on to the CPUsaved_state_dict=torch.load(weights,map_location="cpu")returnsuper().load_state_dict(saved_state_dict,strict=True)