# ***** BEGIN GPL LICENSE BLOCK *****
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick
# All rights reserved.
# ***** END GPL LICENSE BLOCK *****
import math
from collections import OrderedDict
from typing import List
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import measurements
from scipy.ndimage.morphology import binary_fill_holes
from skimage.morphology import remove_small_objects
from skimage.segmentation import watershed
from tiatoolbox.models.abc import ModelABC
from tiatoolbox.models.architecture.utils import (
UpSample2x,
centre_crop,
centre_crop_to_shape,
)
from tiatoolbox.utils import misc
from tiatoolbox.utils.misc import get_bounding_box
[docs]class TFSamepaddingLayer(nn.Module):
"""To align with tensorflow `same` padding.
Putting this before any conv layer that needs padding. Here,
we assume kernel has same height and width for simplicity.
"""
def __init__(self, ksize: int, stride: int):
super().__init__()
self.ksize = ksize
self.stride = stride
[docs] def forward(self, x: torch.Tensor):
"""Logic for using layers defined in init."""
if x.shape[2] % self.stride == 0:
pad = max(self.ksize - self.stride, 0)
else:
pad = max(self.ksize - (x.shape[2] % self.stride), 0)
if pad % 2 == 0:
pad_val = pad // 2
padding = (pad_val, pad_val, pad_val, pad_val)
else:
pad_val_start = pad // 2
pad_val_end = pad - pad_val_start
padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end)
x = F.pad(x, padding, "constant", 0)
return x
[docs]class DenseBlock(nn.Module):
"""Dense Convolutional Block.
This convolutional block supports only `valid` padding.
References:
Huang, Gao, et al. "Densely connected convolutional networks."
Proceedings of the IEEE conference on computer vision and
pattern recognition. 2017.
"""
def __init__(
self,
in_ch: int,
unit_ksizes: List[int],
unit_chs: List[int],
unit_count: int,
split: int = 1,
):
super().__init__()
if len(unit_ksizes) != len(unit_chs):
raise ValueError("Unbalance Unit Info")
self.nr_unit = unit_count
self.in_ch = in_ch
# weights value may not match with tensorflow version
# due to different default intialization scheme between
# torch and tensorflow
def get_unit_block(unit_in_ch):
"""Helper function to make it less long."""
layers = OrderedDict(
[
("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
("preact_bna/relu", nn.ReLU(inplace=True)),
(
"conv1",
nn.Conv2d(
unit_in_ch,
unit_chs[0],
unit_ksizes[0],
stride=1,
padding=0,
bias=False,
),
),
("conv1/bn", nn.BatchNorm2d(unit_chs[0], eps=1e-5)),
("conv1/relu", nn.ReLU(inplace=True)),
(
"conv2",
nn.Conv2d(
unit_chs[0],
unit_chs[1],
unit_ksizes[1],
groups=split,
stride=1,
padding=0,
bias=False,
),
),
]
)
return nn.Sequential(layers)
unit_in_ch = in_ch
self.units = nn.ModuleList()
for _ in range(unit_count):
self.units.append(get_unit_block(unit_in_ch))
unit_in_ch += unit_chs[1]
self.blk_bna = nn.Sequential(
OrderedDict(
[
("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
("relu", nn.ReLU(inplace=True)),
]
)
)
[docs] def forward(self, prev_feat: torch.Tensor):
"""Logic for using layers defined in init."""
for idx in range(self.nr_unit):
new_feat = self.units[idx](prev_feat)
prev_feat = centre_crop_to_shape(prev_feat, new_feat)
prev_feat = torch.cat([prev_feat, new_feat], dim=1)
prev_feat = self.blk_bna(prev_feat)
return prev_feat
[docs]class ResidualBlock(nn.Module):
"""Residual block.
References:
He, Kaiming, et al. "Deep residual learning for image recognition."
Proceedings of the IEEE conference on computer vision and
pattern recognition. 2016.
"""
def __init__(
self,
in_ch: int,
unit_ksizes: List[int],
unit_chs: List[int],
unit_count: int,
stride: int = 1,
):
super().__init__()
if len(unit_ksizes) != len(unit_chs):
raise ValueError("Unbalance Unit Info")
self.nr_unit = unit_count
self.in_ch = in_ch
# ! For inference only so init values for batchnorm may not match tensorflow
unit_in_ch = in_ch
self.units = nn.ModuleList()
for idx in range(unit_count):
unit_layer = [
("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
("preact/relu", nn.ReLU(inplace=True)),
(
"conv1",
nn.Conv2d(
unit_in_ch,
unit_chs[0],
unit_ksizes[0],
stride=1,
padding=0,
bias=False,
),
),
("conv1/bn", nn.BatchNorm2d(unit_chs[0], eps=1e-5)),
("conv1/relu", nn.ReLU(inplace=True)),
(
"conv2/pad",
TFSamepaddingLayer(
ksize=unit_ksizes[1], stride=stride if idx == 0 else 1
),
),
(
"conv2",
nn.Conv2d(
unit_chs[0],
unit_chs[1],
unit_ksizes[1],
stride=stride if idx == 0 else 1,
padding=0,
bias=False,
),
),
("conv2/bn", nn.BatchNorm2d(unit_chs[1], eps=1e-5)),
("conv2/relu", nn.ReLU(inplace=True)),
(
"conv3",
nn.Conv2d(
unit_chs[1],
unit_chs[2],
unit_ksizes[2],
stride=1,
padding=0,
bias=False,
),
),
]
# has BatchNorm-Activation layers to conclude each
# previous block so must not put preact for the first
# unit of this block
unit_layer = unit_layer if idx != 0 else unit_layer[2:]
self.units.append(nn.Sequential(OrderedDict(unit_layer)))
unit_in_ch = unit_chs[-1]
if in_ch != unit_chs[-1] or stride != 1:
self.shortcut = nn.Conv2d(in_ch, unit_chs[-1], 1, stride=stride, bias=False)
else:
self.shortcut = None
self.blk_bna = nn.Sequential(
OrderedDict(
[
("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
("relu", nn.ReLU(inplace=True)),
]
)
)
[docs] def forward(self, prev_feat: torch.Tensor):
"""Logic for using layers defined in init."""
if self.shortcut is None:
shortcut = prev_feat
else:
shortcut = self.shortcut(prev_feat)
for idx in range(0, len(self.units)):
new_feat = prev_feat
new_feat = self.units[idx](new_feat)
prev_feat = new_feat + shortcut
shortcut = prev_feat
feat = self.blk_bna(prev_feat)
return feat
[docs]class HoVerNet(ModelABC):
"""HoVer-Net Architecture.
Args:
num_input_channels (int): Number of channels in input.
num_types (int): Number of nuclei types within the predictions.
Once define, a branch dedicated for typing is created.
By default, no typing (`num_types=None`) is used.
mode (str): To use architecture defined in as in original paper
(`original`) or the one used in PanNuke paper (`fast`).
References:
Graham, Simon, et al. "Hover-net: Simultaneous segmentation and
classification of nuclei in multi-tissue histology images."
Medical Image Analysis 58 (2019): 101563.
Gamper, Jevgenij, et al. "PanNuke dataset extension, insights and baselines."
arXiv preprint arXiv:2003.10778 (2020).
"""
def __init__(
self, num_input_channels: int = 3, num_types: int = None, mode: str = "original"
):
super().__init__()
self.mode = mode
self.num_types = num_types
if mode not in ["original", "fast"]:
raise ValueError(
f"Invalid mode {mode} for HoVerNet. "
"Only support `original` or `fast`."
)
modules = [
(
"/",
nn.Conv2d(num_input_channels, 64, 7, stride=1, padding=0, bias=False),
),
("bn", nn.BatchNorm2d(64, eps=1e-5)),
("relu", nn.ReLU(inplace=True)),
]
# pre-pend the padding for `fast` mode
if mode == "fast":
modules = [("pad", TFSamepaddingLayer(ksize=7, stride=1)), *modules]
self.conv0 = nn.Sequential(OrderedDict(modules))
self.d0 = ResidualBlock(64, [1, 3, 1], [64, 64, 256], 3, stride=1)
self.d1 = ResidualBlock(256, [1, 3, 1], [128, 128, 512], 4, stride=2)
self.d2 = ResidualBlock(512, [1, 3, 1], [256, 256, 1024], 6, stride=2)
self.d3 = ResidualBlock(1024, [1, 3, 1], [512, 512, 2048], 3, stride=2)
self.conv_bot = nn.Conv2d(2048, 1024, 1, stride=1, padding=0, bias=False)
ksize = 5 if mode == "original" else 3
if num_types is None:
self.decoder = nn.ModuleDict(
OrderedDict(
[
("np", HoVerNet._create_decoder_branch(ksize=ksize, out_ch=2)),
("hv", HoVerNet._create_decoder_branch(ksize=ksize, out_ch=2)),
]
)
)
else:
self.decoder = nn.ModuleDict(
OrderedDict(
[
(
"tp",
HoVerNet._create_decoder_branch(
ksize=ksize, out_ch=num_types
),
),
("np", HoVerNet._create_decoder_branch(ksize=ksize, out_ch=2)),
("hv", HoVerNet._create_decoder_branch(ksize=ksize, out_ch=2)),
]
)
)
self.upsample2x = UpSample2x()
# skipcq: PYL-W0221
[docs] def forward(self, imgs: torch.Tensor):
"""Logic for using layers defined in init.
This method defines how layers are used in forward operation.
Args:
imgs (torch.Tensor): Input images, the tensor is in the shape of NCHW.
Returns:
output (dict): A dictionary containing the inference output.
The expected format os {decoder_name: prediction}.
"""
imgs = imgs / 255.0 # to 0-1 range to match XY
d0 = self.conv0(imgs)
d0 = self.d0(d0)
d1 = self.d1(d0)
d2 = self.d2(d1)
d3 = self.d3(d2)
d3 = self.conv_bot(d3)
d = [d0, d1, d2, d3]
if self.mode == "original":
d[0] = centre_crop(d[0], [184, 184])
d[1] = centre_crop(d[1], [72, 72])
else:
d[0] = centre_crop(d[0], [92, 92])
d[1] = centre_crop(d[1], [36, 36])
out_dict = OrderedDict()
for branch_name, branch_desc in self.decoder.items():
u3 = self.upsample2x(d[-1]) + d[-2]
u3 = branch_desc[0](u3)
u2 = self.upsample2x(u3) + d[-3]
u2 = branch_desc[1](u2)
u1 = self.upsample2x(u2) + d[-4]
u1 = branch_desc[2](u1)
u0 = branch_desc[3](u1)
out_dict[branch_name] = u0
return out_dict
@staticmethod
def _create_decoder_branch(out_ch=2, ksize=5):
"""Helper to create a decoder branch."""
modules = [
("conva", nn.Conv2d(1024, 256, ksize, stride=1, padding=0, bias=False)),
("dense", DenseBlock(256, [1, ksize], [128, 32], 8, split=4)),
(
"convf",
nn.Conv2d(512, 512, 1, stride=1, padding=0, bias=False),
),
]
u3 = nn.Sequential(OrderedDict(modules))
modules = [
("conva", nn.Conv2d(512, 128, ksize, stride=1, padding=0, bias=False)),
("dense", DenseBlock(128, [1, ksize], [128, 32], 4, split=4)),
(
"convf",
nn.Conv2d(256, 256, 1, stride=1, padding=0, bias=False),
),
]
u2 = nn.Sequential(OrderedDict(modules))
modules = [
("conva/pad", TFSamepaddingLayer(ksize=ksize, stride=1)),
(
"conva",
nn.Conv2d(256, 64, ksize, stride=1, padding=0, bias=False),
),
]
u1 = nn.Sequential(OrderedDict(modules))
modules = [
("bn", nn.BatchNorm2d(64, eps=1e-5)),
("relu", nn.ReLU(inplace=True)),
(
"conv",
nn.Conv2d(64, out_ch, 1, stride=1, padding=0, bias=True),
),
]
u0 = nn.Sequential(OrderedDict(modules))
decoder = nn.Sequential(
OrderedDict([("u3", u3), ("u2", u2), ("u1", u1), ("u0", u0)])
)
return decoder
@staticmethod
def _proc_np_hv(np_map: np.ndarray, hv_map: np.ndarray, fx: float = 1):
"""Extract Nuclei Instance with NP and HV Map.
Sobel will be applied on horizontal and vertical channel in
`hv_map` to derive a energy landscape which highligh possible
nuclei instance boundaries. Afterward, watershed with markers
is applied on the above energy map using the `np_map` as filter
to remove background regions.
Args:
np_map (np.ndarray): An image of shape (heigh, width, 1) which
contains the probabilities of a pixel being a nuclei.
hv_map (np.ndarray): An array of shape (heigh, width, 2) which
contains the horizontal (channel 0) and vertical (channel 1)
of possible instances exist withint the images.
fx (float): The scale factor for processing nuclei. The scale
assumes an image of resolution 0.25 microns per pixel. Default
is therefore 1 for HoVer-Net.
Returns:
An np.ndarray of shape (height, width) where each non-zero values
within the array correspond to one detected nuclei instances.
"""
blb_raw = np_map[..., 0]
h_dir_raw = hv_map[..., 0]
v_dir_raw = hv_map[..., 1]
# processing
blb = np.array(blb_raw >= 0.5, dtype=np.int32)
blb = measurements.label(blb)[0]
blb = remove_small_objects(blb, min_size=10)
blb[blb > 0] = 1 # background is 0 already
h_dir = cv2.normalize(
h_dir_raw,
None,
alpha=0,
beta=1,
norm_type=cv2.NORM_MINMAX,
dtype=cv2.CV_32F,
)
v_dir = cv2.normalize(
v_dir_raw,
None,
alpha=0,
beta=1,
norm_type=cv2.NORM_MINMAX,
dtype=cv2.CV_32F,
)
ksize = int((20 * fx) + 1)
obj_size = math.ceil(10 * (fx ** 2))
# Get resolution specific filters etc.
sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize)
sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize)
sobelh = 1 - (
cv2.normalize(
sobelh,
None,
alpha=0,
beta=1,
norm_type=cv2.NORM_MINMAX,
dtype=cv2.CV_32F,
)
)
sobelv = 1 - (
cv2.normalize(
sobelv,
None,
alpha=0,
beta=1,
norm_type=cv2.NORM_MINMAX,
dtype=cv2.CV_32F,
)
)
overall = np.maximum(sobelh, sobelv)
overall = overall - (1 - blb)
overall[overall < 0] = 0
dist = (1.0 - overall) * blb
# * nuclei values form mountains so inverse to get basins
dist = -cv2.GaussianBlur(dist, (3, 3), 0)
overall = np.array(overall >= 0.4, dtype=np.int32)
marker = blb - overall
marker[marker < 0] = 0
marker = binary_fill_holes(marker).astype("uint8")
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
marker = measurements.label(marker)[0]
marker = remove_small_objects(marker, min_size=obj_size)
proced_pred = watershed(dist, markers=marker, mask=blb)
return proced_pred
@staticmethod
def _get_instance_info(pred_inst, pred_type=None):
"""To collect instance information and store it within a dictionary.
Args:
pred_inst (np.ndarray): An image of shape (heigh, width) which
contains the probabilities of a pixel being a nuclei.
pred_type (np.ndarray): An image of shape (heigh, width, 1) which
contains the probabilities of a pixel being a certain type of nuclei.
Returns:
inst_info_dict (dict): A dictionary containing a mapping of each instance
within `pred_inst` instance information. It has following form
inst_info = {
box: number[],
centroids: number[],
contour: number[][],
type: number,
prob: number,
}
inst_info_dict = {[inst_uid: number] : inst_info}
and `inst_uid` is an integer corresponds to the instance
having the same pixel value within `pred_inst`.
"""
inst_id_list = np.unique(pred_inst)[1:] # exclude background
inst_info_dict = {}
for inst_id in inst_id_list:
inst_map = pred_inst == inst_id
inst_box = get_bounding_box(inst_map)
inst_box_tl = inst_box[:2]
inst_map = inst_map[inst_box[1] : inst_box[3], inst_box[0] : inst_box[2]]
inst_map = inst_map.astype(np.uint8)
inst_moment = cv2.moments(inst_map)
inst_contour = cv2.findContours(
inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
# * opencv protocol format may break
inst_contour = inst_contour[0][0].astype(np.int32)
inst_contour = np.squeeze(inst_contour)
# < 3 points does not make a contour, so skip, likely artifact too
# as the contours obtained via approximation => too small
if inst_contour.shape[0] < 3: # pragma: no cover
continue
# ! check for trickery shape
if len(inst_contour.shape) != 2: # pragma: no cover
continue
inst_centroid = [
(inst_moment["m10"] / inst_moment["m00"]),
(inst_moment["m01"] / inst_moment["m00"]),
]
inst_centroid = np.array(inst_centroid)
inst_contour += inst_box_tl[None]
inst_centroid += inst_box_tl # X
inst_info_dict[inst_id] = { # inst_id should start at 1
"box": inst_box,
"centroid": inst_centroid,
"contour": inst_contour,
"prob": None,
"type": None,
}
if pred_type is not None:
# * Get class of each instance id, stored at index id-1
for inst_id in list(inst_info_dict.keys()):
cmin, rmin, cmax, rmax = inst_info_dict[inst_id]["box"]
inst_map_crop = pred_inst[rmin:rmax, cmin:cmax]
inst_type_crop = pred_type[rmin:rmax, cmin:cmax]
inst_map_crop = inst_map_crop == inst_id
inst_type = inst_type_crop[inst_map_crop]
(type_list, type_pixels) = np.unique(inst_type, return_counts=True)
type_list = list(zip(type_list, type_pixels))
type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
inst_type = type_list[0][0]
# ! pick the 2nd most dominant if it exists
if inst_type == 0 and len(type_list) > 1: # pragma: no cover
inst_type = type_list[1][0]
type_dict = {v[0]: v[1] for v in type_list}
type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6)
inst_info_dict[inst_id]["type"] = int(inst_type)
inst_info_dict[inst_id]["prob"] = float(type_prob)
return inst_info_dict
[docs] @staticmethod
# skipcq: PYL-W0221
def postproc(raw_maps: List[np.ndarray]):
"""Post processing script for image tiles.
Args:
raw_maps (list(ndarray)): list of prediction output of each head and
assumed to be in the order of [np, hv, tp] (match with the output
of `infer_batch`).
Returns:
inst_map (ndarray): pixel-wise nuclear instance segmentation
prediction.
inst_dict (dict): a dictionary containing a mapping of each instance
within `inst_map` instance information. It has following form
inst_info = {
box: number[],
centroids: number[],
contour: number[][],
type: number,
prob: number,
}
inst_dict = {[inst_uid: number] : inst_info}
and `inst_uid` is an integer corresponds to the instance
having the same pixel value within `inst_map`.
Examples:
>>> from tiatoolbox.models.architecture.hovernet import HoVerNet
>>> import torch
>>> import numpy as np
>>> batch = torch.from_numpy(image_patch)[None]
>>> # image_patch is a 256x256x3 numpy array
>>> weights_path = "A/weights.pth"
>>> pretrained = torch.load(weights_path)
>>> model = HoVerNet(num_types=6, mode="fast")
>>> model.load_state_dict(pretrained)
>>> output = model.infer_batch(model, batch, on_gpu=False)
>>> output = [v[0] for v in output]
>>> output = model.postproc(output)
"""
if len(raw_maps) == 3:
np_map, hv_map, tp_map = raw_maps
else:
tp_map = None
np_map, hv_map = raw_maps
pred_type = tp_map
pred_inst = HoVerNet._proc_np_hv(np_map, hv_map)
nuc_inst_info_dict = HoVerNet._get_instance_info(pred_inst, pred_type)
return pred_inst, nuc_inst_info_dict
[docs] @staticmethod
def infer_batch(model, batch_data, on_gpu):
"""Run inference on an input batch.
This contains logic for forward operation as well as batch i/o
aggregation.
Args:
model (nn.Module): PyTorch defined model.
batch_data (ndarray): a batch of data generated by
torch.utils.data.DataLoader.
on_gpu (bool): Whether to run inference on a GPU.
Returns:
List of output from each head, each head is expected to contain
N predictions for N input patches. There are two cases, one
with 2 heads (Nuclei Pixels `np` and Hover `hv`) or with 2 heads
(`np`, `hv`, and Nuclei Types `tp`).
"""
patch_imgs = batch_data
device = misc.select_device(on_gpu)
patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW
patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous()
model.eval() # infer mode
# --------------------------------------------------------------
with torch.inference_mode():
pred_dict = model(patch_imgs_gpu)
pred_dict = OrderedDict(
[[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()]
)
pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:]
if "tp" in pred_dict:
type_map = F.softmax(pred_dict["tp"], dim=-1)
type_map = torch.argmax(type_map, dim=-1, keepdim=True)
type_map = type_map.type(torch.float32)
pred_dict["tp"] = type_map
pred_dict = {k: v.cpu().numpy() for k, v in pred_dict.items()}
if "tp" in pred_dict:
return pred_dict["np"], pred_dict["hv"], pred_dict["tp"]
return pred_dict["np"], pred_dict["hv"]