Source code for tiatoolbox.models.dataset.abc
# ***** BEGIN GPL LICENSE BLOCK *****
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick
# All rights reserved.
# ***** END GPL LICENSE BLOCK *****
import os
import pathlib
from abc import ABC, abstractmethod
import numpy as np
import torch
from tiatoolbox.utils.misc import imread
[docs]class PatchDatasetABC(ABC, torch.utils.data.Dataset):
"""Defines abstract base class for patch dataset.
Attributes:
return_labels (bool, False): `__getitem__` will return both the img and
its label. If `labels` is `None`, `None` is returned
preproc_func: Preprocessing function used to transform the input data. If
supplied, then torch.Compose will be used on the input preprocs.
preprocs is a list of torchvision transforms for preprocessing the
image. The transforms will be applied in the order that they are given in
the list. For more information, use the following link:
https://pytorch.org/vision/stable/transforms.html.
"""
def __init__(
self,
):
super().__init__()
self._preproc = self.preproc
self.data_is_npy_alike = False
self.inputs = []
self.labels = []
def _check_input_integrity(self, mode):
"""Perform check to make sure variables received during init are valid.
These checks include:
- Input is of a singular data type, such as a list of paths.
- If it is list of images, all images are of the same height and width
"""
if mode == "patch":
self.data_is_npy_alike = False
# If input is a list - can contain a list of images or a list of image paths
if isinstance(self.inputs, list):
is_all_paths = all(
isinstance(v, (pathlib.Path, str)) for v in self.inputs
)
is_all_npys = all(isinstance(v, np.ndarray) for v in self.inputs)
if not (is_all_paths or is_all_npys):
raise ValueError(
"Input must be either a list/array of images "
"or a list of valid image paths."
)
shapes = []
# When a list of paths is provided
if is_all_paths:
if any(not os.path.exists(v) for v in self.inputs):
# at least one of the paths are invalid
raise ValueError(
"Input must be either a list/array of images "
"or a list of valid image paths."
)
# Preload test for sanity check
shapes = [self.load_img(v).shape for v in self.inputs]
self.data_is_npy_alike = False
else:
shapes = [v.shape for v in self.inputs]
self.data_is_npy_alike = True
if any(len(v) != 3 for v in shapes):
raise ValueError("Each sample must be an array of the form HWC.")
max_shape = np.max(shapes, axis=0)
if (shapes - max_shape[None]).sum() != 0:
raise ValueError("Images must have the same dimensions.")
# If input is a numpy array
elif isinstance(self.inputs, np.ndarray):
# Check that input array is numerical
if not np.issubdtype(self.inputs.dtype, np.number):
# ndarray of mixed data types
raise ValueError("Provided input array is non-numerical.")
# N H W C | N C H W
if len(self.inputs.shape) != 4:
raise ValueError(
"Input must be an array of images of the form NHWC. This can "
"be achieved by converting a list of images to a numpy array. "
" eg., np.array([img1, img2])."
)
self.data_is_npy_alike = True
else:
raise ValueError(
"Input must be either a list/array of images "
"or a list of valid paths to image."
)
else:
if not isinstance(self.inputs, (list, np.ndarray)):
raise ValueError("inputs should be a list of patch coordinates")
[docs] @staticmethod
def load_img(path):
"""Load an image from a provided path.
Args:
path (str): Path to an image file.
"""
path = pathlib.Path(path)
if path.suffix in (".npy", ".jpg", ".jpeg", ".tif", ".tiff", ".png"):
patch = imread(path, as_uint8=False)
else:
raise ValueError(f"Can not load data of `{path.suffix}`")
return patch
[docs] @staticmethod
def preproc(image):
"""Define the pre-processing of this class of loader."""
return image
@property
def preproc_func(self):
"""Return the current pre-processing function of this instance.
The returned function is expected to behave as follows:
>>> transformed_img = func(img)
"""
return self._preproc
@preproc_func.setter
def preproc_func(self, func):
"""Set the pre-processing function for this instance.
If `func=None`, the method will default to `self.preproc`. Otherwise,
`func` is expected to be callable and behave as follows:
>>> transformed_img = func(img)
"""
if func is None:
self._preproc = self.preproc
elif callable(func):
self._preproc = func
else:
raise ValueError(f"{func} is not callable!")
def __len__(self):
return len(self.inputs)
@abstractmethod
def __getitem__(self, idx):
... # pragma: no cover