"""This module implements semantic segmentation."""
from __future__ import annotations
import copy
import logging
import shutil
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import TYPE_CHECKING, Callable
import cv2
import joblib
import numpy as np
import torch
import torch.multiprocessing as torch_mp
import torch.utils.data as torch_data
import tqdm
from tiatoolbox import logger
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.models_abc import IOConfigABC
from tiatoolbox.tools.patchextraction import PatchExtractor
from tiatoolbox.utils import imread, misc
from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader
if TYPE_CHECKING: # pragma: no cover
from multiprocessing.managers import Namespace
from tiatoolbox.typing import IntPair, Resolution, Units
def _estimate_canvas_parameters(
sample_prediction: np.ndarray,
canvas_shape: np.ndarray,
) -> tuple[tuple, tuple, bool]:
"""Estimates canvas parameters.
Args:
sample_prediction (:class:`numpy.ndarray`):
Patch prediction assuming to be of shape HWC.
canvas_shape (:class:`numpy.ndarray`):
HW of the supposed assembled image.
Returns:
(tuple, tuple, bool):
Canvas Shape, Canvas Count and whether to add singleton dimension.
"""
if len(sample_prediction.shape) == 3: # noqa: PLR2004
num_output_ch = sample_prediction.shape[-1]
canvas_cum_shape_ = (*tuple(canvas_shape), num_output_ch)
canvas_count_shape_ = (*tuple(canvas_shape), 1)
add_singleton_dim = num_output_ch == 1
else:
canvas_cum_shape_ = (*tuple(canvas_shape), 1)
canvas_count_shape_ = (*tuple(canvas_shape), 1)
add_singleton_dim = True
return canvas_cum_shape_, canvas_count_shape_, add_singleton_dim
def _prepare_save_output(
save_path: str | Path,
cache_count_path: str | Path,
canvas_cum_shape_: tuple[int, ...],
canvas_count_shape_: tuple[int, ...],
) -> tuple:
"""Prepares for saving the cached output."""
if save_path is not None:
save_path = Path(save_path)
cache_count_path = Path(cache_count_path)
if Path.exists(save_path) and Path.exists(cache_count_path):
cum_canvas = np.load(str(save_path), mmap_mode="r+")
count_canvas = np.load(str(cache_count_path), mmap_mode="r+")
if canvas_cum_shape_ != cum_canvas.shape:
msg = "Existing image shape in `save_path` does not match."
raise ValueError(msg)
if canvas_count_shape_ != count_canvas.shape:
msg = "Existing image shape in `cache_count_path` does not match."
raise ValueError(
msg,
)
else:
cum_canvas = np.lib.format.open_memmap(
save_path,
mode="w+",
shape=canvas_cum_shape_,
dtype=np.float32,
)
# assuming no more than 255 overlapping times
count_canvas = np.lib.format.open_memmap(
cache_count_path,
mode="w+",
shape=canvas_count_shape_,
dtype=np.uint8,
)
# flush fill
count_canvas[:] = 0
is_on_drive = True
else:
is_on_drive = False
cum_canvas = np.zeros(
shape=canvas_cum_shape_,
dtype=np.float32,
)
# for pixel occurrence counting
count_canvas = np.zeros(canvas_count_shape_, dtype=np.float32)
return is_on_drive, count_canvas, cum_canvas
[docs]
class IOSegmentorConfig(IOConfigABC):
"""Contain semantic segmentor input and output information.
Args:
input_resolutions (list):
Resolution of each input head of model inference, must be in
the same order as `target model.forward()`.
output_resolutions (list):
Resolution of each output head from model inference, must be
in the same order as target model.infer_batch().
patch_input_shape (:class:`numpy.ndarray`, list(int)):
Shape of the largest input in (height, width).
patch_output_shape (:class:`numpy.ndarray`, list(int)):
Shape of the largest output in (height, width).
save_resolution (dict):
Resolution to save all output.
Examples:
>>> # Defining io for a network having 1 input and 1 output at the
>>> # same resolution
>>> ioconfig = IOSegmentorConfig(
... input_resolutions=[{"units": "baseline", "resolution": 1.0}],
... output_resolutions=[{"units": "baseline", "resolution": 1.0}],
... patch_input_shape=[2048, 2048],
... patch_output_shape=[1024, 1024],
... stride_shape=[512, 512],
... )
Examples:
>>> # Defining io for a network having 3 input and 2 output
>>> # at the same resolution, the output is then merged at a
>>> # different resolution.
>>> ioconfig = IOSegmentorConfig(
... input_resolutions=[
... {"units": "mpp", "resolution": 0.25},
... {"units": "mpp", "resolution": 0.50},
... {"units": "mpp", "resolution": 0.75},
... ],
... output_resolutions=[
... {"units": "mpp", "resolution": 0.25},
... {"units": "mpp", "resolution": 0.50},
... ],
... patch_input_shape=[2048, 2048],
... patch_output_shape=[1024, 1024],
... stride_shape=[512, 512],
... save_resolution={"units": "mpp", "resolution": 4.0},
... )
"""
# We pre-define to follow enforcement, actual initialisation in init
input_resolutions = None
output_resolutions = None
def __init__(
self: IOSegmentorConfig,
input_resolutions: list[dict],
output_resolutions: list[dict],
patch_input_shape: IntPair,
patch_output_shape: IntPair,
save_resolution: dict | None = None,
**kwargs: dict,
) -> None:
"""Initialize :class:`IOSegmentorConfig`."""
self._kwargs = kwargs
self.patch_input_shape = patch_input_shape
self.patch_output_shape = patch_output_shape
self.stride_shape = None
self.input_resolutions = input_resolutions
self.output_resolutions = output_resolutions
self.resolution_unit = input_resolutions[0]["units"]
self.save_resolution = save_resolution
for variable, value in kwargs.items():
self.__setattr__(variable, value)
self._validate()
if self.resolution_unit == "mpp":
self.highest_input_resolution = min(
self.input_resolutions,
key=lambda x: x["resolution"],
)
else:
self.highest_input_resolution = max(
self.input_resolutions,
key=lambda x: x["resolution"],
)
def _validate(self: IOSegmentorConfig) -> None:
"""Validate the data format."""
resolutions = self.input_resolutions + self.output_resolutions
units = [v["units"] for v in resolutions]
units = np.unique(units)
if len(units) != 1 or units[0] not in [
"power",
"baseline",
"mpp",
]:
msg = f"Invalid resolution units `{units[0]}`."
raise ValueError(msg)
[docs]
@staticmethod
def scale_to_highest(resolutions: list[dict], units: Units) -> np.ndarray:
"""Get the scaling factor from input resolutions.
This will convert resolutions to a scaling factor with respect to
the highest resolution found in the input resolutions list.
Args:
resolutions (list):
A list of resolutions where one is defined as
`{'resolution': value, 'unit': value}`
units (Units):
Units that the resolutions are at.
Returns:
:class:`numpy.ndarray`:
A 1D array of scaling factors having the same length as
`resolutions`
"""
old_val = [v["resolution"] for v in resolutions]
if units not in ["baseline", "mpp", "power"]:
msg = (
f"Unknown units `{units}`. "
f"Units should be one of 'baseline', 'mpp' or 'power'."
)
raise ValueError(
msg,
)
if units == "baseline":
return old_val
if units == "mpp":
return np.min(old_val) / np.array(old_val)
return np.array(old_val) / np.max(old_val)
[docs]
def to_baseline(self: IOSegmentorConfig) -> IOSegmentorConfig:
"""Return a new config object converted to baseline form.
This will return a new :class:`IOSegmentorConfig` where
resolutions have been converted to baseline format with the
highest possible resolution found in both input and output as
reference.
"""
resolutions = self.input_resolutions + self.output_resolutions
if self.save_resolution is not None:
resolutions.append(self.save_resolution)
scale_factors = self.scale_to_highest(resolutions, self.resolution_unit)
num_input_resolutions = len(self.input_resolutions)
num_output_resolutions = len(self.output_resolutions)
end_idx = num_input_resolutions
input_resolutions = [
{"units": "baseline", "resolution": v} for v in scale_factors[:end_idx]
]
end_idx = num_input_resolutions + num_output_resolutions
output_resolutions = [
{"units": "baseline", "resolution": v}
for v in scale_factors[num_input_resolutions:end_idx]
]
save_resolution = None
if self.save_resolution is not None:
save_resolution = {"units": "baseline", "resolution": scale_factors[-1]}
return IOSegmentorConfig(
input_resolutions=input_resolutions,
output_resolutions=output_resolutions,
patch_input_shape=self.patch_input_shape,
patch_output_shape=self.patch_output_shape,
save_resolution=save_resolution,
**self._kwargs,
)
[docs]
class WSIStreamDataset(torch_data.Dataset):
"""Reading a wsi in parallel mode with persistent workers.
To speed up the inference process for multiple WSIs. The
`torch.utils.data.Dataloader` is set to run in persistent mode.
Normally, this will prevent workers from altering their initial
states (such as provided input etc.). To sidestep this, we use a
shared parallel workspace context manager to send data and signal
from the main thread, thus allowing each worker to load a new wsi as
well as corresponding patch information.
Args:
mp_shared_space (:class:`Namespace`):
A shared multiprocessing space, must be from
`torch.multiprocessing`.
ioconfig (:class:`IOSegmentorConfig`):
An object which contains I/O placement for patches.
wsi_paths (list): List of paths pointing to a WSI or tiles.
preproc (Callable):
Pre-processing function to be applied to a patch.
mode (str):
Either `"wsi"` or `"tile"` to indicate the format of images
in `wsi_paths`.
Examples:
>>> ioconfig = IOSegmentorConfig(
... input_resolutions=[{"units": "baseline", "resolution": 1.0}],
... output_resolutions=[{"units": "baseline", "resolution": 1.0}],
... patch_input_shape=[2048, 2048],
... patch_output_shape=[1024, 1024],
... stride_shape=[512, 512],
... )
>>> mp_manager = torch_mp.Manager()
>>> mp_shared_space = mp_manager.Namespace()
>>> mp_shared_space.signal = 1 # adding variable to the shared space
>>> wsi_paths = ['A.svs', 'B.svs']
>>> wsi_dataset = WSIStreamDataset(ioconfig, wsi_paths, mp_shared_space)
"""
def __init__(
self: WSIStreamDataset,
ioconfig: IOSegmentorConfig,
wsi_paths: list[str | Path],
mp_shared_space: Namespace,
preproc: Callable[[np.ndarray], np.ndarray] | None = None,
mode: str = "wsi",
) -> None:
"""Initialize :class:`WSIStreamDataset`."""
super().__init__()
self.mode = mode
self.preproc = preproc
self.ioconfig = copy.deepcopy(ioconfig)
if mode == "tile":
logger.warning(
"WSIPatchDataset only reads image tile at "
'`units="baseline"`. Resolutions will be converted '
"to baseline value.",
stacklevel=2,
)
self.ioconfig = self.ioconfig.to_baseline()
self.mp_shared_space = mp_shared_space
self.wsi_paths = wsi_paths
self.wsi_idx = None # to be received externally via thread communication
self.reader = None
def _get_reader(self: WSIStreamDataset, img_path: str | Path) -> WSIReader:
"""Get appropriate reader for input path."""
img_path = Path(img_path)
if self.mode == "wsi":
return WSIReader.open(img_path)
img = imread(img_path)
# initialise metadata for VirtualWSIReader.
# here, we simulate a whole-slide image, but with a single level.
metadata = WSIMeta(
mpp=np.array([1.0, 1.0]),
objective_power=10,
axes="YXS",
slide_dimensions=np.array(img.shape[:2][::-1]),
level_downsamples=[1.0],
level_dimensions=[np.array(img.shape[:2][::-1])],
)
return VirtualWSIReader(
img,
info=metadata,
)
def __len__(self: WSIStreamDataset) -> int:
"""Return the length of the instance attributes."""
return len(self.mp_shared_space.patch_inputs)
[docs]
@staticmethod
def collate_fn(batch: list | np.ndarray) -> torch.Tensor:
"""Prototype to handle reading exception.
This will exclude any sample with `None` from the batch. As
such, wrapping `__getitem__` with try-catch and return `None`
upon exceptions will prevent crashing the entire program. But as
a side effect, the batch may not have the size as defined.
"""
batch = [v for v in batch if v is not None]
return torch.utils.data.dataloader.default_collate(batch)
def __getitem__(self: WSIStreamDataset, idx: int) -> tuple:
"""Get an item from the dataset."""
# ! no need to lock as we do not modify source value in shared space
if self.wsi_idx != self.mp_shared_space.wsi_idx:
self.wsi_idx = int(self.mp_shared_space.wsi_idx.item())
self.reader = self._get_reader(self.wsi_paths[self.wsi_idx])
# this is in XY and at requested resolution (not baseline)
bounds = self.mp_shared_space.patch_inputs[idx]
bounds = bounds.numpy() # expected to be a torch.Tensor
# be the same as bounds br-tl, unless bounds are of float
patch_data_ = []
scale_factors = self.ioconfig.scale_to_highest(
self.ioconfig.input_resolutions,
self.ioconfig.resolution_unit,
)
for idy, resolution in enumerate(self.ioconfig.input_resolutions):
resolution_bounds = np.round(bounds * scale_factors[idy])
patch_data = self.reader.read_bounds(
resolution_bounds.astype(np.int32),
coord_space="resolution",
pad_constant_values=0, # expose this ?
**resolution,
)
if self.preproc is not None:
patch_data = patch_data.copy()
patch_data = self.preproc(patch_data)
patch_data_.append(patch_data)
if len(patch_data_) == 1:
patch_data_ = patch_data_[0]
bound = self.mp_shared_space.patch_outputs[idx]
return patch_data_, bound
[docs]
class SemanticSegmentor:
"""Pixel-wise segmentation predictor.
The tiatoolbox model should produce the following results on the BCSS dataset
using fcn_resnet50_unet-bcss.
.. list-table:: Semantic segmentation performance on the BCSS dataset
:widths: 15 15 15 15 15 15 15
:header-rows: 1
* -
- Tumour
- Stroma
- Inflammatory
- Necrosis
- Other
- All
* - Amgad et al.
- 0.851
- 0.800
- 0.712
- 0.723
- 0.666
- 0.750
* - TIAToolbox
- 0.885
- 0.825
- 0.761
- 0.765
- 0.581
- 0.763
Note, if `model` is supplied in the arguments, it will ignore the
`pretrained_model` and `pretrained_weights` arguments.
Args:
model (nn.Module):
Use externally defined PyTorch model for prediction with
weights already loaded. Default is `None`. If provided,
`pretrained_model` argument is ignored.
pretrained_model (str):
Name of the existing models support by tiatoolbox for
processing the data. For a full list of pretrained models,
refer to the `docs
<https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_.
By default, the corresponding pretrained weights will also
be downloaded. However, you can override with your own set
of weights via the `pretrained_weights` argument. Argument
is case-insensitive.
pretrained_weights (str):
Path to the weight of the corresponding `pretrained_model`.
batch_size (int):
Number of images fed into the model each time.
num_loader_workers (int):
Number of workers to load the data. Take note that they will
also perform preprocessing.
num_postproc_workers (int):
This value is there to maintain input compatibility with
`tiatoolbox.models.classification` and is not used.
verbose (bool):
Whether to output logging information.
dataset_class (obj):
Dataset class to be used instead of default.
auto_generate_mask (bool):
To automatically generate tile/WSI tissue mask if is not
provided.
Attributes:
process_prediction_per_batch (bool):
A flag to denote whether post-processing for inference
output is applied after each batch or after finishing an entire
tile or WSI.
Examples:
>>> # Sample output of a network
>>> wsis = ['A/wsi.svs', 'B/wsi.svs']
>>> predictor = SemanticSegmentor(model='fcn-tissue_mask')
>>> output = predictor.predict(wsis, mode='wsi')
>>> list(output.keys())
[('A/wsi.svs', 'output/0.raw') , ('B/wsi.svs', 'output/1.raw')]
>>> # if a network have 2 output heads, each head output of 'A/wsi.svs'
>>> # will be respectively stored in 'output/0.raw.0', 'output/0.raw.1'
"""
def __init__(
self: SemanticSegmentor,
batch_size: int = 8,
num_loader_workers: int = 0,
num_postproc_workers: int = 0,
model: torch.nn.Module | None = None,
pretrained_model: str | None = None,
pretrained_weights: str | None = None,
dataset_class: Callable = WSIStreamDataset,
*,
verbose: bool = True,
auto_generate_mask: bool = False,
) -> None:
"""Initialize :class:`SemanticSegmentor`."""
super().__init__()
if model is None and pretrained_model is None:
msg = "Must provide either of `model` or `pretrained_model`"
raise ValueError(msg)
if model is not None:
self.model = model
# template ioconfig, usually coming from pretrained
self.ioconfig = None
else:
model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights)
self.ioconfig = ioconfig
self.model = model
# local variables for flagging mode within class,
# subclass should have overwritten to alter some specific behavior
self.process_prediction_per_batch = True
# for runtime, such as after wrapping with nn.DataParallel
self._cache_dir = None
self._loader = None
self._model = None
self._on_gpu = None
self._mp_shared_space = None
self._postproc_workers = None
self.num_postproc_workers = num_postproc_workers
self._futures = None
self._outputs = []
self.imgs = None
self.masks = None
self.dataset_class: WSIStreamDataset = dataset_class
self.model = model # original copy
self.pretrained_model = pretrained_model
self.batch_size = batch_size
self.num_loader_workers = num_loader_workers
self.num_postproc_workers = None
self.verbose = verbose
self.auto_generate_mask = auto_generate_mask
[docs]
@staticmethod
def get_coordinates(
image_shape: list[int] | np.ndarray,
ioconfig: IOSegmentorConfig,
) -> tuple[list, list]:
"""Calculate patch tiling coordinates.
By default, internally, it will call the
`PatchExtractor.get_coordinates`. To use your own approach,
either subclass to overwrite or directly assign your own
function to this name. In either cases, the function must obey
the API defined here.
Args:
image_shape (tuple(int), :class:`numpy.ndarray`):
This argument specifies the shape of mother image (the
image we want to extract patches from) at requested
`resolution` and `units` and it is expected to be in
(width, height) format.
ioconfig (:class:`IOSegmentorConfig`):
Object that contains information about input and output
placement of patches. Check `IOSegmentorConfig` for
details about available attributes.
Returns:
tuple:
List of patch inputs and outputs
- :py:obj:`list` - patch_inputs:
A list of corrdinates in `[start_x, start_y, end_x,
end_y]` format indicating the read location of the
patch in the mother image.
- :py:obj:`list` - patch_outputs:
A list of corrdinates in `[start_x, start_y, end_x,
end_y]` format indicating to write location of the
patch in the mother image.
Examples:
>>> # API of function expected to overwrite `get_coordinates`
>>> def func(image_shape, ioconfig):
... patch_inputs = np.array([[0, 0, 256, 256]])
... patch_outputs = np.array([[0, 0, 256, 256]])
... return patch_inputs, patch_outputs
>>> segmentor = SemanticSegmentor(model='unet')
>>> segmentor.get_coordinates = func
"""
(patch_inputs, patch_outputs) = PatchExtractor.get_coordinates(
image_shape=image_shape,
patch_input_shape=ioconfig.patch_input_shape,
patch_output_shape=ioconfig.patch_output_shape,
stride_shape=ioconfig.stride_shape,
)
return patch_inputs, patch_outputs
[docs]
@staticmethod
def filter_coordinates(
mask_reader: VirtualWSIReader,
bounds: np.ndarray,
resolution: Resolution | None = None,
units: Units | None = None,
) -> np.ndarray:
"""Indicates which coordinate is valid basing on the mask.
To use your own approaches, either subclass to overwrite or
directly assign your own function to this name. In either cases,
the function must obey the API defined here.
Args:
mask_reader (:class:`.VirtualReader`):
A virtual pyramidal reader of the mask related to the
WSI from which we want to extract the patches.
bounds (ndarray and np.int32):
Coordinates to be checked via the `func`. They must be
in the same resolution as requested `resolution` and
`units`. The shape of `coordinates` is (N, K) where N is
the number of coordinate sets and K is either 2 for
centroids or 4 for bounding boxes. When using the
default `func=None`, K should be 4, as we expect the
`coordinates` to be bounding boxes in `[start_x,
start_y, end_x, end_y]` format.
resolution (Resolution):
Resolution of the requested patch.
units (Units):
Units of the requested patch.
Returns:
:class:`numpy.ndarray`:
List of flags to indicate which coordinate is valid.
Examples:
>>> # API of function expected to overwrite `filter_coordinates`
>>> def func(reader, bounds, resolution, units):
... # as example, only select first bound
... return np.array([1, 0])
>>> coords = [[0, 0, 256, 256], [128, 128, 384, 384]]
>>> segmentor = SemanticSegmentor(model='unet')
>>> segmentor.filter_coordinates = func
"""
if not isinstance(mask_reader, VirtualWSIReader):
msg = "`mask_reader` should be VirtualWSIReader."
raise TypeError(msg)
if not isinstance(bounds, np.ndarray) or not np.issubdtype(
bounds.dtype,
np.integer,
):
msg = "`coordinates` should be ndarray of integer type."
raise ValueError(msg)
mask_real_shape = mask_reader.img.shape[:2]
mask_resolution_shape = mask_reader.slide_dimensions(
resolution=resolution,
units=units,
)[::-1]
mask_real_shape = np.array(mask_real_shape)
mask_resolution_shape = np.array(mask_resolution_shape)
scale_factor = mask_real_shape / mask_resolution_shape
scale_factor = scale_factor[0] # what if ratio x != y
def sel_func(coord: np.ndarray) -> bool:
"""Accept coord as long as its box contains part of mask."""
coord_in_real_mask = np.ceil(scale_factor * coord).astype(np.int32)
start_x, start_y, end_x, end_y = coord_in_real_mask
roi = mask_reader.img[start_y:end_y, start_x:end_x]
return np.sum(roi > 0) > 0
flags = [sel_func(bound) for bound in bounds]
return np.array(flags)
[docs]
@staticmethod
def get_reader(
img_path: str | Path,
mask_path: str | Path,
mode: str,
*,
auto_get_mask: bool,
) -> tuple[WSIReader, WSIReader]:
"""Define how to get reader for mask and source image."""
img_path = Path(img_path)
reader = WSIReader.open(img_path)
mask_reader = None
if mask_path is not None:
mask_path = Path(mask_path)
if not Path.is_file(mask_path):
msg = "`mask_path` must be a valid file path."
raise ValueError(msg)
mask = imread(mask_path) # assume to be gray
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
mask = np.array(mask > 0, dtype=np.uint8)
mask_reader = VirtualWSIReader(mask)
mask_reader.info = reader.info
elif auto_get_mask and mode == "wsi" and mask_path is None:
# if no mask provided and `wsi` mode, generate basic tissue
# mask on the fly
mask_reader = reader.tissue_mask(resolution=1.25, units="power")
mask_reader.info = reader.info
return reader, mask_reader
def _predict_one_wsi(
self: SemanticSegmentor,
wsi_idx: int,
ioconfig: IOSegmentorConfig,
save_path: str,
mode: str,
) -> None:
"""Make a prediction on tile/wsi.
Args:
wsi_idx (int):
Index of the tile/wsi to be processed within `self`.
ioconfig (:class:`IOSegmentorConfig`):
Object which defines I/O placement during inference and
when assembling back to full tile/wsi.
save_path (str):
Location to save output prediction as well as possible
intermediate results.
mode (str):
Either `"tile"` or `"wsi"` to indicate run mode.
"""
cache_dir = self._cache_dir / str(wsi_idx)
cache_dir.mkdir(parents=True)
wsi_path = self.imgs[wsi_idx]
mask_path = None if self.masks is None else self.masks[wsi_idx]
wsi_reader, mask_reader = self.get_reader(
wsi_path,
mask_path,
mode,
auto_get_mask=self.auto_generate_mask,
)
# assume ioconfig has already been converted to `baseline` for `tile` mode
resolution = ioconfig.highest_input_resolution
wsi_proc_shape = wsi_reader.slide_dimensions(**resolution)
# * retrieve patch and tile placement
# this is in XY
(patch_inputs, patch_outputs) = self.get_coordinates(wsi_proc_shape, ioconfig)
if mask_reader is not None:
sel = self.filter_coordinates(mask_reader, patch_outputs, **resolution)
patch_outputs = patch_outputs[sel]
patch_inputs = patch_inputs[sel]
# modify the shared space so that we can update worker info
# without needing to re-create the worker. There should be no
# race-condition because only the following enumerate loop
# triggers the parallelism, and this portion is still in
# sequential execution order
patch_inputs = torch.from_numpy(patch_inputs).share_memory_()
patch_outputs = torch.from_numpy(patch_outputs).share_memory_()
self._mp_shared_space.patch_inputs = patch_inputs
self._mp_shared_space.patch_outputs = patch_outputs
self._mp_shared_space.wsi_idx = torch.Tensor([wsi_idx]).share_memory_()
pbar_desc = "Process Batch: "
pbar = tqdm.tqdm(
desc=pbar_desc,
leave=True,
total=int(len(self._loader)),
ncols=80,
ascii=True,
position=0,
)
cum_output = []
for _, batch_data in enumerate(self._loader):
sample_datas, sample_infos = batch_data
batch_size = sample_infos.shape[0]
# ! depending on the protocol of the output within infer_batch
# ! this may change, how to enforce/document/expose this in a
# ! sensible way?
# assume to return a list of L output,
# each of shape N x etc. (N=batch size)
sample_outputs = self.model.infer_batch(
self._model,
sample_datas,
on_gpu=self._on_gpu,
)
# repackage so that it's an N list, each contains
# L x etc. output
sample_outputs = [np.split(v, batch_size, axis=0) for v in sample_outputs]
sample_outputs = list(zip(*sample_outputs))
# tensor to numpy, costly?
sample_infos = sample_infos.numpy()
sample_infos = np.split(sample_infos, batch_size, axis=0)
sample_outputs = list(zip(sample_infos, sample_outputs))
if self.process_prediction_per_batch:
self._process_predictions(
sample_outputs,
wsi_reader,
ioconfig,
save_path,
cache_dir,
)
else:
cum_output.extend(sample_outputs)
pbar.update()
pbar.close()
self._process_predictions(
cum_output,
wsi_reader,
ioconfig,
save_path,
cache_dir,
)
# clean up the cache directories
shutil.rmtree(cache_dir)
def _process_predictions(
self: SemanticSegmentor,
cum_batch_predictions: list,
wsi_reader: WSIReader,
ioconfig: IOSegmentorConfig,
save_path: str,
cache_dir: str,
) -> None:
"""Define how the aggregated predictions are processed.
This includes merging the prediction if necessary and also saving afterwards.
Note that items within `cum_batch_predictions` will be consumed during
the operation.
Args:
cum_batch_predictions (list):
List of batch predictions. Each item within the list
should be of (location, patch_predictions).
wsi_reader (:class:`WSIReader`):
A reader for the image where the predictions come from.
ioconfig (:class:`IOSegmentorConfig`):
A configuration object contains input and output
information.
save_path (str):
Root path to save current WSI predictions.
cache_dir (str):
Root path to cache current WSI data.
"""
if len(cum_batch_predictions) == 0:
return
# assume predictions is N, each item has L output element
locations, predictions = list(zip(*cum_batch_predictions))
# Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of
# output patch this can exceed the image bound at the requested
# resolution remove singleton due to split.
locations = np.array([v[0] for v in locations])
for index, output_resolution in enumerate(ioconfig.output_resolutions):
# assume resolution index to be in the same order as L
merged_resolution = ioconfig.highest_input_resolution
merged_locations = locations
# ! location is w.r.t the highest resolution, hence still need conversion
if ioconfig.save_resolution is not None:
merged_resolution = ioconfig.save_resolution
output_shape = wsi_reader.slide_dimensions(**output_resolution)
merged_shape = wsi_reader.slide_dimensions(**merged_resolution)
fx = merged_shape[0] / output_shape[0]
merged_locations = np.ceil(locations * fx).astype(np.int64)
merged_shape = wsi_reader.slide_dimensions(**merged_resolution)
# 0 idx is to remove singleton without removing other axes singleton
to_merge_predictions = [v[index][0] for v in predictions]
sub_save_path = f"{save_path}.raw.{index}.npy"
sub_count_path = f"{cache_dir}/count.{index}.npy"
self.merge_prediction(
merged_shape[::-1], # XY to YX
to_merge_predictions,
merged_locations,
save_path=sub_save_path,
cache_count_path=sub_count_path,
)
[docs]
@staticmethod
def merge_prediction(
canvas_shape: tuple[int] | list[int] | np.ndarray,
predictions: list[np.ndarray],
locations: list | np.ndarray,
save_path: str | Path | None = None,
cache_count_path: str | Path | None = None,
) -> np.ndarray:
"""Merge patch-level predictions to form a 2-dimensional prediction map.
When accumulating the raw prediction onto a same canvas (via
calling the function multiple times), `save_path` and
`cache_count_path` must be the same. If either of these two do
not exist, the function will create new files. However, if
`save_path` is `None`, the function will perform the
accumulation using CPU-RAM as storage.
Args:
canvas_shape (:class:`numpy.ndarray`):
HW of the supposed assembled image.
predictions (list):
List of :class:`np.ndarray`, each item is a patch prediction,
assuming to be of shape HWC.
locations (list):
List of :class:`np.ndarray`, each item is the location of the patch
at the same index within `predictions`. The location is
in the to be assembled canvas and of the form
`(top_left_x, top_left_y, bottom_right_x,
bottom_right_x)`.
save_path (str):
Location to save the assembled image.
cache_count_path (str):
Location to store the canvas for counting how many times
each pixel get overlapped when assembling.
Returns:
:class:`numpy.ndarray`:
An image contains merged data.
Examples:
>>> SemanticSegmentor.merge_prediction(
... canvas_shape=[4, 4],
... predictions=[
... np.full((2, 2), 1),
... np.full((2, 2), 2)],
... locations=[
... [0, 0, 2, 2],
... [2, 2, 4, 4]],
... save_path=None,
... )
... array([[1, 1, 0, 0],
... [1, 1, 0, 0],
... [0, 0, 2, 2],
... [0, 0, 2, 2]])
"""
canvas_shape = np.array(canvas_shape)
sample_prediction = predictions[0]
if len(sample_prediction.shape) not in (2, 3):
msg = f"Prediction is no HW or HWC: {sample_prediction.shape}."
raise ValueError(msg)
(
canvas_cum_shape_,
canvas_count_shape_,
add_singleton_dim,
) = _estimate_canvas_parameters(sample_prediction, canvas_shape)
is_on_drive, count_canvas, cum_canvas = _prepare_save_output(
save_path,
cache_count_path,
canvas_cum_shape_,
canvas_count_shape_,
)
def index(arr: np.ndarray, tl: np.ndarray, br: np.ndarray) -> np.ndarray:
"""Helper to shorten indexing."""
return arr[tl[0] : br[0], tl[1] : br[1]]
patch_infos = list(zip(locations, predictions))
for _, patch_info in enumerate(patch_infos):
# position is assumed to be in XY coordinate
(bound_in_wsi, prediction) = patch_info
# convert to XY to YX, and in tl, br
tl_in_wsi = np.array(bound_in_wsi[:2][::-1])
br_in_wsi = np.array(bound_in_wsi[2:][::-1])
old_tl_in_wsi = tl_in_wsi.copy()
# need to do conversion
patch_shape_in_wsi = tuple(br_in_wsi - tl_in_wsi)
# conversion to make cv2 happy
prediction = prediction.astype(np.float32)
prediction = cv2.resize(prediction, patch_shape_in_wsi[::-1])
# ! cv2 resize will remove singleton !
if add_singleton_dim:
prediction = prediction[..., None]
sel = tl_in_wsi < 0
tl_in_wsi[sel] = 0
if np.any(tl_in_wsi >= canvas_shape):
continue
sel = br_in_wsi > canvas_shape
br_in_wsi[sel] = canvas_shape[sel]
# re-calibrate the position in case patch passing the image bound
br_in_patch = br_in_wsi - old_tl_in_wsi
patch_actual_shape = br_in_wsi - tl_in_wsi
tl_in_patch = br_in_patch - patch_actual_shape
# now cropping the prediction region
patch_pred = prediction[
tl_in_patch[0] : br_in_patch[0],
tl_in_patch[1] : br_in_patch[1],
]
patch_count = np.ones(patch_pred.shape[:2])[..., None]
if not is_on_drive:
index(cum_canvas, tl_in_wsi, br_in_wsi)[:] += patch_pred
index(count_canvas, tl_in_wsi, br_in_wsi)[:] += patch_count
else:
old_avg_pred = np.array(index(cum_canvas, tl_in_wsi, br_in_wsi))
old_count = np.array(index(count_canvas, tl_in_wsi, br_in_wsi))
# ! there will be precision error, but we have to live with this
new_count = old_count + patch_count
# retrieve old raw probabilities after summation
old_raw_pred = old_avg_pred * old_count
new_avg_pred = (old_raw_pred + patch_pred) / new_count
index(cum_canvas, tl_in_wsi, br_in_wsi)[:] = new_avg_pred
index(count_canvas, tl_in_wsi, br_in_wsi)[:] = new_count
if not is_on_drive:
cum_canvas /= count_canvas + 1.0e-6
return cum_canvas
@staticmethod
def _prepare_save_dir(save_dir: str | Path | None) -> tuple[Path, Path]:
"""Prepare save directory and cache."""
if save_dir is None:
logger.warning(
"Segmentor will only output to directory. "
"All subsequent output will be saved to current runtime "
"location under folder 'output'. Overwriting may happen! ",
stacklevel=2,
)
save_dir = Path.cwd() / "output"
save_dir = Path(save_dir).resolve()
if save_dir.is_dir():
msg = f"`save_dir` already exists! {save_dir}"
raise ValueError(msg)
save_dir.mkdir(parents=True)
cache_dir = Path(f"{save_dir}/cache")
Path.mkdir(cache_dir, parents=True)
return save_dir, cache_dir
def _update_ioconfig(
self: SemanticSegmentor,
ioconfig: IOSegmentorConfig,
mode: str,
patch_input_shape: IntPair,
patch_output_shape: IntPair,
stride_shape: IntPair,
resolution: Resolution,
units: Units,
) -> IOSegmentorConfig:
"""Update ioconfig according to input parameters.
Args:
ioconfig (:class:`IOSegmentorConfig`):
Object defines information about input and output
placement of patches. When provided,
`patch_input_shape`, `patch_output_shape`,
`stride_shape`, `resolution`, and `units` arguments are
ignored. Otherwise, those arguments will be internally
converted to a :class:`IOSegmentorConfig` object.
mode (str):
Type of input to process. Choose from either `tile` or
`wsi`.
patch_input_shape (tuple):
Size of patches input to the model. The values
are at requested read resolution and must be positive.
patch_output_shape (tuple):
Size of patches output by the model. The values are at
the requested read resolution and must be positive.
stride_shape (tuple):
Stride using during tile and WSI processing. The values
are at requested read resolution and must be positive.
If not provided, `stride_shape=patch_input_shape` is
used.
resolution (Resolution):
Resolution used for reading the image.
units (Units):
Units of resolution used for reading the image.
Returns:
:class:`IOSegmentorConfig`:
Updated ioconfig.
"""
if patch_output_shape is None:
patch_output_shape = patch_input_shape
if stride_shape is None:
stride_shape = patch_output_shape
if ioconfig is None and patch_input_shape is None:
if self.ioconfig is None:
msg = (
"Must provide either `ioconfig` or `patch_input_shape` "
"and `patch_output_shape`"
)
raise ValueError(
msg,
)
ioconfig = copy.deepcopy(self.ioconfig)
elif ioconfig is None:
ioconfig = IOSegmentorConfig(
input_resolutions=[{"resolution": resolution, "units": units}],
output_resolutions=[{"resolution": resolution, "units": units}],
patch_input_shape=patch_input_shape,
patch_output_shape=patch_output_shape,
stride_shape=stride_shape,
)
if mode == "tile":
logger.warning(
"WSIPatchDataset only reads image tile at "
'`units="baseline"`. Resolutions will be converted '
"to baseline value.",
stacklevel=2,
)
return ioconfig.to_baseline()
return ioconfig
def _prepare_workers(self: SemanticSegmentor) -> None:
"""Prepare number of workers."""
self._postproc_workers = None
if self.num_postproc_workers is not None:
self._postproc_workers = ProcessPoolExecutor(
max_workers=self.num_postproc_workers,
)
def _memory_cleanup(self: SemanticSegmentor) -> None:
"""Memory clean up."""
self.imgs = None
self.masks = None
self._cache_dir = None
self._model = None
self._loader = None
self._on_gpu = None
self._futures = None
self._mp_shared_space = None
if self._postproc_workers is not None:
self._postproc_workers.shutdown()
self._postproc_workers = None
def _predict_wsi_handle_exception(
self: SemanticSegmentor,
imgs: list,
wsi_idx: int,
img_path: str | Path,
mode: str,
ioconfig: IOSegmentorConfig,
save_dir: str | Path,
*,
crash_on_exception: bool,
) -> None:
"""Predict on multiple WSIs.
Args:
imgs (list, ndarray):
List of inputs to process. When using `"patch"` mode,
the input must be either a list of images, a list of
image file paths or a numpy array of an image list. When
using `"tile"` or `"wsi"` mode, the input must be a list
of file paths.
wsi_idx (int):
index of current WSI being processed.
img_path(str or Path):
Path to current image.
mode (str):
Type of input to process. Choose from either `tile` or
`wsi`.
ioconfig (:class:`IOSegmentorConfig`):
Object defines information about input and output
placement of patches. When provided,
`patch_input_shape`, `patch_output_shape`,
`stride_shape`, `resolution`, and `units` arguments are
ignored. Otherwise, those arguments will be internally
converted to a :class:`IOSegmentorConfig` object.
save_dir (str or Path):
Output directory when processing multiple tiles and
whole-slide images. By default, it is folder `output`
where the running script is invoked.
crash_on_exception (bool):
If `True`, the running loop will crash if there is any
error during processing a WSI. Otherwise, the loop will
move on to the next wsi for processing.
Returns:
list:
A list of tuple(input_path, save_path) where
`input_path` is the path of the input wsi while
`save_path` corresponds to the output predictions.
"""
try:
wsi_save_path = save_dir / f"{wsi_idx}"
self._predict_one_wsi(wsi_idx, ioconfig, str(wsi_save_path), mode)
# Do not use dict with file name as key, because it can be
# overwritten. It may be user intention to provide files with a
# same name multiple times (maybe they have different root path)
self._outputs.append([str(img_path), str(wsi_save_path)])
# ? will this corrupt old version if control + c midway?
map_file_path = save_dir / "file_map.dat"
# backup old version first
if Path.exists(map_file_path):
old_map_file_path = save_dir / "file_map_old.dat"
shutil.copy(map_file_path, old_map_file_path)
joblib.dump(self._outputs, map_file_path)
# verbose mode, error by passing ?
logging.info("Finish: %d", wsi_idx / len(imgs))
logging.info("--Input: %s", str(img_path))
logging.info("--Output: %s", str(wsi_save_path))
# prevent deep source check because this is bypass and
# delegating error message
except Exception as err: # skipcq: PYL-W0703
wsi_save_path = save_dir.joinpath(f"{wsi_idx}")
if crash_on_exception:
raise err # noqa: TRY201
logging.exception("Crashed on %s", wsi_save_path)
[docs]
def predict( # noqa: PLR0913
self: SemanticSegmentor,
imgs: list,
masks: list | None = None,
mode: str = "tile",
ioconfig: IOSegmentorConfig = None,
patch_input_shape: IntPair = None,
patch_output_shape: IntPair = None,
stride_shape: IntPair = None,
resolution: Resolution = 1.0,
units: Units = "baseline",
save_dir: str | Path | None = None,
*,
on_gpu: bool = True,
crash_on_exception: bool = False,
) -> list[tuple[Path, Path]]:
"""Make a prediction for a list of input data.
By default, if the input model at the object instantiation time
is a pretrained model in the toolbox as well as
`patch_input_shape`, `patch_output_shape`, `stride_shape`,
`resolution`, `units` and `ioconfig` are `None`. The method will
use the `ioconfig` retrieved together with the pretrained model.
Otherwise, either `patch_input_shape`, `patch_output_shape`,
`stride_shape`, `resolution`, `units` or `ioconfig` must be set
else a `Value Error` will be raised.
Args:
imgs (list, ndarray):
List of inputs to process. When using `"patch"` mode,
the input must be either a list of images, a list of
image file paths or a numpy array of an image list. When
using `"tile"` or `"wsi"` mode, the input must be a list
of file paths.
masks (list):
List of masks. Only utilised when processing image tiles
and whole-slide images. Patches are only processed if
they are within a masked area. If not provided, then a
tissue mask will be automatically generated for
whole-slide images or the entire image is processed for
image tiles.
mode (str):
Type of input to process. Choose from either `tile` or
`wsi`.
ioconfig (:class:`IOSegmentorConfig`):
Object defines information about input and output
placement of patches. When provided,
`patch_input_shape`, `patch_output_shape`,
`stride_shape`, `resolution`, and `units` arguments are
ignored. Otherwise, those arguments will be internally
converted to a :class:`IOSegmentorConfig` object.
on_gpu (bool):
Whether to run the model on the GPU.
patch_input_shape (tuple):
Size of patches input to the model. The values
are at requested read resolution and must be positive.
patch_output_shape (tuple):
Size of patches output by the model. The values are at
the requested read resolution and must be positive.
stride_shape (tuple):
Stride using during tile and WSI processing. The values
are at requested read resolution and must be positive.
If not provided, `stride_shape=patch_input_shape` is
used.
resolution (float):
Resolution used for reading the image.
units (Units):
Units of resolution used for reading the image. Choose
from either `"level"`, `"power"` or `"mpp"`.
save_dir (str or pathlib.Path):
Output directory when processing multiple tiles and
whole-slide images. By default, it is folder `output`
where the running script is invoked.
crash_on_exception (bool):
If `True`, the running loop will crash if there is any
error during processing a WSI. Otherwise, the loop will
move on to the next wsi for processing.
Returns:
list:
A list of tuple(input_path, save_path) where
`input_path` is the path of the input wsi while
`save_path` corresponds to the output predictions.
Examples:
>>> # Sample output of a network
>>> wsis = ['A/wsi.svs', 'B/wsi.svs']
>>> predictor = SemanticSegmentor(model='fcn-tissue_mask')
>>> output = predictor.predict(wsis, mode='wsi')
>>> list(output.keys())
[('A/wsi.svs', 'output/0.raw') , ('B/wsi.svs', 'output/1.raw')]
>>> # if a network have 2 output heads, each head output of 'A/wsi.svs'
>>> # will be respectively stored in 'output/0.raw.0', 'output/0.raw.1'
"""
if mode not in ["wsi", "tile"]:
msg = f"{mode} is not a valid mode. Use either `tile` or `wsi`."
raise ValueError(msg)
save_dir, self._cache_dir = self._prepare_save_dir(save_dir)
ioconfig = self._update_ioconfig(
ioconfig,
mode,
patch_input_shape,
patch_output_shape,
stride_shape,
resolution,
units,
)
# use external for testing
self._on_gpu = on_gpu
self._model = misc.model_to(model=self.model, on_gpu=on_gpu)
# workers should be > 0 else Value Error will be thrown
self._prepare_workers()
mp_manager = torch_mp.Manager()
mp_shared_space = mp_manager.Namespace()
self._mp_shared_space = mp_shared_space
ds = self.dataset_class(
ioconfig=ioconfig,
preproc=self.model.preproc_func,
wsi_paths=imgs,
mp_shared_space=mp_shared_space,
mode=mode,
)
loader = torch_data.DataLoader(
ds,
drop_last=False,
batch_size=self.batch_size,
num_workers=self.num_loader_workers,
persistent_workers=self.num_loader_workers > 0,
)
self._loader = loader
self.imgs = imgs
self.masks = masks
# contain input / output prediction mapping
self._outputs = []
# ? what will happen if this crash midway?
# => may not be able to retrieve the result dict
for wsi_idx, img_path in enumerate(imgs):
self._predict_wsi_handle_exception(
imgs=imgs,
wsi_idx=wsi_idx,
img_path=img_path,
mode=mode,
ioconfig=ioconfig,
save_dir=save_dir,
crash_on_exception=crash_on_exception,
)
# clean up the cache directories
try:
shutil.rmtree(self._cache_dir)
except PermissionError: # pragma: no cover
logger.warning("Unable to remove %s", self._cache_dir)
self._memory_cleanup()
return self._outputs