# ***** BEGIN GPL LICENSE BLOCK *****
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick
# All rights reserved.
# ***** END GPL LICENSE BLOCK *****
"""This module implements patch-level prediction."""
import copy
import os
import pathlib
import warnings
from collections import OrderedDict
from typing import Callable, Tuple, Union
import numpy as np
import torch
import tqdm
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset
from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig
from tiatoolbox.utils import misc
from tiatoolbox.utils.misc import save_as_json
from tiatoolbox.wsicore.wsireader import VirtualWSIReader, get_wsireader
[docs]class IOPatchPredictorConfig(IOSegmentorConfig):
"""Contain patch predictor input and output information."""
def __init__(
self,
patch_input_shape=None,
input_resolutions=None,
stride_shape=None,
**kwargs,
):
stride_shape = patch_input_shape if stride_shape is None else stride_shape
super().__init__(
input_resolutions=input_resolutions,
output_resolutions=[],
stride_shape=stride_shape,
patch_input_shape=patch_input_shape,
patch_output_shape=patch_input_shape,
save_resolution=None,
**kwargs,
)
[docs]class PatchPredictor:
"""Patch-level predictor.
Args:
model (nn.Module): Use externally defined PyTorch model for prediction with.
weights already loaded. Default is `None`. If provided,
`pretrained_model` argument is ignored.
pretrained_model (str): Name of the existing models support by tiatoolbox
for processing the data. For a full list of pretrained models, refer to the
`docs <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_
By default, the corresponding pretrained weights will also be
downloaded. However, you can override with your own set of weights
via the `pretrained_weights` argument. Argument is case-insensitive.
pretrained_weights (str): Path to the weight of the corresponding
`pretrained_model`.
>>> predictor = PatchPredictor(
... pretrained_model="resnet18-kather100k",
... pretrained_weights="resnet18_local_weight")
batch_size (int) : Number of images fed into the model each time.
num_loader_workers (int) : Number of workers to load the data.
Take note that they will also perform preprocessing.
verbose (bool): Whether to output logging information.
Attributes:
img (:obj:`str` or :obj:`pathlib.Path` or :class:`numpy.ndarray`):
A HWC image or a path to WSI.
mode (str): Type of input to process. Choose from either `patch`, `tile`
or `wsi`.
model (nn.Module): Defined PyTorch model.
pretrained_model (str): Name of the existing models support by tiatoolbox
for processing the data. For a full list of pretrained models, refer to the
`docs <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_
By default, the corresponding pretrained weights will also be
downloaded. However, you can override with your own set of weights
via the `pretrained_weights` argument. Argument is case insensitive.
batch_size (int) : Number of images fed into the model each time.
num_loader_workers (int): Number of workers used in torch.utils.data.DataLoader.
verbose (bool): Whether to output logging information.
Examples:
>>> # list of 2 image patches as input
>>> data = [img1, img2]
>>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k")
>>> output = predictor.predict(data, mode='patch')
>>> # array of list of 2 image patches as input
>>> data = np.array([img1, img2])
>>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k")
>>> output = predictor.predict(data, mode='patch')
>>> # list of 2 image patch files as input
>>> data = ['path/img.png', 'path/img.png']
>>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k")
>>> output = predictor.predict(data, mode='patch')
>>> # list of 2 image tile files as input
>>> tile_file = ['path/tile1.png', 'path/tile2.png']
>>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k")
>>> output = predictor.predict(tile_file, mode='tile')
>>> # list of 2 wsi files as input
>>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs']
>>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k")
>>> output = predictor.predict(wsi_file, mode='wsi')
"""
def __init__(
self,
batch_size=8,
num_loader_workers=0,
model=None,
pretrained_model=None,
pretrained_weights=None,
verbose=True,
):
super().__init__()
self.imgs = None
self.mode = None
if model is None and pretrained_model is None:
raise ValueError("Must provide either of `model` or `pretrained_model`")
if model is not None:
self.model = model
ioconfig = None # retrieve iostate from provided model ?
else:
model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights)
self.ioconfig = ioconfig # for storing original
self._ioconfig = None # for storing runtime
self.model = model # for runtime, such as after wrapping with nn.DataParallel
self.pretrained_model = pretrained_model
self.batch_size = batch_size
self.num_loader_worker = num_loader_workers
self.verbose = verbose
[docs] @staticmethod
def merge_predictions(
img: Union[str, pathlib.Path, np.ndarray],
output: dict,
resolution: float = None,
units: str = None,
postproc_func: Callable = None,
return_raw: bool = False,
):
"""Merge patch-level predictions to form a 2-dimensional prediction map.
#! Improve how the below reads.
The prediction map will contain values from 0 to N, where N is the number
of classes. Here, 0 is the background which has not been processed by the
model and N is the number of classes predicted by the model.
Args:
img (:obj:`str` or :obj:`pathlib.Path` or :class:`numpy.ndarray`):
A HWC image or a path to WSI.
output (dict): Ouput generated by the model.
resolution (float): Resolution of merged predictions.
units (str): Units of resolution used when merging predictions. This
must be the same `units` used when processing the data.
postproc_func (callable): A function to post-process raw prediction
from model. By default, internal code uses the `np.argmax` function.
return_raw (bool): Return raw result without applying the `postproc_func`
on the assembled image.
Returns:
prediction_map (ndarray): Merged predictions as a 2D array.
Examples:
>>> # pseudo output dict from model with 2 patches
>>> output = {
... 'resolution': 1.0,
... 'units': 'baseline',
... 'probabilities': [[0.45, 0.55], [0.90, 0.10]],
... 'predictions': [1, 0],
... 'coordinates': [[0, 0, 2, 2], [2, 2, 4, 4]],
... }
>>> merged = PatchPredictor.merge_predictions(
... np.zeros([4, 4]),
... output,
... resolution=1.0,
... units='baseline'
... )
>>> merged
array([[2, 2, 0, 0],
[2, 2, 0, 0],
[0, 0, 1, 1],
[0, 0, 1, 1]])
"""
reader = get_wsireader(img)
if isinstance(reader, VirtualWSIReader):
warnings.warn(
(
"Image is not pyramidal hence read is forced to be "
"at `units='baseline'` and `resolution=1.0`."
)
)
resolution = 1.0
units = "baseline"
canvas_shape = reader.slide_dimensions(resolution=resolution, units=units)
canvas_shape = canvas_shape[::-1] # XY to YX
# may crash here, do we need to deal with this ?
output_shape = reader.slide_dimensions(
resolution=output["resolution"], units=output["units"]
)
output_shape = output_shape[::-1] # XY to YX
fx = np.array(canvas_shape) / np.array(output_shape)
if "probabilities" not in output.keys():
coordinates = output["coordinates"]
predictions = output["predictions"]
denominator = None
output = np.zeros(list(canvas_shape), dtype=np.float32)
else:
coordinates = output["coordinates"]
predictions = output["probabilities"]
num_class = np.array(predictions[0]).shape[0]
denominator = np.zeros(canvas_shape)
output = np.zeros(list(canvas_shape) + [num_class], dtype=np.float32)
for idx, bound in enumerate(coordinates):
prediction = predictions[idx]
# assumed to be in XY
# top-left for output placement
tl = np.ceil(np.array(bound[:2]) * fx).astype(np.int32)
# bot-right for output placement
br = np.ceil(np.array(bound[2:]) * fx).astype(np.int32)
output[tl[1] : br[1], tl[0] : br[0]] = prediction
if denominator is not None:
denominator[tl[1] : br[1], tl[0] : br[0]] += 1
# deal with overlapping regions
if denominator is not None:
output = output / (np.expand_dims(denominator, -1) + 1.0e-8)
if not return_raw:
# convert raw probabilities to predictions
if postproc_func is not None:
output = postproc_func(output)
else:
output = np.argmax(output, axis=-1)
# to make sure background is 0 while class will be 1..N
output[denominator > 0] += 1
return output
def _predict_engine(
self,
dataset,
return_probabilities=False,
return_labels=False,
return_coordinates=False,
on_gpu=True,
):
"""Make a prediction on a dataset. The dataset may be mutated.
Args:
dataset (torch.utils.data.Dataset): PyTorch dataset object created using
tiatoolbox.models.data.classification.Patch_Dataset.
return_probabilities (bool): Whether to return per-class probabilities.
return_labels (bool): Whether to return labels.
return_coordinates (bool): Whether to return patch coordinates.
on_gpu (bool): whether to run model on the GPU.
Returns:
output (ndarray): Model predictions of the input dataset
"""
dataset.preproc_func = self.model.preproc_func
# preprocessing must be defined with the dataset
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self.num_loader_worker,
batch_size=self.batch_size,
drop_last=False,
shuffle=False,
)
if self.verbose:
pbar = tqdm.tqdm(
total=int(len(dataloader)), leave=True, ncols=80, ascii=True, position=0
)
# use external for testing
model = misc.model_to(on_gpu, self.model)
cum_output = {
"probabilities": [],
"predictions": [],
"coordinates": [],
"labels": [],
}
for _, batch_data in enumerate(dataloader):
batch_output_probabilities = self.model.infer_batch(
model, batch_data["image"], on_gpu
)
# We get the index of the class with the maximum probability
batch_output_predictions = self.model.postproc_func(
batch_output_probabilities
)
# tolist might be very expensive
cum_output["probabilities"].extend(batch_output_probabilities.tolist())
cum_output["predictions"].extend(batch_output_predictions.tolist())
if return_coordinates:
cum_output["coordinates"].extend(batch_data["coords"].tolist())
if return_labels: # be careful of `s`
# We do not use tolist here because label may be of mixed types
# and hence collated as list by torch
cum_output["labels"].extend(list(batch_data["label"]))
if self.verbose:
pbar.update()
if self.verbose:
pbar.close()
if not return_probabilities:
cum_output.pop("probabilities")
if not return_labels:
cum_output.pop("labels")
if not return_coordinates:
cum_output.pop("coordinates")
return cum_output
[docs] def predict(
self,
imgs,
masks=None,
labels=None,
mode="patch",
return_probabilities=False,
return_labels=False,
on_gpu=True,
ioconfig: IOPatchPredictorConfig = None,
patch_input_shape: Tuple[int, int] = None,
stride_shape: Tuple[int, int] = None,
resolution=None,
units=None,
merge_predictions=False,
save_dir=None,
save_output=False,
):
"""Make a prediction for a list of input data.
Args:
imgs (list, ndarray): List of inputs to process. When using `patch`
mode, the input must be either a list of images, a list of image
file paths or a numpy array of an image list. When using `tile` or
`wsi` mode, the input must be a list of file paths.
masks (list): List of masks. Only utilised when processing image tiles
and whole-slide images. Patches are only processed if they are
within a masked area. If not provided, then a tissue mask will be
automatically generated for whole-slide images or the entire image
is processed for image tiles.
labels: List of labels. If using `tile` or `wsi` mode, then only a
single label per image tile or whole-slide image is supported.
mode (str): Type of input to process. Choose from either `patch`, `tile`
or `wsi`.
return_probabilities (bool): Whether to return per-class probabilities.
return_labels (bool): Whether to return the labels with the predictions.
on_gpu (bool): whether to run model on the GPU.
patch_input_shape (tuple): Size of patches input to the model. Patches
are at requested read resolution, not with respect to level 0, and must be
positive.
stride_shape (tuple): Stride using during tile and WSI processing.
Stride is at requested read resolution, not with respect to to level
0, and must be positive. If not provided,
`stride_shape=patch_input_shape`.
resolution (float): Resolution used for reading the image. Please see
:obj:`WSIReader` for details.
units (str): Units of resolution used for reading the image. Choose from
either `level`, `power` or `mpp`. Please see
:obj:`WSIReader` for details.
merge_predictions (bool): Whether to merge the predictions to form a
2-dimensional map. This is only applicable for `mode='wsi'`
or `mode='tile'`.
save_dir (str or pathlib.Path): Output directory when processing
multiple tiles and whole-slide images. By default, it is folder `output`
where the running script is invoked.
save_output (bool): Whether to save output for a single file. default=False
Returns:
output (ndarray, dict): Model predictions of the input dataset. If
multiple image tiles or whole-slide images are provided as input,
or save_output is True, then results are saved to `save_dir` and a
dictionary indicating save location for each input is return.
The dict has following format:
- img_path: path of the input image.
- raw: path to save location for raw prediction, saved in .json.
- merged: path to .npy contain merged predictions if
`merge_predictions` is `True`.
Examples:
>>> wsis = ['wsi1.svs', 'wsi2.svs']
>>> predictor = PatchPredictor(
... pretrained_model="resnet18-kather100k")
>>> output = predictor.predict(wsis, mode="wsi")
>>> output.keys()
['wsi1.svs', 'wsi2.svs']
>>> output['wsi1.svs']
{'raw': '0.raw.json', 'merged': '0.merged.npy}
>>> output['wsi2.svs']
{'raw': '1.raw.json', 'merged': '1.merged.npy}
"""
if mode not in ["patch", "wsi", "tile"]:
raise ValueError(
f"{mode} is not a valid mode. Use either `patch`, `tile` or `wsi`"
)
if mode == "patch" and labels is not None:
# if a labels is provided, then return with the prediction
return_labels = bool(labels)
if len(labels) != len(imgs):
raise ValueError(
f"len(labels) != len(imgs) : " f"{len(labels)} != {len(imgs)}"
)
if mode == "wsi" and masks is not None and len(masks) != len(imgs):
raise ValueError(
f"len(masks) != len(imgs) : " f"{len(masks)} != {len(imgs)}"
)
if mode == "patch":
# don't return coordinates if patches are already extracted
return_coordinates = False
dataset = PatchDataset(imgs, labels)
output = self._predict_engine(
dataset, return_probabilities, return_labels, return_coordinates, on_gpu
)
else:
if stride_shape is None:
stride_shape = patch_input_shape
# ! not sure if there is any way to make this nicer
make_config_flag = (
patch_input_shape is None,
resolution is None,
units is None,
)
if ioconfig is None and self.ioconfig is None and any(make_config_flag):
raise ValueError(
"Must provide either `ioconfig` or "
"`patch_input_shape`, `resolution`, and `units`."
)
if ioconfig is None and self.ioconfig:
ioconfig = copy.deepcopy(self.ioconfig)
# ! not sure if there is a nicer way to set this
if patch_input_shape is not None:
ioconfig.patch_input_shape = patch_input_shape
if stride_shape is not None:
ioconfig.stride_shape = stride_shape
if resolution is not None:
ioconfig.input_resolutions[0]["resolution"] = resolution
if units is not None:
ioconfig.input_resolutions[0]["units"] = units
elif ioconfig is None and all(not v for v in make_config_flag):
ioconfig = IOPatchPredictorConfig(
input_resolutions=[{"resolution": resolution, "units": units}],
patch_input_shape=patch_input_shape,
stride_shape=stride_shape,
)
fx_list = ioconfig.scale_to_highest(
ioconfig.input_resolutions, ioconfig.input_resolutions[0]["units"]
)
fx_list = zip(fx_list, ioconfig.input_resolutions)
fx_list = sorted(fx_list, key=lambda x: x[0])
highest_input_resolution = fx_list[0][1]
if mode == "tile":
warnings.warn(
"WSIPatchDataset only reads image tile at "
'`units="baseline"`. Resolutions will be converted '
"to baseline value."
)
ioconfig = ioconfig.to_baseline()
if len(imgs) > 1:
warnings.warn(
"When providing multiple whole-slide images / tiles, "
"we save the outputs and return the locations "
"to the corresponding files."
)
if len(imgs) > 1:
warnings.warn(
"When providing multiple whole-slide images / tiles, "
"we save the outputs and return the locations "
"to the corresponding files."
)
if save_dir is None:
warnings.warn(
"> 1 WSIs detected but there is no save directory set."
"All subsequent output will be saved to current runtime"
"location under folder 'output'. Overwriting may happen!"
)
save_dir = pathlib.Path(os.getcwd()).joinpath("output")
save_dir = pathlib.Path(save_dir)
if save_dir is not None:
save_dir = pathlib.Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=False)
# return coordinates of patches processed within a tile / whole-slide image
return_coordinates = True
if not isinstance(imgs, list):
raise ValueError(
"Input to `tile` and `wsi` mode must be a list of file paths."
)
# None if no output
outputs = None
self._ioconfig = ioconfig
# generate a list of output file paths if number of input images > 1
file_dict = OrderedDict()
for idx, img_path in enumerate(imgs):
img_path = pathlib.Path(img_path)
img_label = None if labels is None else labels[idx]
img_mask = None if masks is None else masks[idx]
dataset = WSIPatchDataset(
img_path,
mode=mode,
mask_path=img_mask,
patch_input_shape=ioconfig.patch_input_shape,
stride_shape=ioconfig.stride_shape,
resolution=ioconfig.input_resolutions[0]["resolution"],
units=ioconfig.input_resolutions[0]["units"],
)
output_model = self._predict_engine(
dataset,
return_labels=False,
return_probabilities=return_probabilities,
return_coordinates=return_coordinates,
on_gpu=on_gpu,
)
output_model["label"] = img_label
# add extra information useful for downstream analysis
output_model["pretrained_model"] = self.pretrained_model
output_model["resolution"] = highest_input_resolution["resolution"]
output_model["units"] = highest_input_resolution["units"]
outputs = [output_model] # assign to a list
merged_prediction = None
if merge_predictions:
merged_prediction = self.merge_predictions(
img_path,
output_model,
resolution=output_model["resolution"],
units=output_model["units"],
postproc_func=self.model.postproc,
)
outputs.append(merged_prediction)
if len(imgs) > 1 or save_output:
# dynamic 0 padding
img_code = f"{idx:0{len(str(len(imgs)))}d}"
save_info = {}
save_path = os.path.join(str(save_dir), img_code)
raw_save_path = f"{save_path}.raw.json"
save_info["raw"] = raw_save_path
save_as_json(output_model, raw_save_path)
if merge_predictions:
merged_file_path = f"{save_path}.merged.npy"
np.save(merged_file_path, merged_prediction)
save_info["merged"] = merged_file_path
file_dict[str(img_path)] = save_info
output = file_dict if len(imgs) > 1 or save_output else outputs
return output