Source code for tiatoolbox.models.engine.nucleus_instance_segmentor

"""This module enables nucleus instance segmentation."""

from __future__ import annotations

import uuid
from collections import deque
from typing import Callable

# replace with the sql database once the PR in place
import joblib
import numpy as np
import torch
import tqdm
from shapely.geometry import box as shapely_box
from shapely.strtree import STRtree

from tiatoolbox.models.engine.semantic_segmentor import (
    IOSegmentorConfig,
    SemanticSegmentor,
    WSIStreamDataset,
)
from tiatoolbox.tools.patchextraction import PatchExtractor


def _process_instance_predictions(
    inst_dict: dict,
    ioconfig: IOSegmentorConfig,
    tile_shape: list,
    tile_flag: list,
    tile_mode: int,
    tile_tl: tuple,
    ref_inst_dict: dict,
) -> list | tuple:
    """Function to merge new tile prediction with existing prediction.

    Args:
        inst_dict (dict): Dictionary containing instance information.
        ioconfig (:class:`IOSegmentorConfig`): Object defines information
            about input and output placement of patches.
        tile_shape (list): A list of the tile shape.
        tile_flag (list): A list of flag to indicate if instances within
            an area extended from each side (by `ioconfig.margin`) of
            the tile should be replaced by those within the same spatial
            region in the accumulated output this run. The format is
            [top, bottom, left, right], 1 indicates removal while 0 is not.
            For example, [1, 1, 0, 0] denotes replacing top and bottom instances
            within `ref_inst_dict` with new ones after this processing.
        tile_mode (int): A flag to indicate the type of this tile. There
            are 4 flags:
            - 0: A tile from tile grid without any overlapping, it is not
                an overlapping tile from tile generation. The predicted
                instances are immediately added to accumulated output.
            - 1: Vertical tile strip that stands between two normal tiles
                (flag 0). It has the same height as normal tile but
                less width (hence vertical strip).
            - 2: Horizontal tile strip that stands between two normal tiles
                (flag 0). It has the same width as normal tile but
                less height (hence horizontal strip).
            - 3: tile strip stands at the cross-section of four normal tiles
                (flag 0).
        tile_tl (tuple): Top left coordinates of the current tile.
        ref_inst_dict (dict): Dictionary contains accumulated output. The
            expected format is {instance_id: {type: int,
            contour: List[List[int]], centroid:List[float], box:List[int]}.

    Returns:
        new_inst_dict (dict): A dictionary contain new instances to be accumulated.
            The expected format is {instance_id: {type: int,
            contour: List[List[int]], centroid:List[float], box:List[int]}.
        remove_insts_in_orig (list): List of instance id within `ref_inst_dict`
            to be removed to prevent overlapping predictions. These instances
            are those get cutoff at the boundary due to the tiling process.

    """
    # should be rare, no nuclei detected in input images
    if len(inst_dict) == 0:
        return {}, []

    # !
    m = ioconfig.margin
    w, h = tile_shape
    inst_boxes = [v["box"] for v in inst_dict.values()]
    inst_boxes = np.array(inst_boxes)

    geometries = [shapely_box(*bounds) for bounds in inst_boxes]
    tile_rtree = STRtree(geometries)
    # !

    # create margin bounding box, ordering should match with
    # created tile info flag (top, bottom, left, right)
    boundary_lines = [
        shapely_box(0, 0, w, 1),  # top egde
        shapely_box(0, h - 1, w, h),  # bottom edge
        shapely_box(0, 0, 1, h),  # left
        shapely_box(w - 1, 0, w, h),  # right
    ]
    margin_boxes = [
        shapely_box(0, 0, w, m),  # top egde
        shapely_box(0, h - m, w, h),  # bottom edge
        shapely_box(0, 0, m, h),  # left
        shapely_box(w - m, 0, w, h),  # right
    ]
    # ! this is wrt to WSI coord space, not tile
    margin_lines = [
        [[m, m], [w - m, m]],  # top egde
        [[m, h - m], [w - m, h - m]],  # bottom edge
        [[m, m], [m, h - m]],  # left
        [[w - m, m], [w - m, h - m]],  # right
    ]
    margin_lines = np.array(margin_lines) + tile_tl[None, None]
    margin_lines = [shapely_box(*v.flatten().tolist()) for v in margin_lines]

    # the ids within this match with those within `inst_map`, not UUID
    sel_indices = []
    if tile_mode in [0, 3]:
        # for `full grid` tiles `cross section` tiles
        # -- extend from the boundary by the margin size, remove
        #    nuclei whose entire contours lie within the margin area
        sel_boxes = [
            box
            for idx, box in enumerate(margin_boxes)
            if tile_flag[idx] or tile_mode == 3  # noqa: PLR2004
        ]

        sel_indices = [
            geo
            for bounds in sel_boxes
            for geo in tile_rtree.query(bounds)
            if bounds.contains(geometries[geo])
        ]
    elif tile_mode in [1, 2]:
        # for `horizontal/vertical strip` tiles
        # -- extend from the marked edges (top/bot or left/right) by
        #    the margin size, remove all nuclei lie within the margin
        #    area (including on the margin line)
        # -- remove all nuclei on the boundary also

        sel_boxes = [
            margin_boxes[idx] if flag else boundary_lines[idx]
            for idx, flag in enumerate(tile_flag)
        ]

        sel_indices = [geo for bounds in sel_boxes for geo in tile_rtree.query(bounds)]
    else:
        msg = f"Unknown tile mode {tile_mode}."
        raise ValueError(msg)

    def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list:
        """Helper to retrieved selected instance uids."""
        if len(sel_indices) > 0:
            # not sure how costly this is in large dict
            inst_uids = list(inst_dict.keys())
        return [inst_uids[idx] for idx in sel_indices]

    remove_insts_in_tile = retrieve_sel_uids(sel_indices, inst_dict)

    # external removal only for tile at cross-sections
    # this one should contain UUID with the reference database
    remove_insts_in_orig = []
    if tile_mode == 3:  # noqa: PLR2004
        inst_boxes = [v["box"] for v in ref_inst_dict.values()]
        inst_boxes = np.array(inst_boxes)

        geometries = [shapely_box(*bounds) for bounds in inst_boxes]
        ref_inst_rtree = STRtree(geometries)
        sel_indices = [
            geo for bounds in margin_lines for geo in ref_inst_rtree.query(bounds)
        ]

        remove_insts_in_orig = retrieve_sel_uids(sel_indices, ref_inst_dict)

    # move inst position from tile space back to WSI space
    # an also generate universal uid as replacement for storage
    new_inst_dict = {}
    for inst_uid, inst_info in inst_dict.items():
        if inst_uid not in remove_insts_in_tile:
            inst_info["box"] += np.concatenate([tile_tl] * 2)
            inst_info["centroid"] += tile_tl
            inst_info["contour"] += tile_tl
            inst_uuid = uuid.uuid4().hex
            new_inst_dict[inst_uuid] = inst_info
    return new_inst_dict, remove_insts_in_orig


# Python is yet to be able to natively pickle Object method/static
# method. Only top-level function is passable to multiprocessing as
# caller. May need 3rd party libraries to use method/static method
# otherwise.
def _process_tile_predictions(
    ioconfig: IOSegmentorConfig,
    tile_bounds: np.ndarray,
    tile_flag: list,
    tile_mode: int,
    tile_output: list,
    # this would be replaced by annotation store
    # in the future
    ref_inst_dict: dict,
    postproc: Callable,
    merge_predictions: Callable,
) -> tuple[dict, list]:
    """Function to merge new tile prediction with existing prediction.

    Args:
        ioconfig (:class:`IOSegmentorConfig`):
            Object defines information about input and output placement
            of patches.
        tile_bounds (:class:`numpy.array`):
            Boundary of the current tile, defined as `(top_left_x,
            top_left_y, bottom_x, bottom_y)`.
        tile_flag (list):
            A list of flag to indicate if instances within an area
            extended from each side (by `ioconfig.margin`) of the tile
            should be replaced by those within the same spatial region
            in the accumulated output this run. The format is `[top,
            bottom, left, right]`, 1 indicates removal while 0 is not.
            For example, `[1, 1, 0, 0]` denotes replacing top and bottom
            instances within `ref_inst_dict` with new ones after this
            processing.
        tile_mode (int):
            A flag to indicate the type of this tile. There are 4 flags:
            - 0: A tile from tile grid without any overlapping, it is
                 not an overlapping tile from tile generation. The
                 predicted instances are immediately added to
                 accumulated output.
            - 1: Vertical tile strip that stands between two normal
                 tiles (flag 0). It has the same height as normal tile
                 but less width (hence vertical strip).
            - 2: Horizontal tile strip that stands between two normal
                 tiles (flag 0). It has the same width as normal tile
                 but less height (hence horizontal strip).
            - 3: Tile strip stands at the cross-section of four normal
                 tiles (flag 0).
        tile_output (list):
            A list of patch predictions, that lie within this tile, to
            be merged and processed.
        ref_inst_dict (dict):
            Dictionary contains accumulated output. The expected format
            is `{instance_id: {type: int, contour: List[List[int]],
            centroid:List[float], box:List[int]}`.
        postproc (callable):
            Function to post-process the raw assembled tile.
        merge_predictions (callable):
            Function to merge the `tile_output` into raw tile
            prediction.

    Returns:
        tuple:
            - :py:obj:`dict` - New instances dictionary:
                A dictionary contain new instances to be accumulated.
                The expected format is `{instance_id: {type: int,
                contour: List[List[int]], centroid:List[float],
                box:List[int]}`.
            - :py:obj:`list` - Instances IDs to remove:
                List of instance IDs within `ref_inst_dict` to be
                removed to prevent overlapping predictions. These
                instances are those get cut off at the boundary due to
                the tiling process.

    """
    locations, predictions = list(zip(*tile_output))

    # convert from WSI space to tile space
    tile_tl = tile_bounds[:2]
    tile_br = tile_bounds[2:]
    locations = [np.reshape(loc, (2, -1)) for loc in locations]
    locations_in_tile = [loc - tile_tl[None] for loc in locations]
    locations_in_tile = [loc.flatten() for loc in locations_in_tile]
    locations_in_tile = np.array(locations_in_tile)

    tile_shape = tile_br - tile_tl  # in width height

    # As the placement output is calculated wrt the highest possible
    # resolution within input, the output will need to re-calibrate if
    # it is at different resolution than the input.
    ioconfig = ioconfig.to_baseline()
    fx_list = [v["resolution"] for v in ioconfig.output_resolutions]

    head_raws = []
    for idx, fx in enumerate(fx_list):
        head_tile_shape = np.ceil(tile_shape * fx).astype(np.int32)
        head_locations = np.ceil(locations_in_tile * fx).astype(np.int32)
        head_predictions = [v[idx][0] for v in predictions]
        head_raw = merge_predictions(
            head_tile_shape[::-1],
            head_predictions,
            head_locations,
        )
        head_raws.append(head_raw)
    _, inst_dict = postproc(head_raws)

    new_inst_dict, remove_insts_in_orig = _process_instance_predictions(
        inst_dict,
        ioconfig,
        tile_shape,
        tile_flag,
        tile_mode,
        tile_tl,
        ref_inst_dict,
    )

    return new_inst_dict, remove_insts_in_orig


[docs] class NucleusInstanceSegmentor(SemanticSegmentor): """An engine specifically designed to handle tiles or WSIs inference. Note, if `model` is supplied in the arguments, it will ignore the `pretrained_model` and `pretrained_weights` arguments. Additionally, unlike `SemanticSegmentor`, this engine assumes each input model will ultimately predict one single target: the nucleus instance within the tiles/WSIs. Each WSI prediction will be store under a `.dat` file which contains a dictionary of form: .. code-block:: yaml inst_uid: # top left and bottom right of bounding box box: (start_x, start_y, end_x, end_y) # centroid coordinates centroid: (x, y) # array/list of points contour: [(x1, y1), (x2, y2), ...] # the type of nuclei type: int # the probabilities of being this nuclei type prob: float Args: model (nn.Module): Use externally defined PyTorch model for prediction with. weights already loaded. Default is `None`. If provided, `pretrained_model` argument is ignored. pretrained_model (str): Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the `docs <https://tia-toolbox.readthedocs.io/en/latest/pretrained.html>`_. By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case insensitive. pretrained_weights (str): Path to the weight of the corresponding `pretrained_model`. batch_size (int) : Number of images fed into the model each time. num_loader_workers (int): Number of workers to load the data. Take note that they will also perform preprocessing. num_postproc_workers (int): Number of workers to post-process predictions. verbose (bool): Whether to output logging information. dataset_class (obj): Dataset class to be used instead of default. auto_generate_mask (bool): To automatically generate tile/WSI tissue mask if is not provided. Examples: >>> # Sample output of a network >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] >>> predictor = SemanticSegmentor(model='hovernet_fast-pannuke') >>> output = predictor.predict(wsis, mode='wsi') >>> list(output.keys()) [('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')] >>> # Each output of 'A/wsi.svs' >>> # will be respectively stored in 'output/0.dat', 'output/0.dat' """ def __init__( self: NucleusInstanceSegmentor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, model: torch.nn.Module | None = None, pretrained_model: str | None = None, pretrained_weights: str | None = None, dataset_class: Callable = WSIStreamDataset, *, verbose: bool = True, auto_generate_mask: bool = False, ) -> None: """Initialize :class:`NucleusInstanceSegmentor`.""" super().__init__( batch_size=batch_size, num_loader_workers=num_loader_workers, num_postproc_workers=num_postproc_workers, model=model, pretrained_model=pretrained_model, pretrained_weights=pretrained_weights, verbose=verbose, auto_generate_mask=auto_generate_mask, dataset_class=dataset_class, ) # default is None in base class and is un-settable # hence we redefine the namespace here self.num_postproc_workers = ( num_postproc_workers if num_postproc_workers > 0 else None ) # adding more runtime placeholder self._wsi_inst_info = None self._futures = [] @staticmethod def _get_tile_info( image_shape: list[int] | np.ndarray, ioconfig: IOSegmentorConfig, ) -> list[list, ...]: """Generating tile information. To avoid out of memory problem when processing WSI-scale in general, the predictor will perform the inference and assemble on a large image tiles (each may have size of 4000x4000 compared to patch output of 256x256) first before stitching every tiles by the end to complete the WSI output. For nuclei instance segmentation, the stitching process will require removal of predictions within some bounding areas. This function generates both the tile placement and the flag to indicate how the removal should be done to achieve the above goal. Args: image_shape (:class:`numpy.ndarray`, list(int)): The shape of WSI to extract the tile from, assumed to be in `[width, height]`. ioconfig (:obj:IOSegmentorConfig): The input and output configuration objects. Returns: list: - :py:obj:`list` - Tiles and flags - :class:`numpy.ndarray` - Grid tiles - :class:`numpy.ndarray` - Removal flags - :py:obj:`list` - Tiles and flags - :class:`numpy.ndarray` - Vertical strip tiles - :class:`numpy.ndarray` - Removal flags - :py:obj:`list` - Tiles and flags - :class:`numpy.ndarray` - Horizontal strip tiles - :class:`numpy.ndarray` - Removal flags - :py:obj:`list` - Tiles and flags - :class:`numpy.ndarray` - Cross-section tiles - :class:`numpy.ndarray` - Removal flags """ margin = np.array(ioconfig.margin) tile_shape = np.array(ioconfig.tile_shape) tile_shape = ( np.floor(tile_shape / ioconfig.patch_output_shape) * ioconfig.patch_output_shape ).astype(np.int32) image_shape = np.array(image_shape) (_, tile_outputs) = PatchExtractor.get_coordinates( image_shape=image_shape, patch_input_shape=tile_shape, patch_output_shape=tile_shape, stride_shape=tile_shape, ) # * === Now generating the flags to indicate which side should # * === be removed in postproc callback boxes = tile_outputs # This saves computation time if the image is smaller than the expected tile if np.all(image_shape <= tile_shape): flag = np.zeros([boxes.shape[0], 4], dtype=np.int32) return [[boxes, flag]] # * remove all sides for boxes # unset for those lie within the selection def unset_removal_flag(boxes: tuple, removal_flag: np.ndarray) -> np.ndarray: """Unset removal flags for tiles intersecting image boundaries.""" sel_boxes = [ shapely_box(0, 0, w, 0), # top edge shapely_box(0, h, w, h), # bottom edge shapely_box(0, 0, 0, h), # left shapely_box(w, 0, w, h), # right ] geometries = [shapely_box(*bounds) for bounds in boxes] spatial_indexer = STRtree(geometries) for idx, sel_box in enumerate(sel_boxes): sel_indices = list(spatial_indexer.query(sel_box)) removal_flag[sel_indices, idx] = 0 return removal_flag w, h = image_shape boxes = tile_outputs # expand to full four corners boxes_br = boxes[:, 2:] boxes_tr = np.dstack([boxes[:, 2], boxes[:, 1]])[0] boxes_bl = np.dstack([boxes[:, 0], boxes[:, 3]])[0] # * remove edges on all sides, excluding edges at on WSI boundary flag = np.ones([boxes.shape[0], 4], dtype=np.int32) flag = unset_removal_flag(boxes, flag) info = deque([[boxes, flag]]) # * create vertical boxes at tile boundary and # * flag top and bottom removal, excluding those # * on the WSI boundary # ------------------- # | =|= =|= | # | =|= =|= | # | >=|= >=|= | # ------------------- # | >=|= >=|= | # | =|= =|= | # | >=|= >=|= | # ------------------- # | >=|= >=|= | # | =|= =|= | # | =|= =|= | # ------------------- # only select boxes having right edges removed sel_indices = np.nonzero(flag[..., 3]) _boxes = np.concatenate( [ boxes_tr[sel_indices] - np.array([margin, 0])[None], boxes_br[sel_indices] + np.array([margin, 0])[None], ], axis=-1, ) _flag = np.full([_boxes.shape[0], 4], 0, dtype=np.int32) _flag[:, [0, 1]] = 1 _flag = unset_removal_flag(_boxes, _flag) info.append([_boxes, _flag]) # * create horizontal boxes at tile boundary and # * flag left and right removal, excluding those # * on the WSI boundary # ------------- # | | | | # | v|v v|v | # |===|===|===| # ------------- # |===|===|===| # | | | | # | | | | # ------------- # only select boxes having bottom edges removed sel_indices = np.nonzero(flag[..., 1]) # top bottom left right _boxes = np.concatenate( [ boxes_bl[sel_indices] - np.array([0, margin])[None], boxes_br[sel_indices] + np.array([0, margin])[None], ], axis=-1, ) _flag = np.full([_boxes.shape[0], 4], 0, dtype=np.int32) _flag[:, [2, 3]] = 1 _flag = unset_removal_flag(_boxes, _flag) info.append([_boxes, _flag]) # * create boxes at tile cross-section and all sides # ------------------------ # | | | | | # | v| | | | # | > =|= =|= =|= | # -----=-=---=-=---=-=---- # | =|= =|= =|= | # | | | | | # | =|= =|= =|= | # -----=-=---=-=---=-=---- # | =|= =|= =|= | # | | | | | # | | | | | # ------------------------ # only select boxes having both right and bottom edges removed sel_indices = np.nonzero(np.prod(flag[:, [1, 3]], axis=-1)) _boxes = np.concatenate( [ boxes_br[sel_indices] - np.array([2 * margin, 2 * margin])[None], boxes_br[sel_indices] + np.array([2 * margin, 2 * margin])[None], ], axis=-1, ) flag = np.full([_boxes.shape[0], 4], 1, dtype=np.int32) info.append([_boxes, flag]) return info def _to_shared_space( self: NucleusInstanceSegmentor, wsi_idx: int, patch_inputs: list, patch_outputs: list, ) -> None: """Helper functions to transfer variable to shared space. We modify the shared space so that we can update worker info without needing to re-create the worker. There should be no race-condition because only by looping `self._loader` in main thread will trigger querying new data from each worker, and this portion should still be in sequential execution order in the main thread. Args: wsi_idx (int): The index of the WSI to be processed. This is used to retrieve the file path. patch_inputs (list): A list of coordinates in `[start_x, start_y, end_x, end_y]` format indicating the read location of the patch in the WSI image. The coordinates are in the highest resolution defined in `self.ioconfig`. patch_outputs (list): A list of coordinates in `[start_x, start_y, end_x, end_y]` format indicating the write location of the patch in the WSI image. The coordinates are in the highest resolution defined in `self.ioconfig`. """ patch_inputs = torch.from_numpy(patch_inputs).share_memory_() patch_outputs = torch.from_numpy(patch_outputs).share_memory_() self._mp_shared_space.patch_inputs = patch_inputs self._mp_shared_space.patch_outputs = patch_outputs self._mp_shared_space.wsi_idx = torch.Tensor([wsi_idx]).share_memory_() def _infer_once(self: NucleusInstanceSegmentor) -> list: """Running the inference only once for the currently active dataloader.""" num_steps = len(self._loader) pbar_desc = "Process Batch: " pbar = tqdm.tqdm( desc=pbar_desc, leave=True, total=int(num_steps), ncols=80, ascii=True, position=0, ) cum_output = [] for _, batch_data in enumerate(self._loader): sample_datas, sample_infos = batch_data batch_size = sample_infos.shape[0] # ! depending on the protocol of the output within infer_batch # ! this may change, how to enforce/document/expose this in a # ! sensible way? # assume to return a list of L output, # each of shape N x etc. (N=batch size) sample_outputs = self.model.infer_batch( self._model, sample_datas, on_gpu=self._on_gpu, ) # repackage so that it's a N list, each contains # L x etc. output sample_outputs = [np.split(v, batch_size, axis=0) for v in sample_outputs] sample_outputs = list(zip(*sample_outputs)) # tensor to numpy, costly? sample_infos = sample_infos.numpy() sample_infos = np.split(sample_infos, batch_size, axis=0) sample_outputs = list(zip(sample_infos, sample_outputs)) cum_output.extend(sample_outputs) pbar.update() pbar.close() return cum_output def _predict_one_wsi( self: NucleusInstanceSegmentor, wsi_idx: int, ioconfig: IOSegmentorConfig, save_path: str, mode: str, ) -> None: """Make a prediction on tile/wsi. Args: wsi_idx (int): Index of the tile/wsi to be processed within `self`. ioconfig (IOSegmentorConfig): Object which defines I/O placement during inference and when assembling back to full tile/wsi. save_path (str): Location to save output prediction as well as possible intermediate results. mode (str): `tile` or `wsi` to indicate run mode. """ wsi_path = self.imgs[wsi_idx] mask_path = None if self.masks is None else self.masks[wsi_idx] wsi_reader, mask_reader = self.get_reader( wsi_path, mask_path, mode, auto_get_mask=self.auto_generate_mask, ) # assume ioconfig has already been converted to `baseline` for `tile` mode resolution = ioconfig.highest_input_resolution wsi_proc_shape = wsi_reader.slide_dimensions(**resolution) # * retrieve patch placement # this is in XY (patch_inputs, patch_outputs) = self.get_coordinates(wsi_proc_shape, ioconfig) if mask_reader is not None: sel = self.filter_coordinates(mask_reader, patch_outputs, **resolution) patch_outputs = patch_outputs[sel] patch_inputs = patch_inputs[sel] # assume to be in [top_left_x, top_left_y, bot_right_x, bot_right_y] geometries = [shapely_box(*bounds) for bounds in patch_outputs] spatial_indexer = STRtree(geometries) # * retrieve tile placement and tile info flag # tile shape will always be corrected to be multiple of output tile_info_sets = self._get_tile_info(wsi_proc_shape, ioconfig) # ! running order of each set matters ! self._futures = [] # ! DEPRECATION: # ! will be deprecated upon finalization of SQL annotation store self._wsi_inst_info = {} # ! for set_idx, (set_bounds, set_flags) in enumerate(tile_info_sets): for tile_idx, tile_bounds in enumerate(set_bounds): tile_flag = set_flags[tile_idx] # select any patches that have their output # within the current tile sel_box = shapely_box(*tile_bounds) sel_indices = list(spatial_indexer.query(sel_box)) # there is nothing in the tile # Ignore coverage as the condition is difficult # to reproduce on travis. if len(sel_indices) == 0: # pragma: no cover continue tile_patch_inputs = patch_inputs[sel_indices] tile_patch_outputs = patch_outputs[sel_indices] self._to_shared_space(wsi_idx, tile_patch_inputs, tile_patch_outputs) tile_infer_output = self._infer_once() self._process_tile_predictions( ioconfig, tile_bounds, tile_flag, set_idx, tile_infer_output, ) self._merge_post_process_results() joblib.dump(self._wsi_inst_info, f"{save_path}.dat") # may need to chain it with parents self._wsi_inst_info = None # clean up def _process_tile_predictions( self: NucleusInstanceSegmentor, ioconfig: IOSegmentorConfig, tile_bounds: np.ndarray, tile_flag: list, tile_mode: int, tile_output: list, ) -> None: """Function to dispatch parallel post processing.""" args = [ ioconfig, tile_bounds, tile_flag, tile_mode, tile_output, self._wsi_inst_info, self.model.postproc_func, self.merge_prediction, ] if self._postproc_workers is not None: future = self._postproc_workers.submit(_process_tile_predictions, *args) else: future = _process_tile_predictions(*args) self._futures.append(future) def _merge_post_process_results(self: NucleusInstanceSegmentor) -> None: """Helper to aggregate results from parallel workers.""" def callback(new_inst_dict: dict, remove_uuid_list: list) -> None: """Helper to aggregate worker's results.""" # ! DEPRECATION: # ! will be deprecated upon finalization of SQL annotation store self._wsi_inst_info.update(new_inst_dict) for inst_uuid in remove_uuid_list: self._wsi_inst_info.pop(inst_uuid, None) # ! for future in self._futures: # not actually future but the results if self._postproc_workers is None: callback(*future) continue # some errors happen, log it and propagate exception # ! this will lead to discard a bunch of # ! inferred tiles within this current WSI if future.exception() is not None: raise future.exception() # aggregate the result via callback result = future.result() # manually call the callback rather than # attaching it when receiving/creating the future callback(*result)