Source code for tiatoolbox.models.architecture.hovernet

"""Define HoVerNet architecture."""

from __future__ import annotations

import math
from collections import OrderedDict

import cv2
import numpy as np
import torch
import torch.nn.functional as F  # noqa: N812
from scipy import ndimage
from skimage.morphology import remove_small_objects
from skimage.segmentation import watershed
from torch import nn

from tiatoolbox.models.architecture.utils import (
    UpSample2x,
    centre_crop,
    centre_crop_to_shape,
)
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.utils import misc
from tiatoolbox.utils.misc import get_bounding_box


[docs] class TFSamepaddingLayer(nn.Module): """To align with tensorflow `same` padding. Putting this before any conv layer that needs padding. Here, we assume kernel has same height and width for simplicity. """ def __init__(self: TFSamepaddingLayer, ksize: int, stride: int) -> None: """Initialize :class:`TFSamepaddingLayer`.""" super().__init__() self.ksize = ksize self.stride = stride
[docs] def forward(self: TFSamepaddingLayer, x: torch.Tensor) -> torch.Tensor: """Logic for using layers defined in init.""" if x.shape[2] % self.stride == 0: pad = max(self.ksize - self.stride, 0) else: pad = max(self.ksize - (x.shape[2] % self.stride), 0) if pad % 2 == 0: pad_val = pad // 2 padding = (pad_val, pad_val, pad_val, pad_val) else: pad_val_start = pad // 2 pad_val_end = pad - pad_val_start padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end) return F.pad(x, padding, "constant", 0)
[docs] class DenseBlock(nn.Module): """Dense Convolutional Block. This convolutional block supports only `valid` padding. References: Huang, Gao, et al. "Densely connected convolutional networks." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. """ def __init__( self: DenseBlock, in_ch: int, unit_ksizes: list[int], unit_chs: list[int], unit_count: int, split: int = 1, ) -> None: """Initialize :class:`DenseBlock`.""" super().__init__() if len(unit_ksizes) != len(unit_chs): msg = "Unbalance Unit Info." raise ValueError(msg) self.nr_unit = unit_count self.in_ch = in_ch # weights value may not match with tensorflow version # due to different default initialization scheme between # torch and tensorflow def get_unit_block(unit_in_ch: int) -> nn.Sequential: """Helper function to make it less long.""" layers = OrderedDict( [ ("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), ("preact_bna/relu", nn.ReLU(inplace=True)), ( "conv1", nn.Conv2d( unit_in_ch, unit_chs[0], unit_ksizes[0], stride=1, padding=0, bias=False, ), ), ("conv1/bn", nn.BatchNorm2d(unit_chs[0], eps=1e-5)), ("conv1/relu", nn.ReLU(inplace=True)), ( "conv2", nn.Conv2d( unit_chs[0], unit_chs[1], unit_ksizes[1], groups=split, stride=1, padding=0, bias=False, ), ), ], ) return nn.Sequential(layers) unit_in_ch = in_ch self.units = nn.ModuleList() for _ in range(unit_count): self.units.append(get_unit_block(unit_in_ch)) unit_in_ch += unit_chs[1] self.blk_bna = nn.Sequential( OrderedDict( [ ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), ("relu", nn.ReLU(inplace=True)), ], ), )
[docs] def forward(self: DenseBlock, prev_feat: torch.Tensor) -> nn.Sequential: """Logic for using layers defined in init.""" for idx in range(self.nr_unit): new_feat = self.units[idx](prev_feat) prev_feat = centre_crop_to_shape(prev_feat, new_feat) prev_feat = torch.cat([prev_feat, new_feat], dim=1) return self.blk_bna(prev_feat)
[docs] class ResidualBlock(nn.Module): """Residual block. References: He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. """ def __init__( self: ResidualBlock, in_ch: int, unit_ksizes: list[int], unit_chs: list[int], unit_count: int, stride: int = 1, ) -> None: """Initialize :class:`ResidualBlock`.""" super().__init__() if len(unit_ksizes) != len(unit_chs): msg = "Unbalance Unit Info." raise ValueError(msg) self.nr_unit = unit_count self.in_ch = in_ch # ! For inference only so init values for batch norm may not match tensorflow unit_in_ch = in_ch self.units = nn.ModuleList() for idx in range(unit_count): unit_layer = [ ("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), ("preact/relu", nn.ReLU(inplace=True)), ( "conv1", nn.Conv2d( unit_in_ch, unit_chs[0], unit_ksizes[0], stride=1, padding=0, bias=False, ), ), ("conv1/bn", nn.BatchNorm2d(unit_chs[0], eps=1e-5)), ("conv1/relu", nn.ReLU(inplace=True)), ( "conv2/pad", TFSamepaddingLayer( ksize=unit_ksizes[1], stride=stride if idx == 0 else 1, ), ), ( "conv2", nn.Conv2d( unit_chs[0], unit_chs[1], unit_ksizes[1], stride=stride if idx == 0 else 1, padding=0, bias=False, ), ), ("conv2/bn", nn.BatchNorm2d(unit_chs[1], eps=1e-5)), ("conv2/relu", nn.ReLU(inplace=True)), ( "conv3", nn.Conv2d( unit_chs[1], unit_chs[2], unit_ksizes[2], stride=1, padding=0, bias=False, ), ), ] # has BatchNorm-Activation layers to conclude each # previous block so must not put pre activation for the first # unit of this block unit_layer = unit_layer if idx != 0 else unit_layer[2:] self.units.append(nn.Sequential(OrderedDict(unit_layer))) unit_in_ch = unit_chs[-1] if in_ch != unit_chs[-1] or stride != 1: self.shortcut = nn.Conv2d(in_ch, unit_chs[-1], 1, stride=stride, bias=False) else: self.shortcut = None self.blk_bna = nn.Sequential( OrderedDict( [ ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), ("relu", nn.ReLU(inplace=True)), ], ), )
[docs] def forward(self: ResidualBlock, prev_feat: torch.Tensor) -> nn.Sequential: """Logic for using layers defined in init.""" shortcut = prev_feat if self.shortcut is None else self.shortcut(prev_feat) for _, unit in enumerate(self.units): new_feat = prev_feat new_feat = unit(new_feat) prev_feat = new_feat + shortcut shortcut = prev_feat return self.blk_bna(prev_feat)
[docs] class HoVerNet(ModelABC): """Initialise HoVerNet [1]. The tiatoolbox models should produce the following results: .. list-table:: HoVerNet segmentation performance on the CoNSeP dataset [1] :widths: 15 15 15 15 15 15 15 :header-rows: 1 * - Model name - Data set - DICE - AJI - DQ - SQ - PQ * - hovernet-original-consep - CoNSeP - 0.85 - 0.57 - 0.70 - 0.78 - 0.55 .. list-table:: HoVerNet segmentation performance on the Kumar dataset [2] :widths: 15 15 15 15 15 15 15 :header-rows: 1 * - Model name - Data set - DICE - AJI - DQ - SQ - PQ * - hovernet-original-kumar - Kumar - 0.83 - 0.62 - 0.77 - 0.77 - 0.60 Args: num_input_channels (int): Number of channels in input. num_types (int): Number of nuclei types within the predictions. Once defined, a branch dedicated for typing is created. By default, no typing (`num_types=None`) is used. mode (str): To use architecture defined in as in original paper (`original`) or the one used in PanNuke paper (`fast`). References: [1] Graham, Simon, et al. "HoVerNet: Simultaneous segmentation and classification of nuclei in multi-tissue histology images." Medical Image Analysis 58 (2019): 101563. [2] Kumar, Neeraj, et al. "A dataset and a technique for generalized nuclear segmentation for computational pathology." IEEE transactions on medical imaging 36.7 (2017): 1550-1560. """ def __init__( self: HoVerNet, num_input_channels: int = 3, num_types: int | None = None, mode: str = "original", nuc_type_dict: dict | None = None, ) -> None: """Initialize :class:`HoVerNet`.""" super().__init__() self.mode = mode self.num_types = num_types self.nuc_type_dict = nuc_type_dict if mode not in ["original", "fast"]: msg = ( f"Invalid mode {mode} for HoVerNet. " f"Only support `original` or `fast`." ) raise ValueError( msg, ) modules = [ ( "/", nn.Conv2d(num_input_channels, 64, 7, stride=1, padding=0, bias=False), ), ("bn", nn.BatchNorm2d(64, eps=1e-5)), ("relu", nn.ReLU(inplace=True)), ] # pre-pend the padding for `fast` mode if mode == "fast": modules = [("pad", TFSamepaddingLayer(ksize=7, stride=1)), *modules] self.conv0 = nn.Sequential(OrderedDict(modules)) self.d0 = ResidualBlock(64, [1, 3, 1], [64, 64, 256], 3, stride=1) self.d1 = ResidualBlock(256, [1, 3, 1], [128, 128, 512], 4, stride=2) self.d2 = ResidualBlock(512, [1, 3, 1], [256, 256, 1024], 6, stride=2) self.d3 = ResidualBlock(1024, [1, 3, 1], [512, 512, 2048], 3, stride=2) self.conv_bot = nn.Conv2d(2048, 1024, 1, stride=1, padding=0, bias=False) ksize = 5 if mode == "original" else 3 if num_types is None: self.decoder = nn.ModuleDict( OrderedDict( [ ("np", HoVerNet._create_decoder_branch(ksize=ksize, out_ch=2)), ("hv", HoVerNet._create_decoder_branch(ksize=ksize, out_ch=2)), ], ), ) else: self.decoder = nn.ModuleDict( OrderedDict( [ ( "tp", HoVerNet._create_decoder_branch( ksize=ksize, out_ch=num_types, ), ), ("np", HoVerNet._create_decoder_branch(ksize=ksize, out_ch=2)), ("hv", HoVerNet._create_decoder_branch(ksize=ksize, out_ch=2)), ], ), ) self.upsample2x = UpSample2x()
[docs] def forward( # skipcq: PYL-W0221 self: HoVerNet, input_tensor: torch.Tensor, ) -> dict: """Logic for using layers defined in init. This method defines how layers are used in forward operation. Args: input_tensor (torch.Tensor): Input images, the tensor is in the shape of NCHW. Returns: dict: A dictionary containing the inference output. The expected format os {decoder_name: prediction}. """ input_tensor = input_tensor / 255.0 # to 0-1 range to match XY d0 = self.conv0(input_tensor) d0 = self.d0(d0) d1 = self.d1(d0) d2 = self.d2(d1) d3 = self.d3(d2) d3 = self.conv_bot(d3) d = [d0, d1, d2, d3] if self.mode == "original": d[0] = centre_crop(d[0], [184, 184]) d[1] = centre_crop(d[1], [72, 72]) else: d[0] = centre_crop(d[0], [92, 92]) d[1] = centre_crop(d[1], [36, 36]) out_dict = OrderedDict() for branch_name, branch_desc in self.decoder.items(): u3 = self.upsample2x(d[-1]) + d[-2] u3 = branch_desc[0](u3) u2 = self.upsample2x(u3) + d[-3] u2 = branch_desc[1](u2) u1 = self.upsample2x(u2) + d[-4] u1 = branch_desc[2](u1) u0 = branch_desc[3](u1) out_dict[branch_name] = u0 return out_dict
@staticmethod def _create_decoder_branch(out_ch: int = 2, ksize: int = 5) -> nn.Sequential: """Helper to create a decoder branch.""" modules = [ ("conva", nn.Conv2d(1024, 256, ksize, stride=1, padding=0, bias=False)), ("dense", DenseBlock(256, [1, ksize], [128, 32], 8, split=4)), ( "convf", nn.Conv2d(512, 512, 1, stride=1, padding=0, bias=False), ), ] u3 = nn.Sequential(OrderedDict(modules)) modules = [ ("conva", nn.Conv2d(512, 128, ksize, stride=1, padding=0, bias=False)), ("dense", DenseBlock(128, [1, ksize], [128, 32], 4, split=4)), ( "convf", nn.Conv2d(256, 256, 1, stride=1, padding=0, bias=False), ), ] u2 = nn.Sequential(OrderedDict(modules)) modules = [ ("conva/pad", TFSamepaddingLayer(ksize=ksize, stride=1)), ( "conva", nn.Conv2d(256, 64, ksize, stride=1, padding=0, bias=False), ), ] u1 = nn.Sequential(OrderedDict(modules)) modules = [ ("bn", nn.BatchNorm2d(64, eps=1e-5)), ("relu", nn.ReLU(inplace=True)), ( "conv", nn.Conv2d(64, out_ch, 1, stride=1, padding=0, bias=True), ), ] u0 = nn.Sequential(OrderedDict(modules)) return nn.Sequential( OrderedDict([("u3", u3), ("u2", u2), ("u1", u1), ("u0", u0)]), ) @staticmethod def _proc_np_hv( np_map: np.ndarray, hv_map: np.ndarray, scale_factor: float = 1, ) -> np.ndarray: """Extract Nuclei Instance with NP and HV Map. Sobel will be applied on horizontal and vertical channel in `hv_map` to derive an energy landscape which highlight possible nuclei instance boundaries. Afterward, watershed with markers is applied on the above energy map using the `np_map` as filter to remove background regions. Args: np_map (np.ndarray): An image of shape (height, width, 1) which contains the probabilities of a pixel being a nucleus. hv_map (np.ndarray): An array of shape (height, width, 2) which contains the horizontal (channel 0) and vertical (channel 1) maps of possible instances within the image. scale_factor (float): The scale factor for processing nuclei. The scale assumes an image of resolution 0.25 microns per pixel. Default is therefore 1 for HoVer-Net. Returns: :class:`numpy.ndarray`: An np.ndarray of shape (height, width) where each non-zero values within the array correspond to one detected nuclei instances. """ blb_raw = np_map[..., 0] h_dir_raw = hv_map[..., 0] v_dir_raw = hv_map[..., 1] # processing blb = np.array(blb_raw >= 0.5, dtype=np.int32) # noqa: PLR2004 blb = ndimage.label(blb)[0] blb = remove_small_objects(blb, min_size=10) blb[blb > 0] = 1 # background is 0 already h_dir = cv2.normalize( h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F, ) v_dir = cv2.normalize( v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F, ) ksize = int((20 * scale_factor) + 1) obj_size = math.ceil(10 * (scale_factor**2)) # Get resolution specific filters etc. sobel_h = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize) sobel_v = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize) sobel_h = 1 - ( cv2.normalize( sobel_h, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F, ) ) sobel_v = 1 - ( cv2.normalize( sobel_v, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F, ) ) overall = np.maximum(sobel_h, sobel_v) overall = overall - (1 - blb) overall[overall < 0] = 0 dist = (1.0 - overall) * blb # * nuclei values form mountains so inverse to get basins dist = -cv2.GaussianBlur(dist, (3, 3), 0) overall = np.array(overall >= 0.4, dtype=np.int32) # noqa: PLR2004 marker = blb - overall marker[marker < 0] = 0 marker = ndimage.binary_fill_holes(marker).astype("uint8") kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) marker = ndimage.label(marker)[0] marker = remove_small_objects(marker, min_size=obj_size) return watershed(dist, markers=marker, mask=blb)
[docs] @staticmethod def get_instance_info(pred_inst: np.ndarray, pred_type: np.ndarray = None) -> dict: """To collect instance information and store it within a dictionary. Args: pred_inst (:class:`numpy.ndarray`): An image of shape (height, width) which contains the probabilities of a pixel being a nuclei. pred_type (:class:`numpy.ndarray`): An image of shape (height, width, 1) which contains the probabilities of a pixel being a certain type of nuclei. Returns: dict: A dictionary containing a mapping of each instance within `pred_inst` instance information. It has the following form:: { 0: { # Instance ID "box": [ x_min, y_min, x_max, y_max, ], "centroid": [x, y], "contour": [ [x, y], ... ], "type": integer, "prob": float, }, ... } where the instance ID is an integer corresponding to the instance at the same pixel value within `pred_inst`. """ inst_id_list = np.unique(pred_inst)[1:] # exclude background inst_info_dict = {} for inst_id in inst_id_list: inst_map = pred_inst == inst_id inst_box = get_bounding_box(inst_map) inst_box_tl = inst_box[:2] inst_map = inst_map[inst_box[1] : inst_box[3], inst_box[0] : inst_box[2]] inst_map = inst_map.astype(np.uint8) inst_moment = cv2.moments(inst_map) inst_contour = cv2.findContours( inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE, ) # * opencv protocol format may break inst_contour = inst_contour[0][0].astype(np.int32) inst_contour = np.squeeze(inst_contour) # < 3 points does not make a contour, so skip, likely artifact too # as the contours obtained via approximation => too small if inst_contour.shape[0] < 3: # pragma: no cover # noqa: PLR2004 continue # ! check for trickery shape if len(inst_contour.shape) != 2: # pragma: no cover # noqa: PLR2004 continue inst_centroid = [ (inst_moment["m10"] / inst_moment["m00"]), (inst_moment["m01"] / inst_moment["m00"]), ] inst_centroid = np.array(inst_centroid) inst_contour += inst_box_tl[None] inst_centroid += inst_box_tl # X inst_info_dict[inst_id] = { # inst_id should start at 1 "box": inst_box, "centroid": inst_centroid, "contour": inst_contour, "prob": None, "type": None, } if pred_type is not None: # * Get class of each instance id, stored at index id-1 for inst_id in list(inst_info_dict.keys()): c_min, r_min, c_max, r_max = inst_info_dict[inst_id]["box"] inst_map_crop = pred_inst[r_min:r_max, c_min:c_max] inst_type_crop = pred_type[r_min:r_max, c_min:c_max] inst_map_crop = inst_map_crop == inst_id inst_type = inst_type_crop[inst_map_crop] (type_list, type_pixels) = np.unique(inst_type, return_counts=True) type_list = list(zip(type_list, type_pixels)) type_list = sorted(type_list, key=lambda x: x[1], reverse=True) inst_type = type_list[0][0] # ! pick the 2nd most dominant if it exists if inst_type == 0 and len(type_list) > 1: # pragma: no cover inst_type = type_list[1][0] type_dict = {v[0]: v[1] for v in type_list} type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) inst_info_dict[inst_id]["type"] = int(inst_type) inst_info_dict[inst_id]["prob"] = float(type_prob) return inst_info_dict
[docs] @staticmethod # skipcq: PYL-W0221 # noqa: ERA001 def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: """Post-processing script for image tiles. Args: raw_maps (list(:class:`numpy.ndarray`)): A list of prediction outputs of each head and assumed to be in the order of [np, hv, tp] (match with the output of `infer_batch`). Returns: tuple: - :class:`numpy.ndarray` - Instance map: Pixel-wise nuclear instance segmentation prediction. - :py:obj:`dict` - Instance dictionary: A dictionary containing a mapping of each instance within `inst_map` instance information. It has the following form:: { 0: { # Instance ID "box": [ x_min, y_min, x_max, y_max, ], "centroid": [x, y], "contour": [ [x, y], ... ], "type": 1, "prob": 0.95, }, ... } where the instance ID is an integer corresponding to the instance at the same pixel location within the returned instance map. Examples: >>> from tiatoolbox.models.architecture.hovernet import HoVerNet >>> import torch >>> import numpy as np >>> batch = torch.from_numpy(image_patch)[None] >>> # image_patch is a 256x256x3 numpy array >>> weights_path = "A/weights.pth" >>> pretrained = torch.load(weights_path) >>> model = HoVerNet(num_types=6, mode="fast") >>> model.load_state_dict(pretrained) >>> output = model.infer_batch(model, batch, on_gpu=False) >>> output = [v[0] for v in output] >>> output = model.postproc(output) """ if len(raw_maps) == 3: # noqa: PLR2004 np_map, hv_map, tp_map = raw_maps tp_map = np.around(tp_map).astype("uint8") else: tp_map = None np_map, hv_map = raw_maps pred_type = tp_map pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) return pred_inst, nuc_inst_info_dict
[docs] @staticmethod def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. on_gpu (bool): Whether to run inference on a GPU. Returns: tuple: Output from each head. Each head is expected to contain N predictions for N input patches. There are two cases, one with 2 heads (Nuclei Pixels `np` and Hover `hv`) or with 2 heads (`np`, `hv`, and Nuclei Types `tp`). """ patch_imgs = batch_data device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() model.eval() # infer mode # -------------------------------------------------------------- with torch.inference_mode(): pred_dict = model(patch_imgs_gpu) pred_dict = OrderedDict( [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()], ) pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:] if "tp" in pred_dict: type_map = F.softmax(pred_dict["tp"], dim=-1) type_map = torch.argmax(type_map, dim=-1, keepdim=True) type_map = type_map.type(torch.float32) pred_dict["tp"] = type_map pred_dict = {k: v.cpu().numpy() for k, v in pred_dict.items()} if "tp" in pred_dict: return pred_dict["np"], pred_dict["hv"], pred_dict["tp"] return pred_dict["np"], pred_dict["hv"]