Source code for tiatoolbox.models.engine.patch_predictor

"""This module implements patch level prediction."""

from __future__ import annotations

import copy
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Callable

import numpy as np
import torch
import tqdm

from tiatoolbox import logger
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset
from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig
from tiatoolbox.utils import misc, save_as_json
from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader

if TYPE_CHECKING:  # pragma: no cover
    from tiatoolbox.typing import IntPair, Resolution, Units


[docs] class IOPatchPredictorConfig(IOSegmentorConfig): """Contains patch predictor input and output information.""" def __init__( self: IOPatchPredictorConfig, patch_input_shape: IntPair = None, input_resolutions: Resolution = None, stride_shape: IntPair = None, **kwargs: dict, ) -> None: """Initialize :class:`IOPatchPredictorConfig`.""" stride_shape = patch_input_shape if stride_shape is None else stride_shape super().__init__( input_resolutions=input_resolutions, output_resolutions=[], stride_shape=stride_shape, patch_input_shape=patch_input_shape, patch_output_shape=patch_input_shape, save_resolution=None, **kwargs, )
[docs] class PatchPredictor: r"""Patch level predictor. The models provided by tiatoolbox should give the following results: .. list-table:: PatchPredictor performance on the Kather100K dataset [1] :widths: 15 15 :header-rows: 1 * - Model name - F\ :sub:`1`\ score * - alexnet-kather100k - 0.965 * - resnet18-kather100k - 0.990 * - resnet34-kather100k - 0.991 * - resnet50-kather100k - 0.989 * - resnet101-kather100k - 0.989 * - resnext50_32x4d-kather100k - 0.992 * - resnext101_32x8d-kather100k - 0.991 * - wide_resnet50_2-kather100k - 0.989 * - wide_resnet101_2-kather100k - 0.990 * - densenet121-kather100k - 0.993 * - densenet161-kather100k - 0.992 * - densenet169-kather100k - 0.992 * - densenet201-kather100k - 0.991 * - mobilenet_v2-kather100k - 0.990 * - mobilenet_v3_large-kather100k - 0.991 * - mobilenet_v3_small-kather100k - 0.992 * - googlenet-kather100k - 0.992 .. list-table:: PatchPredictor performance on the PCam dataset [2] :widths: 15 15 :header-rows: 1 * - Model name - F\ :sub:`1`\ score * - alexnet-pcam - 0.840 * - resnet18-pcam - 0.888 * - resnet34-pcam - 0.889 * - resnet50-pcam - 0.892 * - resnet101-pcam - 0.888 * - resnext50_32x4d-pcam - 0.900 * - resnext101_32x8d-pcam - 0.892 * - wide_resnet50_2-pcam - 0.901 * - wide_resnet101_2-pcam - 0.898 * - densenet121-pcam - 0.897 * - densenet161-pcam - 0.893 * - densenet169-pcam - 0.895 * - densenet201-pcam - 0.891 * - mobilenet_v2-pcam - 0.899 * - mobilenet_v3_large-pcam - 0.895 * - mobilenet_v3_small-pcam - 0.890 * - googlenet-pcam - 0.867 Args: model (nn.Module): Use externally defined PyTorch model for prediction with. weights already loaded. Default is `None`. If provided, `pretrained_model` argument is ignored. pretrained_model (str): Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the `docs <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case-insensitive. pretrained_weights (str): Path to the weight of the corresponding `pretrained_model`. >>> predictor = PatchPredictor( ... pretrained_model="resnet18-kather100k", ... pretrained_weights="resnet18_local_weight") batch_size (int): Number of images fed into the model each time. num_loader_workers (int): Number of workers to load the data. Take note that they will also perform preprocessing. verbose (bool): Whether to output logging information. Attributes: img (:obj:`str` or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): A HWC image or a path to WSI. mode (str): Type of input to process. Choose from either `patch`, `tile` or `wsi`. model (nn.Module): Defined PyTorch model. pretrained_model (str): Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the `docs <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case insensitive. batch_size (int): Number of images fed into the model each time. num_loader_worker (int): Number of workers used in torch.utils.data.DataLoader. verbose (bool): Whether to output logging information. Examples: >>> # list of 2 image patches as input >>> data = [img1, img2] >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch') >>> # array of list of 2 image patches as input >>> data = np.array([img1, img2]) >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch') >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch') >>> # list of 2 image tile files as input >>> tile_file = ['path/tile1.png', 'path/tile2.png'] >>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k") >>> output = predictor.predict(tile_file, mode='tile') >>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] >>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k") >>> output = predictor.predict(wsi_file, mode='wsi') References: [1] Kather, Jakob Nikolas, et al. "Predicting survival from colorectal cancer histology slides using deep learning: A retrospective multicenter study." PLoS medicine 16.1 (2019): e1002730. [2] Veeling, Bastiaan S., et al. "Rotation equivariant CNNs for digital pathology." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2018. """ def __init__( self: PatchPredictor, batch_size: int = 8, num_loader_workers: int = 0, model: torch.nn.Module = None, pretrained_model: str | None = None, pretrained_weights: str | None = None, *, verbose: bool = True, ) -> None: """Initialize :class:`PatchPredictor`.""" super().__init__() self.imgs = None self.mode = None if model is None and pretrained_model is None: msg = "Must provide either `model` or `pretrained_model`." raise ValueError(msg) if model is not None: self.model = model ioconfig = None # retrieve iostate from provided model ? else: model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights) self.ioconfig = ioconfig # for storing original self._ioconfig = None # for storing runtime self.model = model # for runtime, such as after wrapping with nn.DataParallel self.pretrained_model = pretrained_model self.batch_size = batch_size self.num_loader_worker = num_loader_workers self.verbose = verbose
[docs] @staticmethod def merge_predictions( img: str | Path | np.ndarray, output: dict, resolution: Resolution | None = None, units: Units | None = None, postproc_func: Callable | None = None, *, return_raw: bool = False, ) -> np.ndarray: """Merge patch level predictions to form a 2-dimensional prediction map. #! Improve how the below reads. The prediction map will contain values from 0 to N, where N is the number of classes. Here, 0 is the background which has not been processed by the model and N is the number of classes predicted by the model. Args: img (:obj:`str` or :obj:`pathlib.Path` or :class:`numpy.ndarray`): A HWC image or a path to WSI. output (dict): Output generated by the model. resolution (Resolution): Resolution of merged predictions. units (Units): Units of resolution used when merging predictions. This must be the same `units` used when processing the data. postproc_func (callable): A function to post-process raw prediction from model. By default, internal code uses the `np.argmax` function. return_raw (bool): Return raw result without applying the `postproc_func` on the assembled image. Returns: :class:`numpy.ndarray`: Merged predictions as a 2D array. Examples: >>> # pseudo output dict from model with 2 patches >>> output = { ... 'resolution': 1.0, ... 'units': 'baseline', ... 'probabilities': [[0.45, 0.55], [0.90, 0.10]], ... 'predictions': [1, 0], ... 'coordinates': [[0, 0, 2, 2], [2, 2, 4, 4]], ... } >>> merged = PatchPredictor.merge_predictions( ... np.zeros([4, 4]), ... output, ... resolution=1.0, ... units='baseline' ... ) >>> merged ... array([[2, 2, 0, 0], ... [2, 2, 0, 0], ... [0, 0, 1, 1], ... [0, 0, 1, 1]]) """ reader = WSIReader.open(img) if isinstance(reader, VirtualWSIReader): logger.warning( "Image is not pyramidal hence read is forced to be " "at `units='baseline'` and `resolution=1.0`.", stacklevel=2, ) resolution = 1.0 units = "baseline" canvas_shape = reader.slide_dimensions(resolution=resolution, units=units) canvas_shape = canvas_shape[::-1] # XY to YX # may crash here, do we need to deal with this ? output_shape = reader.slide_dimensions( resolution=output["resolution"], units=output["units"], ) output_shape = output_shape[::-1] # XY to YX fx = np.array(canvas_shape) / np.array(output_shape) if "probabilities" not in output: coordinates = output["coordinates"] predictions = output["predictions"] denominator = None output = np.zeros(list(canvas_shape), dtype=np.float32) else: coordinates = output["coordinates"] predictions = output["probabilities"] num_class = np.array(predictions[0]).shape[0] denominator = np.zeros(canvas_shape) output = np.zeros([*list(canvas_shape), num_class], dtype=np.float32) for idx, bound in enumerate(coordinates): prediction = predictions[idx] # assumed to be in XY # top-left for output placement tl = np.ceil(np.array(bound[:2]) * fx).astype(np.int32) # bot-right for output placement br = np.ceil(np.array(bound[2:]) * fx).astype(np.int32) output[tl[1] : br[1], tl[0] : br[0]] += prediction if denominator is not None: denominator[tl[1] : br[1], tl[0] : br[0]] += 1 # deal with overlapping regions if denominator is not None: output = output / (np.expand_dims(denominator, -1) + 1.0e-8) if not return_raw: # convert raw probabilities to predictions if postproc_func is not None: output = postproc_func(output) else: output = np.argmax(output, axis=-1) # to make sure background is 0 while class will be 1...N output[denominator > 0] += 1 return output
def _predict_engine( self: PatchPredictor, dataset: torch.utils.data.Dataset, *, return_probabilities: bool = False, return_labels: bool = False, return_coordinates: bool = False, on_gpu: bool = True, ) -> np.ndarray: """Make a prediction on a dataset. The dataset may be mutated. Args: dataset (torch.utils.data.Dataset): PyTorch dataset object created using `tiatoolbox.models.data.classification.Patch_Dataset`. return_probabilities (bool): Whether to return per-class probabilities. return_labels (bool): Whether to return labels. return_coordinates (bool): Whether to return patch coordinates. on_gpu (bool): Whether to run model on the GPU. Returns: :class:`numpy.ndarray`: Model predictions of the input dataset """ dataset.preproc_func = self.model.preproc_func # preprocessing must be defined with the dataset dataloader = torch.utils.data.DataLoader( dataset, num_workers=self.num_loader_worker, batch_size=self.batch_size, drop_last=False, shuffle=False, ) if self.verbose: pbar = tqdm.tqdm( total=int(len(dataloader)), leave=True, ncols=80, ascii=True, position=0, ) # use external for testing model = misc.model_to(model=self.model, on_gpu=on_gpu) cum_output = { "probabilities": [], "predictions": [], "coordinates": [], "labels": [], } for _, batch_data in enumerate(dataloader): batch_output_probabilities = self.model.infer_batch( model, batch_data["image"], on_gpu=on_gpu, ) # We get the index of the class with the maximum probability batch_output_predictions = self.model.postproc_func( batch_output_probabilities, ) # tolist might be very expensive cum_output["probabilities"].extend(batch_output_probabilities.tolist()) cum_output["predictions"].extend(batch_output_predictions.tolist()) if return_coordinates: cum_output["coordinates"].extend(batch_data["coords"].tolist()) if return_labels: # be careful of `s` # We do not use tolist here because label may be of mixed types # and hence collated as list by torch cum_output["labels"].extend(list(batch_data["label"])) if self.verbose: pbar.update() if self.verbose: pbar.close() if not return_probabilities: cum_output.pop("probabilities") if not return_labels: cum_output.pop("labels") if not return_coordinates: cum_output.pop("coordinates") return cum_output def _update_ioconfig( self: PatchPredictor, ioconfig: IOPatchPredictorConfig, patch_input_shape: IntPair, stride_shape: IntPair, resolution: Resolution, units: Units, ) -> IOPatchPredictorConfig: """Updates the ioconfig. Args: ioconfig (IOPatchPredictorConfig): Input ioconfig for PatchPredictor. patch_input_shape (IntPair): Size of patches input to the model. Patches are at requested read resolution, not with respect to level 0, and must be positive. stride_shape (IntPair): Stride using during tile and WSI processing. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, `stride_shape=patch_input_shape`. resolution (Resolution): Resolution used for reading the image. Please see :obj:`WSIReader` for details. units (Units): Units of resolution used for reading the image. Returns: IOPatchPredictorConfig: Updated Patch Predictor IO configuration. """ config_flag = ( patch_input_shape is None, resolution is None, units is None, ) if ioconfig: return ioconfig if self.ioconfig is None and any(config_flag): msg = ( "Must provide either " "`ioconfig` or `patch_input_shape`, `resolution`, and `units`." ) raise ValueError( msg, ) if stride_shape is None: stride_shape = patch_input_shape if self.ioconfig: ioconfig = copy.deepcopy(self.ioconfig) # ! not sure if there is a nicer way to set this if patch_input_shape is not None: ioconfig.patch_input_shape = patch_input_shape if stride_shape is not None: ioconfig.stride_shape = stride_shape if resolution is not None: ioconfig.input_resolutions[0]["resolution"] = resolution if units is not None: ioconfig.input_resolutions[0]["units"] = units return ioconfig return IOPatchPredictorConfig( input_resolutions=[{"resolution": resolution, "units": units}], patch_input_shape=patch_input_shape, stride_shape=stride_shape, ) @staticmethod def _prepare_save_dir(save_dir: str | Path, imgs: list | np.ndarray) -> Path: """Create directory if not defined and number of images is more than 1. Args: save_dir (str or Path): Path to output directory. imgs (list, ndarray): List of inputs to process. Returns: :class:`Path`: Path to output directory. """ if save_dir is None and len(imgs) > 1: logger.warning( "More than 1 WSIs detected but there is no save directory set." "All subsequent output will be saved to current runtime" "location under folder 'output'. Overwriting may happen!", stacklevel=2, ) save_dir = Path.cwd() / "output" elif save_dir is not None and len(imgs) > 1: logger.warning( "When providing multiple whole-slide images / tiles, " "we save the outputs and return the locations " "to the corresponding files.", stacklevel=2, ) if save_dir is not None: save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=False) return save_dir def _predict_patch( self: PatchPredictor, imgs: list | np.ndarray, labels: list, *, return_probabilities: bool, return_labels: bool, on_gpu: bool, ) -> np.ndarray: """Process patch mode. Args: imgs (list, ndarray): List of inputs to process. when using `patch` mode, the input must be either a list of images, a list of image file paths or a numpy array of an image list. When using `tile` or `wsi` mode, the input must be a list of file paths. labels (list): List of labels. If using `tile` or `wsi` mode, then only a single label per image tile or whole-slide image is supported. return_probabilities (bool): Whether to return per-class probabilities. return_labels (bool): Whether to return the labels with the predictions. on_gpu (bool): Whether to run model on the GPU. Returns: :class:`numpy.ndarray`: Model predictions of the input dataset """ if labels: # if a labels is provided, then return with the prediction return_labels = bool(labels) if labels and len(labels) != len(imgs): msg = f"len(labels) != len(imgs) : {len(labels)} != {len(imgs)}" raise ValueError( msg, ) # don't return coordinates if patches are already extracted return_coordinates = False dataset = PatchDataset(imgs, labels) return self._predict_engine( dataset, return_probabilities=return_probabilities, return_labels=return_labels, return_coordinates=return_coordinates, on_gpu=on_gpu, ) def _predict_tile_wsi( # noqa: PLR0913 self: PatchPredictor, imgs: list, masks: list | None, labels: list, mode: str, ioconfig: IOPatchPredictorConfig, save_dir: str | Path, highest_input_resolution: list[dict], *, save_output: bool, return_probabilities: bool, merge_predictions: bool, on_gpu: bool, ) -> list | dict: """Predict on Tile and WSIs. Args: imgs (list, ndarray): List of inputs to process. when using `patch` mode, the input must be either a list of images, a list of image file paths or a numpy array of an image list. When using `tile` or `wsi` mode, the input must be a list of file paths. masks (list): List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if they are within a masked area. If not provided, then a tissue mask will be automatically generated for whole-slide images or the entire image is processed for image tiles. labels (list): List of labels. If using `tile` or `wsi` mode, then only a single label per image tile or whole-slide image is supported. mode (str): Type of input to process. Choose from either `patch`, `tile` or `wsi`. return_probabilities (bool): Whether to return per-class probabilities. on_gpu (bool): Whether to run model on the GPU. ioconfig (IOPatchPredictorConfig): Patch Predictor IO configuration.. merge_predictions (bool): Whether to merge the predictions to form a 2-dimensional map. This is only applicable for `mode='wsi'` or `mode='tile'`. save_dir (str or pathlib.Path): Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` where the running script is invoked. save_output (bool): Whether to save output for a single file. default=False highest_input_resolution (list(dict)): Highest available input resolution. Returns: dict: Results are saved to `save_dir` and a dictionary indicating save location for each input is returned. The dict is in the following format: - img_path: path of the input image. - raw: path to save location for raw prediction, saved in .json. - merged: path to .npy contain merged predictions if `merge_predictions` is `True`. """ # return coordinates of patches processed within a tile / whole-slide image return_coordinates = True input_is_path_like = isinstance(imgs[0], (str, Path)) default_save_dir = ( imgs[0].parent / "output" if input_is_path_like else Path.cwd() ) save_dir = default_save_dir if save_dir is None else Path(save_dir) # None if no output outputs = None self._ioconfig = ioconfig # generate a list of output file paths if number of input images > 1 file_dict = OrderedDict() if len(imgs) > 1: save_output = True for idx, img_path in enumerate(imgs): img_path_ = Path(img_path) img_label = None if labels is None else labels[idx] img_mask = None if masks is None else masks[idx] dataset = WSIPatchDataset( img_path_, mode=mode, mask_path=img_mask, patch_input_shape=ioconfig.patch_input_shape, stride_shape=ioconfig.stride_shape, resolution=ioconfig.input_resolutions[0]["resolution"], units=ioconfig.input_resolutions[0]["units"], ) output_model = self._predict_engine( dataset, return_labels=False, return_probabilities=return_probabilities, return_coordinates=return_coordinates, on_gpu=on_gpu, ) output_model["label"] = img_label # add extra information useful for downstream analysis output_model["pretrained_model"] = self.pretrained_model output_model["resolution"] = highest_input_resolution["resolution"] output_model["units"] = highest_input_resolution["units"] outputs = [output_model] # assign to a list merged_prediction = None if merge_predictions: merged_prediction = self.merge_predictions( img_path_, output_model, resolution=output_model["resolution"], units=output_model["units"], postproc_func=self.model.postproc, ) outputs.append(merged_prediction) if save_output: # dynamic 0 padding img_code = f"{idx:0{len(str(len(imgs)))}d}" save_info = {} save_path = save_dir / img_code raw_save_path = f"{save_path}.raw.json" save_info["raw"] = raw_save_path save_as_json(output_model, raw_save_path) if merge_predictions: merged_file_path = f"{save_path}.merged.npy" np.save(merged_file_path, merged_prediction) save_info["merged"] = merged_file_path file_dict[str(img_path_)] = save_info return file_dict if save_output else outputs
[docs] def predict( # noqa: PLR0913 self: PatchPredictor, imgs: list, masks: list | None = None, labels: list | None = None, mode: str = "patch", ioconfig: IOPatchPredictorConfig | None = None, patch_input_shape: tuple[int, int] | None = None, stride_shape: tuple[int, int] | None = None, resolution: Resolution | None = None, units: Units = None, *, return_probabilities: bool = False, return_labels: bool = False, on_gpu: bool = True, merge_predictions: bool = False, save_dir: str | Path | None = None, save_output: bool = False, ) -> np.ndarray | list | dict: """Make a prediction for a list of input data. Args: imgs (list, ndarray): List of inputs to process. when using `patch` mode, the input must be either a list of images, a list of image file paths or a numpy array of an image list. When using `tile` or `wsi` mode, the input must be a list of file paths. masks (list): List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if they are within a masked area. If not provided, then a tissue mask will be automatically generated for whole-slide images or the entire image is processed for image tiles. labels: List of labels. If using `tile` or `wsi` mode, then only a single label per image tile or whole-slide image is supported. mode (str): Type of input to process. Choose from either `patch`, `tile` or `wsi`. return_probabilities (bool): Whether to return per-class probabilities. return_labels (bool): Whether to return the labels with the predictions. on_gpu (bool): Whether to run model on the GPU. ioconfig (IOPatchPredictorConfig): Patch Predictor IO configuration. patch_input_shape (tuple): Size of patches input to the model. Patches are at requested read resolution, not with respect to level 0, and must be positive. stride_shape (tuple): Stride using during tile and WSI processing. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, `stride_shape=patch_input_shape`. resolution (Resolution): Resolution used for reading the image. Please see :obj:`WSIReader` for details. units (Units): Units of resolution used for reading the image. Choose from either `level`, `power` or `mpp`. Please see :obj:`WSIReader` for details. merge_predictions (bool): Whether to merge the predictions to form a 2-dimensional map. This is only applicable for `mode='wsi'` or `mode='tile'`. save_dir (str or pathlib.Path): Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` where the running script is invoked. save_output (bool): Whether to save output for a single file. default=False Returns: (:class:`numpy.ndarray` or list or dict): Model predictions of the input dataset. If multiple image tiles or whole-slide images are provided as input, or save_output is True, then results are saved to `save_dir` and a dictionary indicating save location for each input is returned. The dict has the following format: - img_path: path of the input image. - raw: path to save location for raw prediction, saved in .json. - merged: path to .npy contain merged predictions if `merge_predictions` is `True`. Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] >>> predictor = PatchPredictor( ... pretrained_model="resnet18-kather100k") >>> output = predictor.predict(wsis, mode="wsi") >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} >>> output['wsi2.svs'] ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} """ if mode not in ["patch", "wsi", "tile"]: msg = f"{mode} is not a valid mode. Use either `patch`, `tile` or `wsi`" raise ValueError( msg, ) if mode == "patch": return self._predict_patch( imgs, labels, return_probabilities=return_probabilities, return_labels=return_labels, on_gpu=on_gpu, ) if not isinstance(imgs, list): msg = "Input to `tile` and `wsi` mode must be a list of file paths." raise TypeError( msg, ) if mode == "wsi" and masks is not None and len(masks) != len(imgs): msg = f"len(masks) != len(imgs) : {len(masks)} != {len(imgs)}" raise ValueError( msg, ) ioconfig = self._update_ioconfig( ioconfig, patch_input_shape, stride_shape, resolution, units, ) if mode == "tile": logger.warning( "WSIPatchDataset only reads image tile at " '`units="baseline"`. Resolutions will be converted ' "to baseline value.", stacklevel=2, ) ioconfig = ioconfig.to_baseline() fx_list = ioconfig.scale_to_highest( ioconfig.input_resolutions, ioconfig.input_resolutions[0]["units"], ) fx_list = zip(fx_list, ioconfig.input_resolutions) fx_list = sorted(fx_list, key=lambda x: x[0]) highest_input_resolution = fx_list[0][1] save_dir = self._prepare_save_dir(save_dir, imgs) return self._predict_tile_wsi( imgs=imgs, masks=masks, labels=labels, mode=mode, return_probabilities=return_probabilities, on_gpu=on_gpu, ioconfig=ioconfig, merge_predictions=merge_predictions, save_dir=save_dir, save_output=save_output, highest_input_resolution=highest_input_resolution, )