Source code for tiatoolbox.tools.patchextraction
# ***** BEGIN GPL LICENSE BLOCK *****
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick
# All rights reserved.
# ***** END GPL LICENSE BLOCK *****
"""This file defines patch extraction methods for deep learning models."""
from abc import ABC
import numpy as np
from tiatoolbox.utils import misc
from tiatoolbox.utils.exceptions import MethodNotSupported
from tiatoolbox.wsicore import wsireader
[docs]class PatchExtractor(ABC):
"""
Class for extracting and merging patches in standard and whole-slide images.
Args:
input_img(str, pathlib.Path, :class:`numpy.ndarray`): input image for
patch extraction.
patch_size(int or tuple(int)): patch size tuple (width, height).
input_mask(str, pathlib.Path, :class:`numpy.ndarray`, or :obj:`WSIReader`):
input mask that is used for position filtering when extracting patches
i.e., patches will only be extracted based on the highlighted regions in
the input_mask. input_mask can be either path to the mask, a numpy
array, :class:`VirtualWSIReader`, or one of 'otsu' and 'morphological'
options. In case of 'otsu' or 'morphological', a tissue mask is generated
for the input_image using tiatoolbox :class:`TissueMasker` functionality.
resolution (int or float or tuple of float): resolution at
which to read the image, default = 0. Either a single
number or a sequence of two numbers for x and y are
valid. This value is in terms of the corresponding
units. For example: resolution=0.5 and units="mpp" will
read the slide at 0.5 microns per-pixel, and
resolution=3, units="level" will read at level at
pyramid level / resolution layer 3.
units (str): the units of resolution, default = "level".
Supported units are: microns per pixel (mpp), objective
power (power), pyramid / resolution level (level),
Only pyramid / resolution levels (level) embedded in
the whole slide image are supported.
pad_mode (str): Method for padding at edges of the WSI. Default
to 'constant'. See :func:`numpy.pad` for more information.
pad_constant_values (int or tuple(int)): Values to use with
constant padding. Defaults to 0. See :func:`numpy.pad` for
more.
within_bound (bool): whether to extract patches beyond the
input_image size limits. If False, extracted patches at margins
will be padded appropriately based on `pad_constant_values` and
`pad_mode`. If False, patches at the margin that their bounds
exceed the mother image dimensions would be neglected.
Default is False.
Attributes:
wsi(WSIReader): input image for patch extraction of type :obj:`WSIReader`.
patch_size(tuple(int)): patch size tuple (width, height).
resolution(tuple(int)): resolution at which to read the image.
units (str): the units of resolution.
n(int): current state of the iterator.
locations_df(pd.DataFrame): A table containing location and/or type of patces
in `(x_start, y_start, class)` format.
coord_list(:class:`numpy.ndarray`): An array containing coordinates of patches
in `(x_start, y_start, x_end, y_end)` format to be used for `slidingwindow`
patch extraction.
pad_mode (str): Method for padding at edges of the WSI.
See :func:`numpy.pad` for more information.
pad_constant_values (int or tuple(int)): Values to use with
constant padding. Defaults to 0. See :func:`numpy.pad` for
more.
stride (tuple(int)): stride in (x, y) direction for patch extraction. Not used
for :obj:`PointsPatchExtractor`
"""
def __init__(
self,
input_img,
patch_size,
input_mask=None,
resolution=0,
units="level",
pad_mode="constant",
pad_constant_values=0,
within_bound=False,
):
if isinstance(patch_size, (tuple, list)):
self.patch_size = (int(patch_size[0]), int(patch_size[1]))
else:
self.patch_size = (int(patch_size), int(patch_size))
self.resolution = resolution
self.units = units
self.pad_mode = pad_mode
self.pad_constant_values = pad_constant_values
self.n = 0
self.wsi = wsireader.get_wsireader(input_img=input_img)
self.locations_df = None
self.coord_list = None
self.stride = None
if input_mask is None:
self.mask = None
elif isinstance(input_mask, str) and input_mask in {"otsu", "morphological"}:
if isinstance(self.wsi, wsireader.VirtualWSIReader):
self.mask = None
else:
self.mask = self.wsi.tissue_mask(
method=input_mask, resolution=1.25, units="power"
)
elif isinstance(input_mask, wsireader.VirtualWSIReader):
self.mask = input_mask
else:
self.mask = wsireader.VirtualWSIReader(
input_mask, info=self.wsi.info, mode="bool"
)
self.within_bound = within_bound
def __iter__(self):
self.n = 0
return self
def __next__(self):
n = self.n
if n >= self.locations_df.shape[0]:
raise StopIteration
self.n = n + 1
return self[n]
def __getitem__(self, item):
if type(item) is not int:
raise TypeError("Index should be an integer.")
if item >= self.locations_df.shape[0]:
raise IndexError
x = self.locations_df["x"][item]
y = self.locations_df["y"][item]
return self.wsi.read_rect(
location=(int(x), int(y)),
size=self.patch_size,
resolution=self.resolution,
units=self.units,
pad_mode=self.pad_mode,
pad_constant_values=self.pad_constant_values,
coord_space="resolution",
)
def _generate_location_df(self):
"""Generate location list based on slide dimension.
The slide dimension is calculated using units and resolution.
"""
slide_dimension = self.wsi.slide_dimensions(self.resolution, self.units)
self.coord_list = self.get_coordinates(
image_shape=(slide_dimension[0], slide_dimension[1]),
patch_input_shape=(self.patch_size[0], self.patch_size[1]),
stride_shape=(self.stride[0], self.stride[1]),
input_within_bound=self.within_bound,
)
if self.mask is not None:
# convert the coord_list resolution unit to acceptable units
converted_units = self.wsi.convert_resolution_units(
input_res=self.resolution,
input_unit=self.units,
)
# find the first unit which is not None
converted_units = {
k: v for k, v in converted_units.items() if v is not None
}
converted_units_keys = list(converted_units.keys())
selected_coord_idxs = self.filter_coordinates_fast(
self.mask,
self.coord_list,
coord_resolution=converted_units[converted_units_keys[0]],
coord_units=converted_units_keys[0],
)
self.coord_list = self.coord_list[selected_coord_idxs]
if len(self.coord_list) == 0:
raise ValueError(
"No candidate coordinates left after "
"filtering by input_mask positions."
)
data = self.coord_list[:, :2] # only use the x_start and y_start
self.locations_df = misc.read_locations(input_table=np.array(data))
return self
[docs] @staticmethod
def filter_coordinates_fast(
mask_reader,
coordinates_list,
coord_resolution,
coord_units,
mask_resolution=None,
):
"""Validate patch extraction coordinates based on the input mask.
This function indicates which coordinate is valid for mask-based
patch extraction based on checks in low resolution.
Args:
mask_reader (:class:`.VirtualReader`): a virtual pyramidal reader of the
mask related to the WSI from which we want to extract the patches.
coordinates_list (ndarray and np.int32): Coordinates to be checked
via the `func`. They must be in the same resolution as requested
`resolution` and `units`. The shape of `coordinates_list` is (N, K)
where N is the number of coordinate sets and K is either 2 for centroids
or 4 for bounding boxes. When using the default `func=None`, K should be
4, as we expect the `coordinates_list` to be refer to bounding boxes in
`[start_x, start_y, end_x, end_y]` format.
coord_resolution (float): the resolution value at which coordinates_list are
generated.
coord_resolution (str): the resolution unit at which coordinates_list are
generated.
mask_resolution (floar): resolution at which mask array is extracted. It is
supposed to be in the same units as `coord_resolution` i.e.,
`coord_units`. If not provided, a default value will be selected based on
`coord_units`.
Returns:
ndarray: list of flags to indicate which coordinate is valid.
"""
if not isinstance(mask_reader, wsireader.VirtualWSIReader):
raise ValueError("`mask_reader` should be wsireader.VirtualWSIReader.")
if not isinstance(coordinates_list, np.ndarray) or not np.issubdtype(
coordinates_list.dtype, np.integer
):
raise ValueError("`coordinates_list` should be ndarray of integer type.")
if coordinates_list.shape[-1] != 4:
raise ValueError("`coordinates_list` must be of shape [N, 4].")
if isinstance(coord_resolution, (int, float)):
coord_resolution = [coord_resolution, coord_resolution]
# define default mask_resolution based on the input coord_units
if mask_resolution is None:
mask_res_dict = {"mpp": 8, "power": 1.25, "baseline": 0.03125}
mask_resolution = mask_res_dict[coord_units]
tissue_mask = mask_reader.slide_thumbnail(
resolution=mask_resolution, units=coord_units
)
# Scaling the coordinates_list to the `tissue_mask` array resolution
scaled_coords = coordinates_list.copy().astype(np.float32)
scaled_coords[:, [0, 2]] *= coord_resolution[0] / mask_resolution
scaled_coords[:, [0, 2]] = np.clip(
scaled_coords[:, [0, 2]], 0, tissue_mask.shape[1]
)
scaled_coords[:, [1, 3]] *= coord_resolution[1] / mask_resolution
scaled_coords[:, [1, 3]] = np.clip(
scaled_coords[:, [1, 3]], 0, tissue_mask.shape[0]
)
scaled_coords = np.int32(scaled_coords)
flag_list = []
for coord in scaled_coords:
this_part = tissue_mask[coord[1] : coord[3], coord[0] : coord[2]]
flag_list.append(np.any(this_part > 0))
return np.array(flag_list)
[docs] @staticmethod
def filter_coordinates(
mask_reader, coordinates_list, func=None, resolution=None, units=None
):
"""
Indicates which coordinate is valid for mask-based patch extraction.
Locations are being validated by a custom or build-in `func`.
Args:
mask_reader (:class:`.VirtualReader`): a virtual pyramidal reader of the
mask related to the WSI from which we want to extract the patches.
coordinates_list (ndarray and np.int32): Coordinates to be checked
via the `func`. They must be in the same resolution as requested
`resolution` and `units`. The shape of `coordinates_list` is (N, K)
where N is the number of coordinate sets and K is either 2 for centroids
or 4 for bounding boxes. When using the default `func=None`, K should be
4, as we expect the `coordinates_list` to be refer to bounding boxes in
`[start_x, start_y, end_x, end_y]` format.
func: The coordinate validator function. A function that takes `reader` and
`coordinate` as arguments and return True or False as indication of
coordinate validity.
Returns:
ndarray: list of flags to indicate which coordinate is valid.
"""
def default_sel_func(reader: wsireader.VirtualWSIReader, coord: np.ndarray):
"""Accept coord as long as its box contains bits of mask."""
roi = reader.read_bounds(
coord,
resolution=reader.info.mpp if resolution is None else resolution,
units="mpp" if units is None else units,
interpolation="nearest",
coord_space="resolution",
)
return np.sum(roi > 0) > 0
if not isinstance(mask_reader, wsireader.VirtualWSIReader):
raise ValueError("`mask_reader` should be wsireader.VirtualWSIReader.")
if not isinstance(coordinates_list, np.ndarray) or not np.issubdtype(
coordinates_list.dtype, np.integer
):
raise ValueError("`coordinates_list` should be ndarray of integer type.")
if func is None and coordinates_list.shape[-1] != 4:
raise ValueError(
"Default `func` does not support "
"`coordinates_list` of shape {}.".format(coordinates_list.shape)
)
func = default_sel_func if func is None else func
flag_list = [func(mask_reader, coord) for coord in coordinates_list]
return np.array(flag_list)
[docs] @staticmethod
def get_coordinates(
image_shape=None,
patch_input_shape=None,
patch_output_shape=None,
stride_shape=None,
input_within_bound=False,
output_within_bound=False,
):
"""Calculate patch tiling coordinates.
Args:
image_shape (a tuple (int, int) or :class:`numpy.ndarray` of shape (2,)):
This argument specifies the shape of mother image (the image we want to)
extract patches from) at requested `resolution` and `units` and it is
expected to be in (width, height) format.
patch_input_shape (a tuple (int, int) or
:class:`numpy.ndarray` of shape (2,)): Specifies the input shape of
requested patches to be extracted from mother image at desired
`resolution` and `units`. This argument is also expected to be in
(width, height) format.
patch_output_shape (a tuple (int, int) or
:class:`numpy.ndarray` of shape (2,)): Specifies the output shape of
requested patches to be extracted from mother image at desired
`resolution` and `units`. This argument is also expected to be in
(width, height) format. If this is not provided, `patch_output_shape`
will be the same as `patch_input_shape`.
stride_shape (a tuple (int, int) or :class:`numpy.ndarray` of shape (2,)):
The stride that is used to calcualte the patch location during the patch
extraction. If `patch_output_shape` is provided, next stride location
will base on the output rather than the input.
input_within_bound (bool): Whether to include the patches where their
`input` location exceed the margins of mother image. If `True`, the
patches with input location exceeds the `image_shape` would be
neglected. Otherwise, those patches would be extracted with `Reader`
function and appropriate padding.
output_within_bound (bool): Whether to include the patches where their
`output` location exceed the margins of mother image. If `True`, the
patches with output location exceeds the `image_shape` would be
neglected. Otherwise, those patches would be extracted with `Reader`
function and appropriate padding.
Return:
coord_list: a list of corrdinates in `[start_x, start_y, end_x, end_y]`
format to be used for patch extraction.
"""
return_output_bound = patch_output_shape is not None
image_shape = np.array(image_shape)
patch_input_shape = np.array(patch_input_shape)
if patch_output_shape is None:
output_within_bound = False
patch_output_shape = patch_input_shape
patch_output_shape = np.array(patch_output_shape)
stride_shape = np.array(stride_shape)
def validate_shape(shape):
return (
not np.issubdtype(shape.dtype, np.integer)
or np.size(shape) > 2
or np.any(shape < 0)
)
if validate_shape(image_shape):
raise ValueError(f"Invalid `image_shape` value {image_shape}.")
if validate_shape(patch_input_shape):
raise ValueError(f"Invalid `patch_input_shape` value {patch_input_shape}.")
if validate_shape(patch_output_shape):
raise ValueError(
f"Invalid `patch_output_shape` value {patch_output_shape}."
)
if validate_shape(stride_shape):
raise ValueError(f"Invalid `stride_shape` value {stride_shape}.")
if np.any(patch_input_shape < patch_output_shape):
raise ValueError(
(
f"`patch_input_shape` must larger than `patch_output_shape`"
f" {patch_input_shape} must > {patch_output_shape}."
)
)
if np.any(stride_shape < 1):
raise ValueError(f"`stride_shape` value {stride_shape} must > 1.")
def flat_mesh_grid_coord(x, y):
"""Helper function to obtain coordinate grid."""
x, y = np.meshgrid(x, y)
return np.stack([x.flatten(), y.flatten()], axis=-1)
output_x_end = (
np.ceil(image_shape[0] / patch_output_shape[0]) * patch_output_shape[0]
)
output_x_list = np.arange(0, int(output_x_end), stride_shape[0])
output_y_end = (
np.ceil(image_shape[1] / patch_output_shape[1]) * patch_output_shape[1]
)
output_y_list = np.arange(0, int(output_y_end), stride_shape[1])
output_tl_list = flat_mesh_grid_coord(output_x_list, output_y_list)
output_br_list = output_tl_list + patch_output_shape[None]
io_diff = patch_input_shape - patch_output_shape
input_tl_list = output_tl_list - (io_diff // 2)[None]
input_br_list = input_tl_list + patch_input_shape[None]
sel = np.zeros(input_tl_list.shape[0], dtype=bool)
if output_within_bound:
sel |= np.any(output_br_list > image_shape[None], axis=1)
if input_within_bound:
sel |= np.any(input_br_list > image_shape[None], axis=1)
sel |= np.any(input_tl_list < 0, axis=1)
####
input_bound_list = np.concatenate(
[input_tl_list[~sel], input_br_list[~sel]], axis=-1
)
output_bound_list = np.concatenate(
[output_tl_list[~sel], output_br_list[~sel]], axis=-1
)
if return_output_bound:
return input_bound_list, output_bound_list
return input_bound_list
[docs]class SlidingWindowPatchExtractor(PatchExtractor):
"""Extract patches using sliding fixed sized window for images and labels.
Args:
stride(int or tuple(int)): stride in (x, y) direction for patch extraction,
default = patch_size
Attributes:
stride(tuple(int)): stride in (x, y) direction for patch extraction.
"""
def __init__(
self,
input_img,
patch_size,
input_mask=None,
resolution=0,
units="level",
stride=None,
pad_mode="constant",
pad_constant_values=0,
within_bound=False,
):
super().__init__(
input_img=input_img,
input_mask=input_mask,
patch_size=patch_size,
resolution=resolution,
units=units,
pad_mode=pad_mode,
pad_constant_values=pad_constant_values,
within_bound=within_bound,
)
if stride is None:
self.stride = self.patch_size
else:
if isinstance(stride, (tuple, list)):
self.stride = (int(stride[0]), int(stride[1]))
else:
self.stride = (int(stride), int(stride))
self._generate_location_df()
[docs]class PointsPatchExtractor(PatchExtractor):
"""Extracting patches with specified points as a centre.
Args:
locations_list(ndarray, pd.DataFrame, str, pathlib.Path): contains location
and/or type of patch. This can be path to csv, npy or json files. Input can
also be a :class:`numpy.ndarray` or :class:`pandas.DataFrame`.
NOTE: value of location $(x,y)$ is expected to be based on the specified
`resolution` and `units` (not the `'baseline'` resolution).
"""
def __init__(
self,
input_img,
locations_list,
patch_size=(224, 224),
resolution=0,
units="level",
pad_mode="constant",
pad_constant_values=0,
within_bound=False,
):
super().__init__(
input_img=input_img,
patch_size=patch_size,
resolution=resolution,
units=units,
pad_mode=pad_mode,
pad_constant_values=pad_constant_values,
within_bound=within_bound,
)
self.locations_df = misc.read_locations(input_table=locations_list)
self.locations_df["x"] = self.locations_df["x"] - int(
(self.patch_size[1] - 1) / 2
)
self.locations_df["y"] = self.locations_df["y"] - int(
(self.patch_size[1] - 1) / 2
)
[docs]def get_patch_extractor(method_name, **kwargs):
"""Return a patch extractor object as requested.
Args:
method_name (str): name of patch extraction method, must be one of "point" or
"slidingwindow".
**kwargs: Keyword arguments passed to :obj:`PatchExtractor`.
Returns:
PatchExtractor: an object with base :obj:`PatchExtractor` as base class.
Examples:
>>> from tiatoolbox.tools.patchextraction import get_patch_extractor
>>> # PointsPatchExtractor with default values
>>> patch_extract = get_patch_extractor(
... 'point', img_patch_h=200, img_patch_w=200)
"""
if method_name.lower() == "point":
patch_extractor = PointsPatchExtractor(**kwargs)
elif method_name.lower() == "slidingwindow":
patch_extractor = SlidingWindowPatchExtractor(**kwargs)
else:
raise MethodNotSupported
return patch_extractor