"""Define HoVerNetPlus architecture."""
from __future__ import annotations
from collections import OrderedDict
import cv2
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from skimage import morphology
from torch import nn
from tiatoolbox.models.architecture.hovernet import HoVerNet
from tiatoolbox.models.architecture.utils import UpSample2x
from tiatoolbox.utils import misc
[docs]
class HoVerNetPlus(HoVerNet):
"""Initialise HoVerNet+ [1].
HoVerNet+ takes an RGB input image, and provides the option to
simultaneously segment and classify the nuclei present, as well as
semantically segment different regions or layers in the images. Note
the HoVerNet+ architecture assumes an image resolution of 0.5 mpp,
in contrast to HoVerNet at 0.25 mpp.
The tiatoolbox model should produce following results on the specified datasets
that it was trained on.
.. list-table:: HoVerNet+ Performance for Nuclear Instance Segmentation
:widths: 15 15 15 15 15 15 15
:header-rows: 1
* - Model name
- Data set
- DICE
- AJI
- DQ
- SQ
- PQ
* - hovernetplus-oed
- OED
- 0.84
- 0.69
- 0.86
- 0.80
- 0.69
.. list-table:: HoVerNet+ Mean Performance for Semantic Segmentation
:widths: 15 15 15 15 15 15
:header-rows: 1
* - Model name
- Data set
- F1
- Precision
- Recall
- Accuracy
* - hovernetplus-oed
- OED
- 0.82
- 0.82
- 0.82
- 0.84
Args:
num_input_channels (int):
The number of input channels, default = 3 for RGB.
num_types (int):
The number of types of nuclei present in the images.
num_layers (int):
The number of layers/different regions types present.
References:
[1] Shephard, Adam J., et al. "Simultaneous Nuclear Instance and
Layer Segmentation in Oral Epithelial Dysplasia." Proceedings of
the IEEE/CVF International Conference on Computer Vision. 2021.
"""
def __init__(
self: HoVerNetPlus,
num_input_channels: int = 3,
num_types: int | None = None,
num_layers: int | None = None,
nuc_type_dict: dict | None = None,
layer_type_dict: dict | None = None,
) -> None:
"""Initialize :class:`HoVerNetPlus`."""
super().__init__(mode="fast")
self.num_input_channels = num_input_channels
self.num_types = num_types
self.num_layers = num_layers
self.nuc_type_dict = nuc_type_dict
self.layer_type_dict = layer_type_dict
ksize = 3
self.decoder = nn.ModuleDict(
OrderedDict(
[
(
"tp",
self._create_decoder_branch(ksize=ksize, out_ch=num_types),
),
(
"np",
self._create_decoder_branch(ksize=ksize, out_ch=2),
),
(
"hv",
self._create_decoder_branch(ksize=ksize, out_ch=2),
),
(
"ls",
self._create_decoder_branch(ksize=ksize, out_ch=num_layers),
),
],
),
)
self.upsample2x = UpSample2x()
@staticmethod
def _proc_ls(ls_map: np.ndarray) -> np.ndarray:
"""Extract Layer Segmentation map with LS Map.
This function takes the layer segmentation map and applies various morphological
operations remove spurious segmentations. Note, this processing is specific to
oral epithelium, where prioirty is given to certain tissue layers.
Args:
ls_map:
The input predicted segmentation map.
Returns:
:class:`numpy.ndarray`:
The processed segmentation map.
"""
ls_map = np.squeeze(ls_map)
ls_map = np.around(ls_map).astype("uint8") # ensure all numbers are integers
min_size = 20000
kernel_size = 20
epith_all = np.where(ls_map >= 2, 1, 0).astype("uint8") # noqa: PLR2004
mask = np.where(ls_map >= 1, 1, 0).astype("uint8")
epith_all = epith_all > 0
epith_mask = morphology.remove_small_objects(
epith_all,
min_size=min_size,
).astype("uint8")
epith_edited = epith_mask * ls_map
epith_edited = epith_edited.astype("uint8")
epith_edited_open = np.zeros_like(epith_edited).astype("uint8")
for i in [3, 2, 4]:
tmp = np.where(epith_edited == i, 1, 0).astype("uint8")
ep_open = cv2.morphologyEx(
tmp,
cv2.MORPH_CLOSE,
np.ones((kernel_size, kernel_size)),
)
ep_open = cv2.morphologyEx(
ep_open,
cv2.MORPH_OPEN,
np.ones((kernel_size, kernel_size)),
)
epith_edited_open[ep_open == 1] = i
mask_open = cv2.morphologyEx(
mask,
cv2.MORPH_CLOSE,
np.ones((kernel_size, kernel_size)),
)
mask_open = cv2.morphologyEx(
mask_open,
cv2.MORPH_OPEN,
np.ones((kernel_size, kernel_size)),
).astype("uint8")
ls_map = mask_open.copy()
for i in range(2, 5):
ls_map[epith_edited_open == i] = i
return ls_map.astype("uint8")
@staticmethod
def _get_layer_info(pred_layer: np.ndarray) -> dict:
"""Transforms image layers/regions into contours to store in dictionary.
Args:
pred_layer (:class:`numpy.ndarray`):
Semantic segmentation map of different layers/regions
following processing.
Returns:
dict:
A dictionary of layer contours. It has the
following form:
.. code-block:: json
{
1: { # Instance ID
"contour": [
[x, y],
...
],
"type": integer,
},
...
}
"""
layer_list = np.unique(pred_layer)
layer_list = np.delete(layer_list, np.where(layer_list == 0))
layer_info_dict = {}
count = 1
for type_class in layer_list:
layer = np.where(pred_layer == type_class, 1, 0).astype("uint8")
contours, _ = cv2.findContours(
layer.astype("uint8"),
cv2.RETR_TREE,
cv2.CHAIN_APPROX_NONE,
)
for layer in contours:
coords = layer[:, 0, :]
layer_info_dict[count] = {
"contours": coords,
"type": type_class,
}
count += 1
return layer_info_dict
[docs]
@staticmethod
# skipcq: PYL-W0221 # noqa: ERA001
def postproc(raw_maps: list[np.ndarray]) -> tuple:
"""Post-processing script for image tiles.
Args:
raw_maps (list(ndarray)):
A list of prediction outputs of each head and assumed to
be in the order of [np, hv, tp, ls] (match with the
output of `infer_batch`).
Returns:
tuple:
- inst_map (ndarray):
Pixel-wise nuclear instance segmentation prediction.
- inst_dict (dict):
A dictionary containing a mapping of each instance
within `inst_map` instance information. It has the
following form:
.. code-block:: json
{
0: { # Instance ID
"box": [
x_min,
y_min,
x_max,
y_max,
],
"centroid": [x, y],
"contour": [
[x, y],
...
],
"type": integer,
"prob": float,
},
...
}
where the instance ID is an integer corresponding to the
instance at the same pixel value within `inst_map`.
- layer_map (ndarray):
Pixel-wise layer segmentation prediction.
- layer_dict (dict):
A dictionary containing a mapping of each segmented
layer within `layer_map`. It has the following form
.. code-block:: json
{
1: { # Instance ID
"contour": [
[x, y],
...
],
"type": integer,
},
...
}
Examples:
>>> from tiatoolbox.models.architecture.hovernetplus import HoVerNetPlus
>>> import torch
>>> import numpy as np
>>> batch = torch.from_numpy(image_patch)[None]
>>> # image_patch is a 256x256x3 numpy array
>>> weights_path = "A/weights.pth"
>>> pretrained = torch.load(weights_path)
>>> model = HoVerNetPlus(num_types=3, num_layers=5)
>>> model.load_state_dict(pretrained)
>>> output = model.infer_batch(model, batch, on_gpu=False)
>>> output = [v[0] for v in output]
>>> output = model.postproc(output)
"""
np_map, hv_map, tp_map, ls_map = raw_maps
pred_inst = HoVerNetPlus._proc_np_hv(np_map, hv_map, scale_factor=0.5)
# fx=0.5 as nuclear processing is at 0.5 mpp instead of 0.25 mpp
pred_layer = HoVerNetPlus._proc_ls(ls_map)
pred_type = np.around(tp_map).astype("uint8")
nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type)
layer_info_dict = HoVerNetPlus._get_layer_info(pred_layer)
return pred_inst, nuc_inst_info_dict, pred_layer, layer_info_dict
[docs]
@staticmethod
def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple:
"""Run inference on an input batch.
This contains logic for forward operation as well as batch i/o
aggregation.
Args:
model (nn.Module):
PyTorch defined model.
batch_data (ndarray):
A batch of data generated by
`torch.utils.data.DataLoader`.
on_gpu (bool):
Whether to run inference on a GPU.
"""
patch_imgs = batch_data
device = misc.select_device(on_gpu=on_gpu)
patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW
patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous()
model.eval() # infer mode
# --------------------------------------------------------------
with torch.inference_mode():
pred_dict = model(patch_imgs_gpu)
pred_dict = OrderedDict(
[[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()],
)
pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:]
type_map = F.softmax(pred_dict["tp"], dim=-1)
type_map = torch.argmax(type_map, dim=-1, keepdim=True)
type_map = type_map.type(torch.float32)
pred_dict["tp"] = type_map
layer_map = F.softmax(pred_dict["ls"], dim=-1)
layer_map = torch.argmax(layer_map, dim=-1, keepdim=True)
layer_map = layer_map.type(torch.float32)
pred_dict["ls"] = layer_map
pred_dict = {k: v.cpu().numpy() for k, v in pred_dict.items()}
return pred_dict["np"], pred_dict["hv"], pred_dict["tp"], pred_dict["ls"]