Source code for tiatoolbox.tools.patchextraction
"""This file defines patch extraction methods for deep learning models."""
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Tuple, Union
import numpy as np
from pandas import DataFrame
from tiatoolbox.utils import misc
from tiatoolbox.utils.exceptions import MethodNotSupported
from tiatoolbox.wsicore import wsireader
[docs]class PatchExtractorABC(ABC):
"""Abstract base class for Patch Extraction in tiatoolbox."""
@abstractmethod
def __iter__(self):
raise NotImplementedError
@abstractmethod
def __next__(self):
raise NotImplementedError
@abstractmethod
def __getitem__(self, item: int):
raise NotImplementedError
[docs]class PatchExtractor(PatchExtractorABC):
"""Class for extracting and merging patches in standard and whole-slide images.
Args:
input_img(str, pathlib.Path, :class:`numpy.ndarray`):
Input image for patch extraction.
patch_size(int or tuple(int)):
Patch size tuple (width, height).
input_mask(str, pathlib.Path, :class:`numpy.ndarray`, or :obj:`WSIReader`):
Input mask that is used for position filtering when
extracting patches i.e., patches will only be extracted
based on the highlighted regions in the input_mask.
input_mask can be either path to the mask, a numpy array,
:class:`VirtualWSIReader`, or one of 'otsu' and
'morphological' options. In case of 'otsu' or
'morphological', a tissue mask is generated for the
input_image using tiatoolbox :class:`TissueMasker`
functionality.
resolution (int or float or tuple of float):
Resolution at which to read the image, default = 0. Either a
single number or a sequence of two numbers for x and y are
valid. This value is in terms of the corresponding units.
For example: resolution=0.5 and units="mpp" will read the
slide at 0.5 microns per-pixel, and resolution=3,
units="level" will read at level at pyramid level /
resolution layer 3.
units (str):
Units of resolution, default = "level". Supported units are:
microns per pixel (mpp), objective power (power), pyramid /
resolution level (level), Only pyramid / resolution levels
(level) embedded in the whole slide image are supported.
pad_mode (str):
Method for padding at edges of the WSI. Default to
'constant'. See :func:`numpy.pad` for more information.
pad_constant_values (int or tuple(int)):
Values to use with constant padding. Defaults to 0. See
:func:`numpy.pad` for more.
within_bound (bool):
Whether to extract patches beyond the input_image size
limits. If False, extracted patches at margins will be
padded appropriately based on `pad_constant_values` and
`pad_mode`. If False, patches at the margin that their
bounds exceed the mother image dimensions would be
neglected. Default is False.
min_mask_ratio (float):
Area in percentage that a patch needs to contain of positive
mask to be included. Defaults to 0.
Attributes:
wsi(WSIReader):
Input image for patch extraction of type :obj:`WSIReader`.
patch_size(tuple(int)):
Patch size tuple (width, height).
resolution(tuple(int)):
Resolution at which to read the image.
units (str):
Units of resolution.
n (int):
Current state of the iterator.
locations_df (pd.DataFrame):
A table containing location and/or type of patches in
`(x_start, y_start, class)` format.
coordinate_list (:class:`numpy.ndarray`):
An array containing coordinates of patches in `(x_start,
y_start, x_end, y_end)` format to be used for
`slidingwindow` patch extraction.
pad_mode (str):
Method for padding at edges of the WSI. See
:func:`numpy.pad` for more information.
pad_constant_values (int or tuple(int)):
Values to use with constant padding. Defaults to 0. See
:func:`numpy.pad` for more.
stride (tuple(int)):
Stride in (x, y) direction for patch extraction. Not used
for :obj:`PointsPatchExtractor`
min_mask_ratio (float):
Only patches with positive area percentage above this value are included
"""
def __init__(
self,
input_img: Union[str, Path, np.ndarray],
patch_size: Union[int, Tuple[int, int]],
input_mask: Union[str, Path, np.ndarray, wsireader.WSIReader] = None,
resolution: Union[int, float, Tuple[float, float]] = 0,
units: str = "level",
pad_mode: str = "constant",
pad_constant_values: Union[int, Tuple[int, int]] = 0,
within_bound: bool = False,
min_mask_ratio: float = 0,
):
if isinstance(patch_size, (tuple, list)):
self.patch_size = (int(patch_size[0]), int(patch_size[1]))
else:
self.patch_size = (int(patch_size), int(patch_size))
self.resolution = resolution
self.units = units
self.pad_mode = pad_mode
self.pad_constant_values = pad_constant_values
self.n = 0
self.wsi = wsireader.WSIReader.open(input_img=input_img)
self.locations_df = None
self.coordinate_list = None
self.stride = None
self.min_mask_ratio = min_mask_ratio
if input_mask is None:
self.mask = None
elif isinstance(input_mask, str) and input_mask in {"otsu", "morphological"}:
if isinstance(self.wsi, wsireader.VirtualWSIReader):
self.mask = None
else:
self.mask = self.wsi.tissue_mask(
method=input_mask, resolution=1.25, units="power"
)
elif isinstance(input_mask, wsireader.VirtualWSIReader):
self.mask = input_mask
else:
self.mask = wsireader.VirtualWSIReader(
input_mask, info=self.wsi.info, mode="bool"
)
self.within_bound = within_bound
def __iter__(self):
self.n = 0
return self
def __next__(self):
n = self.n
if n >= self.locations_df.shape[0]:
raise StopIteration
self.n = n + 1
return self[n]
def __getitem__(self, item: int):
if not isinstance(item, int):
raise TypeError("Index should be an integer.")
if item >= self.locations_df.shape[0]:
raise IndexError
x = self.locations_df["x"][item]
y = self.locations_df["y"][item]
return self.wsi.read_rect(
location=(int(x), int(y)),
size=self.patch_size,
resolution=self.resolution,
units=self.units,
pad_mode=self.pad_mode,
pad_constant_values=self.pad_constant_values,
coord_space="resolution",
)
def _generate_location_df(self):
"""Generate location list based on slide dimension.
The slide dimension is calculated using units and resolution.
"""
slide_dimension = self.wsi.slide_dimensions(self.resolution, self.units)
self.coordinate_list = self.get_coordinates(
image_shape=(slide_dimension[0], slide_dimension[1]),
patch_input_shape=(self.patch_size[0], self.patch_size[1]),
stride_shape=(self.stride[0], self.stride[1]),
input_within_bound=self.within_bound,
)
if self.mask is not None:
# convert the coordinate_list resolution unit to acceptable units
converted_units = self.wsi.convert_resolution_units(
input_res=self.resolution,
input_unit=self.units,
)
# find the first unit which is not None
converted_units = {
k: v for k, v in converted_units.items() if v is not None
}
converted_units_keys = list(converted_units.keys())
selected_coord_indices = self.filter_coordinates_fast(
self.mask,
self.coordinate_list,
coordinate_resolution=converted_units[converted_units_keys[0]],
coordinate_units=converted_units_keys[0],
min_mask_ratio=self.min_mask_ratio,
)
self.coordinate_list = self.coordinate_list[selected_coord_indices]
if len(self.coordinate_list) == 0:
warnings.warn(
"No candidate coordinates left after "
"filtering by `input_mask` positions."
)
data = self.coordinate_list[:, :2] # only use the x_start and y_start
self.locations_df = misc.read_locations(input_table=np.array(data))
return self
[docs] @staticmethod
def filter_coordinates_fast(
mask_reader: wsireader.VirtualWSIReader,
coordinates_list: np.ndarray,
coordinate_resolution: float,
coordinate_units: str,
mask_resolution: float = None,
min_mask_ratio: float = 0,
):
"""Validate patch extraction coordinates based on the input mask.
This function indicates which coordinate is valid for mask-based
patch extraction based on checks in low resolution.
Args:
mask_reader (:class:`.VirtualReader`):
A virtual pyramidal reader of the mask related to the
WSI from which we want to extract the patches.
coordinates_list (ndarray and np.int32):
Coordinates to be checked via the `func`. They must be
at the same resolution as requested `resolution` and
`units`. The shape of `coordinates_list` is (N, K) where
N is the number of coordinate sets and K is either 2 for
centroids or 4 for bounding boxes. When using the
default `func=None`, K should be 4, as we expect the
`coordinates_list` to be bounding boxes in `[start_x,
start_y, end_x, end_y]` format.
coordinate_resolution (float):
Resolution value at which `coordinates_list` is
generated.
coordinate_units (str):
Resolution unit at which `coordinates_list` is generated.
mask_resolution (float):
Resolution at which mask array is extracted. It is
supposed to be in the same units as `coord_resolution`
i.e., `coordinate_units`. If not provided, a default
value will be selected based on `coordinate_units`.
min_mask_ratio (float):
Only patches with positive area percentage above this value are
included. Defaults to 0.
Returns:
:class:`numpy.ndarray`:
list of flags to indicate which coordinate is valid.
"""
if not isinstance(mask_reader, wsireader.VirtualWSIReader):
raise ValueError("`mask_reader` should be wsireader.VirtualWSIReader.")
if not isinstance(coordinates_list, np.ndarray) or not np.issubdtype(
coordinates_list.dtype, np.integer
):
raise ValueError("`coordinates_list` should be ndarray of integer type.")
if coordinates_list.shape[-1] != 4:
raise ValueError("`coordinates_list` must be of shape [N, 4].")
if isinstance(coordinate_resolution, (int, float)):
coordinate_resolution = [coordinate_resolution, coordinate_resolution]
if not 0 <= min_mask_ratio <= 1:
raise ValueError("`min_mask_ratio` must be between 0 and 1.")
# define default mask_resolution based on the input `coordinate_units`
if mask_resolution is None:
mask_res_dict = {"mpp": 8, "power": 1.25, "baseline": 0.03125}
mask_resolution = mask_res_dict[coordinate_units]
tissue_mask = mask_reader.slide_thumbnail(
resolution=mask_resolution, units=coordinate_units
)
# Scaling the coordinates_list to the `tissue_mask` array resolution
scaled_coords = coordinates_list.copy().astype(np.float32)
scaled_coords[:, [0, 2]] *= coordinate_resolution[0] / mask_resolution
scaled_coords[:, [0, 2]] = np.clip(
scaled_coords[:, [0, 2]], 0, tissue_mask.shape[1]
)
scaled_coords[:, [1, 3]] *= coordinate_resolution[1] / mask_resolution
scaled_coords[:, [1, 3]] = np.clip(
scaled_coords[:, [1, 3]], 0, tissue_mask.shape[0]
)
scaled_coords = list(np.int32(scaled_coords))
flag_list = []
for coord in scaled_coords:
this_part = tissue_mask[coord[1] : coord[3], coord[0] : coord[2]]
patch_area = np.prod(this_part.shape)
pos_area = np.count_nonzero(this_part)
if (
(pos_area == patch_area) or (pos_area > patch_area * min_mask_ratio)
) and (pos_area > 0 and patch_area > 0):
flag_list.append(True)
else:
flag_list.append(False)
return np.array(flag_list)
[docs] @staticmethod
def filter_coordinates(
mask_reader: wsireader.VirtualWSIReader,
coordinates_list: np.ndarray,
func: Callable = None,
resolution: float = None,
units: str = None,
):
"""Indicates which coordinate is valid for mask-based patch extraction.
Locations are validated by a custom or default filter `func`.
Args:
mask_reader (:class:`.VirtualReader`):
A virtual pyramidal reader of the mask related to the
WSI from which we want to extract the patches.
coordinates_list (ndarray and np.int32):
Coordinates to be checked via the `func`. They must be
in the same resolution as requested `resolution` and
`units`. The shape of `coordinates_list` is (N, K) where
N is the number of coordinate sets and K is either 2 for
centroids or 4 for bounding boxes. When using the
default `func=None`, K should be 4, as we expect the
`coordinates_list` to refer to bounding boxes in
`[start_x, start_y, end_x, end_y]` format.
func:
The coordinate validator function. A function that takes
`reader` and `coordinate` as arguments and return True
or False as indication of coordinate validity.
resolution (float):
The resolution value at which coordinates_list are
generated.
units (str):
The resolution unit at which coordinates_list are
generated.
Returns:
:class:`numpy.ndarray`:
List of flags to indicate which coordinates are valid.
"""
def default_sel_func(reader: wsireader.VirtualWSIReader, coord: np.ndarray):
"""Accept coord as long as its box contains bits of mask."""
roi = reader.read_bounds(
coord,
resolution=reader.info.mpp if resolution is None else resolution,
units="mpp" if units is None else units,
interpolation="nearest",
coord_space="resolution",
)
return np.sum(roi > 0) > 0
if not isinstance(mask_reader, wsireader.VirtualWSIReader):
raise ValueError("`mask_reader` should be wsireader.VirtualWSIReader.")
if not isinstance(coordinates_list, np.ndarray) or not np.issubdtype(
coordinates_list.dtype, np.integer
):
raise ValueError("`coordinates_list` should be ndarray of integer type.")
if func is None and coordinates_list.shape[-1] != 4:
raise ValueError(
f"Default `func` does not support "
f"`coordinates_list` of shape {coordinates_list.shape}."
)
func = default_sel_func if func is None else func
flag_list = [func(mask_reader, coord) for coord in coordinates_list]
return np.array(flag_list)
[docs] @staticmethod
def get_coordinates(
image_shape: Union[Tuple[int, int], np.ndarray] = None,
patch_input_shape: Union[Tuple[int, int], np.ndarray] = None,
patch_output_shape: Union[Tuple[int, int], np.ndarray] = None,
stride_shape: Union[Tuple[int, int], np.ndarray] = None,
input_within_bound: bool = False,
output_within_bound: bool = False,
):
"""Calculate patch tiling coordinates.
Args:
image_shape (tuple (int, int) or :class:`numpy.ndarray`):
This argument specifies the shape of mother image (the
image we want to extract patches from) at requested
`resolution` and `units` and it is expected to be in
(width, height) format.
patch_input_shape (tuple (int, int) or :class:`numpy.ndarray`):
Specifies the input shape of requested patches to be
extracted from mother image at desired `resolution` and
`units`. This argument is also expected to be in (width,
height) format.
patch_output_shape (tuple (int, int) or :class:`numpy.ndarray`):
Specifies the output shape of requested patches to be
extracted from mother image at desired `resolution` and
`units`. This argument is also expected to be in (width,
height) format. If this is not provided,
`patch_output_shape` will be the same as
`patch_input_shape`.
stride_shape (tuple (int, int) or :class:`numpy.ndarray`):
The stride that is used to calculate the patch location
during the patch extraction. If `patch_output_shape` is
provided, next stride location will base on the output
rather than the input.
input_within_bound (bool):
Whether to include the patches where their `input`
location exceed the margins of mother image. If `True`,
the patches with input location exceeds the
`image_shape` would be neglected. Otherwise, those
patches would be extracted with `Reader` function and
appropriate padding.
output_within_bound (bool):
Whether to include the patches where their `output`
location exceed the margins of mother image. If `True`,
the patches with output location exceeds the
`image_shape` would be neglected. Otherwise, those
patches would be extracted with `Reader` function and
appropriate padding.
Return:
coord_list:
A list of coordinates in `[start_x, start_y, end_x,
end_y]` format to be used for patch extraction.
"""
return_output_bound = patch_output_shape is not None
image_shape = np.array(image_shape)
patch_input_shape = np.array(patch_input_shape)
if patch_output_shape is None:
output_within_bound = False
patch_output_shape = patch_input_shape
patch_output_shape = np.array(patch_output_shape)
stride_shape = np.array(stride_shape)
def validate_shape(shape):
"""Tests if the shape is valid for an image."""
return (
not np.issubdtype(shape.dtype, np.integer)
or np.size(shape) > 2
or np.any(shape < 0)
)
if validate_shape(image_shape):
raise ValueError(f"Invalid `image_shape` value {image_shape}.")
if validate_shape(patch_input_shape):
raise ValueError(f"Invalid `patch_input_shape` value {patch_input_shape}.")
if validate_shape(patch_output_shape):
raise ValueError(
f"Invalid `patch_output_shape` value {patch_output_shape}."
)
if validate_shape(stride_shape):
raise ValueError(f"Invalid `stride_shape` value {stride_shape}.")
if np.any(patch_input_shape < patch_output_shape):
raise ValueError(
(
f"`patch_input_shape` must larger than `patch_output_shape`"
f" {patch_input_shape} must > {patch_output_shape}."
)
)
if np.any(stride_shape < 1):
raise ValueError(f"`stride_shape` value {stride_shape} must > 1.")
def flat_mesh_grid_coord(x, y):
"""Helper function to obtain coordinate grid."""
x, y = np.meshgrid(x, y)
return np.stack([x.flatten(), y.flatten()], axis=-1)
output_x_end = (
np.ceil(image_shape[0] / patch_output_shape[0]) * patch_output_shape[0]
)
output_x_list = np.arange(0, int(output_x_end), stride_shape[0])
output_y_end = (
np.ceil(image_shape[1] / patch_output_shape[1]) * patch_output_shape[1]
)
output_y_list = np.arange(0, int(output_y_end), stride_shape[1])
output_tl_list = flat_mesh_grid_coord(output_x_list, output_y_list)
output_br_list = output_tl_list + patch_output_shape[None]
io_diff = patch_input_shape - patch_output_shape
input_tl_list = output_tl_list - (io_diff // 2)[None]
input_br_list = input_tl_list + patch_input_shape[None]
sel = np.zeros(input_tl_list.shape[0], dtype=bool)
if output_within_bound:
sel |= np.any(output_br_list > image_shape[None], axis=1)
if input_within_bound:
sel |= np.any(input_br_list > image_shape[None], axis=1)
sel |= np.any(input_tl_list < 0, axis=1)
####
input_bound_list = np.concatenate(
[input_tl_list[~sel], input_br_list[~sel]], axis=-1
)
output_bound_list = np.concatenate(
[output_tl_list[~sel], output_br_list[~sel]], axis=-1
)
if return_output_bound:
return input_bound_list, output_bound_list
return input_bound_list
[docs]class SlidingWindowPatchExtractor(PatchExtractor):
"""Extract patches using sliding fixed sized window for images and labels.
Args:
input_img(str, pathlib.Path, :class:`numpy.ndarray`):
Input image for patch extraction.
patch_size(int or tuple(int)):
Patch size tuple (width, height).
input_mask(str, pathlib.Path, :class:`numpy.ndarray`, or :obj:`WSIReader`):
Input mask that is used for position filtering when
extracting patches i.e., patches will only be extracted
based on the highlighted regions in the `input_mask`.
`input_mask` can be either path to the mask, a numpy array,
:class:`VirtualWSIReader`, or one of 'otsu' and
'morphological' options. In case of 'otsu' or
'morphological', a tissue mask is generated for the
input_image using tiatoolbox :class:`TissueMasker`
functionality.
resolution (int or float or tuple of float):
Resolution at which to read the image, default = 0. Either a
single number or a sequence of two numbers for x and y are
valid. This value is in terms of the corresponding units.
For example: resolution=0.5 and units="mpp" will read the
slide at 0.5 microns per-pixel, and resolution=3,
units="level" will read at level at pyramid level /
resolution layer 3.
units (str):
The units of resolution, default = "level". Supported units
are: microns per pixel (mpp), objective power (power),
pyramid / resolution level (level), Only pyramid /
resolution levels (level) embedded in the whole slide image
are supported.
pad_mode (str):
Method for padding at edges of the WSI. Default to
'constant'. See :func:`numpy.pad` for more information.
pad_constant_values (int or tuple(int)):
Values to use with constant padding. Defaults to 0.
See :func:`numpy.pad` for more information.
within_bound (bool):
Whether to extract patches beyond the input_image size
limits. If False, extracted patches at margins will be
padded appropriately based on `pad_constant_values` and
`pad_mode`. If False, patches at the margin that their
bounds exceed the mother image dimensions would be
neglected. Default is False.
stride(int or tuple(int)):
Stride in (x, y) direction for patch extraction, default =
`patch_size`.
min_mask_ratio (float):
Only patches with positive area percentage above this value are included.
Defaults to 0.
Attributes:
stride(tuple(int)):
Stride in (x, y) direction for patch extraction.
"""
def __init__(
self,
input_img: Union[str, Path, np.ndarray],
patch_size: Union[int, Tuple[int, int]],
input_mask: Union[str, Path, np.ndarray, wsireader.WSIReader] = None,
resolution: Union[int, float, Tuple[float, float]] = 0,
units: str = "level",
stride: Union[int, Tuple[int, int]] = None,
pad_mode: str = "constant",
pad_constant_values: Union[int, Tuple[int, int]] = 0,
within_bound: bool = False,
min_mask_ratio: float = 0,
):
super().__init__(
input_img=input_img,
input_mask=input_mask,
patch_size=patch_size,
resolution=resolution,
units=units,
pad_mode=pad_mode,
pad_constant_values=pad_constant_values,
within_bound=within_bound,
min_mask_ratio=min_mask_ratio,
)
if stride is None:
self.stride = self.patch_size
else:
if isinstance(stride, (tuple, list)):
self.stride = (int(stride[0]), int(stride[1]))
else:
self.stride = (int(stride), int(stride))
self._generate_location_df()
[docs]class PointsPatchExtractor(PatchExtractor):
"""Extracting patches with specified points as a centre.
Args:
input_img(str, pathlib.Path, :class:`numpy.ndarray`):
Input image for patch extraction.
locations_list(ndarray, pd.DataFrame, str, pathlib.Path):
Contains location and/or type of patch. This can be path to
csv, npy or json files. Input can also be a
:class:`numpy.ndarray` or :class:`pandas.DataFrame`. NOTE:
value of location $(x,y)$ is expected to be based on the
specified `resolution` and `units` (not the `'baseline'`
resolution).
patch_size(int or tuple(int)):
Patch size tuple (width, height).
resolution (int or float or tuple of float):
Resolution at which to read the image, default = 0. Either a
single number or a sequence of two numbers for x and y are
valid. This value is in terms of the corresponding units.
For example: resolution=0.5 and units="mpp" will read the
slide at 0.5 microns per-pixel, and resolution=3,
units="level" will read at level at pyramid level /
resolution layer 3.
units (str):
The units of resolution, default = "level". Supported units
are: microns per pixel (mpp), objective power (power),
pyramid / resolution level (level), Only pyramid /
resolution levels (level) embedded in the whole slide image
are supported.
pad_mode (str):
Method for padding at edges of the WSI. Default to
'constant'. See :func:`numpy.pad` for more information.
pad_constant_values (int or tuple(int)): Values to use with
constant padding. Defaults to 0. See :func:`numpy.pad` for
more.
within_bound (bool):
Whether to extract patches beyond the input_image size
limits. If False, extracted patches at margins will be
padded appropriately based on `pad_constant_values` and
`pad_mode`. If False, patches at the margin that their
bounds exceed the mother image dimensions would be
neglected. Default is False.
"""
def __init__(
self,
input_img: Union[str, Path, np.ndarray],
locations_list: Union[np.ndarray, DataFrame, str, Path],
patch_size: Union[int, Tuple[int, int]] = (224, 224),
resolution: Union[int, float, Tuple[float, float]] = 0,
units: str = "level",
pad_mode: str = "constant",
pad_constant_values: Union[int, Tuple[int, int]] = 0,
within_bound: bool = False,
):
super().__init__(
input_img=input_img,
patch_size=patch_size,
resolution=resolution,
units=units,
pad_mode=pad_mode,
pad_constant_values=pad_constant_values,
within_bound=within_bound,
)
self.locations_df = misc.read_locations(input_table=locations_list)
self.locations_df["x"] = self.locations_df["x"] - int(
(self.patch_size[1] - 1) / 2
)
self.locations_df["y"] = self.locations_df["y"] - int(
(self.patch_size[1] - 1) / 2
)
[docs]def get_patch_extractor(method_name: str, **kwargs: str):
"""Return a patch extractor object as requested.
Args:
method_name (str):
Name of patch extraction method, must be one of "point" or
"slidingwindow". The method name is case-insensitive.
**kwargs:
Keyword arguments passed to :obj:`PatchExtractor`.
Returns:
PatchExtractor:
An object with base :obj:`PatchExtractor` as base class.
Examples:
>>> from tiatoolbox.tools.patchextraction import get_patch_extractor
>>> # PointsPatchExtractor with default values
>>> patch_extract = get_patch_extractor(
... 'point', img_patch_h=200, img_patch_w=200)
"""
if method_name.lower() not in ["point", "slidingwindow"]:
raise MethodNotSupported(
f"{method_name.lower()} method is not currently supported."
)
if method_name.lower() == "point":
return PointsPatchExtractor(**kwargs)
return SlidingWindowPatchExtractor(**kwargs)