Source code for tiatoolbox.models.dataset.dataset_abc
"""Define dataset abstract classes."""from__future__importannotationsfromabcimportABC,abstractmethodfrompathlibimportPathfromtypingimportTYPE_CHECKING,Callable,UnionifTYPE_CHECKING:# pragma: no coverfromcollections.abcimportIterabletry:fromtypingimportTypeGuardexceptImportError:fromtyping_extensionsimportTypeGuard# to support python <3.10importnumpyasnpimporttorchfromtiatoolbox.utilsimportimreadinput_type=Union[list[Union[str,Path,np.ndarray]],np.ndarray]
[docs]classPatchDatasetABC(ABC,torch.utils.data.Dataset):"""Define abstract base class for patch dataset."""inputs:input_typelabels:list[int]|np.ndarraydef__init__(self:PatchDatasetABC,)->None:"""Initialize :class:`PatchDatasetABC`."""super().__init__()self._preproc=self.preprocself.data_is_npy_alike=Falseself.inputs=[]self.labels=[]@staticmethoddef_check_shape_integrity(shapes:list|np.ndarray)->None:"""Checks the integrity of input shapes. Args: shapes (list or np.ndarray): input shape to check. Raises: ValueError: If the shape is not valid. """ifany(len(v)!=3forvinshapes):# noqa: PLR2004msg="Each sample must be an array of the form HWC."raiseValueError(msg)max_shape=np.max(shapes,axis=0)if(shapes-max_shape[None]).sum()!=0:msg="Images must have the same dimensions."raiseValueError(msg)@staticmethoddef_are_paths(inputs:input_type)->TypeGuard[Iterable[Path]]:"""TypeGuard to check that input array contains only paths."""returnall(isinstance(v,(Path,str))forvininputs)@staticmethoddef_are_npy_like(inputs:input_type)->TypeGuard[Iterable[np.ndarray]]:"""TypeGuard to check that input array contains only np.ndarray."""returnall(isinstance(v,np.ndarray)forvininputs)def_check_input_integrity(self:PatchDatasetABC,mode:str)->None:"""Check that variables received during init are valid. These checks include: - Input is of a singular data type, such as a list of paths. - If it is list of images, all images are of the same height and width. """ifmode=="patch":self.data_is_npy_alike=Falsemsg=("Input must be either a list/array of images ""or a list of valid image paths.")# When a list of paths is providedifself._are_paths(self.inputs):ifany(notPath(v).exists()forvinself.inputs):# at least one of the paths are invalidraiseValueError(msg,)# Preload test for sanity checkshapes=[self.load_img(v).shapeforvinself.inputs]self.data_is_npy_alike=Falseelifself._are_npy_like(self.inputs):shapes=[v.shapeforvinself.inputs]self.data_is_npy_alike=Trueelse:raiseValueError(msg)self._check_shape_integrity(shapes)# If input is a numpy arrayifisinstance(self.inputs,np.ndarray):# Check that input array is numericalifnotnp.issubdtype(self.inputs.dtype,np.number):# ndarray of mixed data typesmsg="Provided input array is non-numerical."raiseValueError(msg)self.data_is_npy_alike=Trueelifnotisinstance(self.inputs,(list,np.ndarray)):msg="`inputs` should be a list of patch coordinates."raiseValueError(msg)
[docs]@staticmethoddefload_img(path:str|Path)->np.ndarray:"""Load an image from a provided path. Args: path (str or Path): Path to an image file. Returns: :class:`numpy.ndarray`: Image as a numpy array. """path=Path(path)ifpath.suffixnotin(".npy",".jpg",".jpeg",".tif",".tiff",".png"):msg=f"Cannot load image data from `{path.suffix}` files."raiseValueError(msg)returnimread(path,as_uint8=False)
[docs]@staticmethoddefpreproc(image:np.ndarray)->np.ndarray:"""Define the pre-processing of this class of loader."""returnimage
@propertydefpreproc_func(self:PatchDatasetABC)->Callable:"""Return the current pre-processing function of this instance. The returned function is expected to behave as follows: >>> transformed_img = func(img) """returnself._preproc@preproc_func.setterdefpreproc_func(self:PatchDatasetABC,func:Callable)->None:"""Set the pre-processing function for this instance. If `func=None`, the method will default to `self.preproc`. Otherwise, `func` is expected to be callable and behaves as follows: >>> transformed_img = func(img) """iffuncisNone:self._preproc=self.preprocelifcallable(func):self._preproc=funcelse:msg=f"{func} is not callable!"raiseValueError(msg)def__len__(self:PatchDatasetABC)->int:"""Return the length of the instance attributes."""returnlen(self.inputs)@abstractmethoddef__getitem__(self:PatchDatasetABC,idx:int)->None:"""Get an item from the dataset."""...# pragma: no cover