"""Define MicroNet architecture.
Raza, SEA et al., “Micro-Net: A unified model for segmentation of
various objects in microscopy images,” Medical Image Analysis,
Dec. 2018, vol. 52, p. 160-173.
"""
from __future__ import annotations
from collections import OrderedDict
import numpy as np
import torch
from scipy import ndimage
from skimage import morphology
from torch import nn
from torch.nn import functional
from tiatoolbox.models.architecture.hovernet import HoVerNet
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.utils import misc
[docs]
def group1_forward_branch(
layer: nn.Module,
in_tensor: torch.Tensor,
resized_feat: torch.Tensor,
) -> torch.Tensor:
"""Define group 1 connections.
Args:
layer (torch.nn.Module):
Network layer.
in_tensor (torch.Tensor):
Input tensor.
resized_feat (torch.Tensor):
Resized input.
Returns:
torch.Tensor:
Output of group 1 layer.
"""
a = layer["conv1"](in_tensor)
a = layer["conv2"](a)
a = layer["pool"](a)
b = layer["conv3"](resized_feat)
b = layer["conv4"](b)
return torch.cat(tensors=(a, b), dim=1)
[docs]
def group2_forward_branch(layer: nn.Module, in_tensor: torch.Tensor) -> torch.Tensor:
"""Define group 1 connections.
Args:
layer (torch.nn.Module):
Network layer.
in_tensor (torch.Tensor):
Input tensor.
Returns:
torch.Tensor:
Output of group 1 layer.
"""
a = layer["conv1"](in_tensor)
return layer["conv2"](a)
[docs]
def group3_forward_branch(
layer: nn.Module,
main_feat: torch.Tensor,
skip: torch.Tensor,
) -> torch.Tensor:
"""Define group 1 connections.
Args:
layer (torch.nn.Module):
Network layer.
main_feat (torch.Tensor):
Input tensor.
skip (torch.Tensor):
Skip connection.
Returns:
torch.Tensor: Output of group 1 layer.
"""
a = layer["up1"](main_feat)
a = layer["conv1"](a)
a = layer["conv2"](a)
b1 = layer["up2"](a)
b2 = layer["up3"](skip)
b = torch.cat(tensors=(b1, b2), dim=1)
return layer["conv3"](b)
[docs]
def group4_forward_branch(layer: nn.Module, in_tensor: torch.Tensor) -> torch.Tensor:
"""Define group 1 connections.
Args:
layer (torch.nn.Module):
Network layer.
in_tensor (torch.Tensor):
Input tensor.
Returns:
torch.Tensor: Output of group 1 layer.
"""
a = layer["up1"](in_tensor)
return layer["conv1"](a)
[docs]
def group1_arch_branch(in_ch: int, resized_in_ch: int, out_ch: int) -> nn.ModuleDict:
"""Group1 branch for MicroNet.
Args:
in_ch (int):
Number of input channels.
resized_in_ch (int):
Number of input channels from resized input.
out_ch (int):
Number of output channels.
Returns:
:class:`torch.nn.ModuleDict`:
An output of type :class:`torch.nn.ModuleDict`
"""
module_dict = OrderedDict()
module_dict["conv1"] = nn.Sequential(
nn.Conv2d(
in_ch,
out_ch,
kernel_size=(3, 3),
stride=(1, 1),
padding=0,
bias=True,
),
nn.Tanh(),
nn.BatchNorm2d(out_ch),
)
module_dict["conv2"] = nn.Sequential(
nn.Conv2d(
out_ch,
out_ch,
kernel_size=(3, 3),
stride=(1, 1),
padding=0,
bias=True,
),
nn.Tanh(),
)
module_dict["pool"] = nn.MaxPool2d(2, padding=0) # check padding
module_dict["conv3"] = nn.Sequential(
nn.Conv2d(
resized_in_ch,
out_ch,
kernel_size=(3, 3),
stride=(1, 1),
padding=0,
bias=True,
),
nn.Tanh(),
nn.BatchNorm2d(out_ch),
)
module_dict["conv4"] = nn.Sequential(
nn.Conv2d(
out_ch,
out_ch,
kernel_size=(3, 3),
stride=(1, 1),
padding=0,
bias=True,
),
nn.Tanh(),
)
return nn.ModuleDict(module_dict)
[docs]
def group2_arch_branch(in_ch: int, out_ch: int) -> nn.ModuleDict:
"""Group2 branch for MicroNet.
Args:
in_ch (int):
Number of input channels.
out_ch (int):
Number of output channels.
Returns:
torch.nn.ModuleDict:
An output of type :class:`torch.nn.ModuleDict`
"""
module_dict = OrderedDict()
module_dict["conv1"] = nn.Sequential(
nn.Conv2d(
in_ch,
out_ch,
kernel_size=(3, 3),
stride=(1, 1),
padding=0,
bias=True,
),
nn.Tanh(),
)
module_dict["conv2"] = nn.Sequential(
nn.Conv2d(
out_ch,
out_ch,
kernel_size=(3, 3),
stride=(1, 1),
padding=0,
bias=True,
),
nn.Tanh(),
)
return nn.ModuleDict(module_dict)
[docs]
def group3_arch_branch(in_ch: int, skip: int, out_ch: int) -> nn.ModuleDict:
"""Group3 branch for MicroNet.
Args:
in_ch (int):
Number of input channels.
skip (int):
Number of channels for the skip connection.
out_ch (int):
Number of output channels.
Returns:
torch.nn.ModuleDict:
An output of type :class:`torch.nn.ModuleDict`
"""
module_dict = OrderedDict()
module_dict["up1"] = nn.ConvTranspose2d(
in_ch,
out_ch,
kernel_size=(2, 2),
stride=(2, 2),
)
module_dict["conv1"] = nn.Sequential(
nn.Conv2d(
out_ch,
out_ch,
kernel_size=(3, 3),
stride=(1, 1),
padding=0,
bias=True,
),
nn.Tanh(),
)
module_dict["conv2"] = nn.Sequential(
nn.Conv2d(
out_ch,
out_ch,
kernel_size=(3, 3),
stride=(1, 1),
padding=0,
bias=True,
),
nn.Tanh(),
)
module_dict["up2"] = nn.ConvTranspose2d(
out_ch,
out_ch,
kernel_size=(5, 5),
stride=(1, 1),
)
module_dict["up3"] = nn.ConvTranspose2d(
skip,
out_ch,
kernel_size=(5, 5),
stride=(1, 1),
)
module_dict["conv3"] = nn.Sequential(
nn.Conv2d(
2 * out_ch,
out_ch,
kernel_size=(1, 1),
stride=(1, 1),
padding=0,
bias=True,
),
nn.Tanh(),
)
return nn.ModuleDict(module_dict)
[docs]
def group4_arch_branch(
in_ch: int,
out_ch: int,
up_kernel: tuple[int, int] = (2, 2),
up_strides: tuple[int, int] = (2, 2),
activation: str = "tanh",
) -> nn.ModuleDict:
"""Group4 branch for MicroNet.
This branch defines architecture for decoder and
provides input for the auxiliary and main output branch.
Args:
in_ch (int):
Number of input channels.
out_ch (int):
Number of output channels.
up_kernel (tuple of int):
Kernel size for
:class:`torch.nn.ConvTranspose2d`.
up_strides (tuple of int):
Stride size for
:class:`torch.nn.ConvTranspose2d`.
activation (str):
Activation function, default="tanh".
Returns:
torch.nn.ModuleDict:
An output of type :class:`torch.nn.ModuleDict`
"""
activation = nn.ReLU() if activation == "relu" else nn.Tanh()
module_dict = OrderedDict()
module_dict["up1"] = nn.ConvTranspose2d(
in_ch,
out_ch,
kernel_size=up_kernel,
stride=up_strides,
)
module_dict["conv1"] = nn.Sequential(
nn.Conv2d(
out_ch,
out_ch,
kernel_size=(3, 3),
stride=(1, 1),
padding=0,
bias=True,
),
activation,
)
return nn.ModuleDict(module_dict)
[docs]
def out_arch_branch(
in_ch: int,
num_output_channels: int = 2,
activation: str = "softmax",
) -> torch.nn.Sequential:
"""Group5 branch for MicroNet.
This branch defines architecture for auxiliary and the main output.
Args:
in_ch (int):
Number of input channels.
num_output_channels (int):
Number of output channels. default=2.
activation (str):
Activation function, default="softmax".
Returns:
torch.nn.Sequential:
An output of type :class:`torch.nn.Sequential`
"""
activation = nn.ReLU() if activation == "relu" else nn.Softmax()
return nn.Sequential(
nn.Dropout2d(p=0.5),
nn.Conv2d(
in_ch,
num_output_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding=0,
bias=True,
),
activation,
)
[docs]
class MicroNet(ModelABC):
"""Initialize MicroNet [1].
The following models have been included in tiatoolbox:
1. `micronet-consep`:
This is trained on `CoNSeP dataset
<https://warwick.ac.uk/fac/cross_fac/tia/data/hovernet/>`_ The
model is retrained in torch as the original model with results
on CoNSeP [2] was trained in TensorFlow.
The tiatoolbox model should produce the following results on the CoNSeP dataset:
.. list-table:: MicroNet performance
:widths: 15 15 15 15 15 15 15
:header-rows: 1
* - Model name
- Data set
- DICE
- AJI
- DQ
- SQ
- PQ
* - micronet-consep
- CoNSeP
- 0.80
- 0.49
- 0.62
- 0.75
- 0.47
Args:
num_input_channels (int):
Number of channels in input. default=3.
num_output_channels (int):
Number of output channels. default=2.
out_activation (str):
Activation to use at the output. MapDe inherits MicroNet
but uses ReLU activation.
References:
[1] Raza, Shan E Ahmed, et al. "Micro-Net: A unified model for
segmentation of various objects in microscopy images."
Medical image analysis 52 (2019): 160-173.
[2] Graham, Simon, et al. "Hover-net: Simultaneous segmentation
and classification of nuclei in multi-tissue histology images."
Medical Image Analysis 58 (2019): 101563.
"""
def __init__(
self: MicroNet,
num_input_channels: int = 3,
num_output_channels: int = 2,
out_activation: str = "softmax",
) -> None:
"""Initialize :class:`MicroNet`."""
super().__init__()
if num_output_channels < 2: # noqa: PLR2004
msg = "Number of classes should be >=2."
raise ValueError(msg)
self.__num_output_channels = num_output_channels
self.in_ch = num_input_channels
module_dict = OrderedDict()
module_dict["b1"] = group1_arch_branch(
num_input_channels,
num_input_channels,
64,
)
module_dict["b2"] = group1_arch_branch(128, num_input_channels, 128)
module_dict["b3"] = group1_arch_branch(256, num_input_channels, 256)
module_dict["b4"] = group1_arch_branch(512, num_input_channels, 512)
module_dict["b5"] = group2_arch_branch(1024, 2048)
module_dict["b6"] = group3_arch_branch(2048, 1024, 1024)
module_dict["b7"] = group3_arch_branch(1024, 512, 512)
module_dict["b8"] = group3_arch_branch(512, 256, 256)
module_dict["b9"] = group3_arch_branch(256, 128, 128)
module_dict["fm1"] = group4_arch_branch(
128,
64,
(2, 2),
(2, 2),
activation=out_activation,
)
module_dict["fm2"] = group4_arch_branch(
256,
128,
(4, 4),
(4, 4),
activation=out_activation,
)
module_dict["fm3"] = group4_arch_branch(
512,
256,
(8, 8),
(8, 8),
activation=out_activation,
)
module_dict["aux_out1"] = out_arch_branch(
64,
num_output_channels=self.__num_output_channels,
)
module_dict["aux_out2"] = out_arch_branch(
128,
num_output_channels=self.__num_output_channels,
)
module_dict["aux_out3"] = out_arch_branch(
256,
num_output_channels=self.__num_output_channels,
)
module_dict["out"] = out_arch_branch(
64 + 128 + 256,
num_output_channels=self.__num_output_channels,
activation=out_activation,
)
self.layer = nn.ModuleDict(module_dict)
[docs]
def forward( # skipcq: PYL-W0221
self: MicroNet,
input_tensor: torch.Tensor,
) -> list[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Logic for using layers defined in init.
This method defines how layers are used in forward operation.
Args:
input_tensor (torch.Tensor):
Input images, the tensor is in the shape of NCHW.
Returns:
list:
A list of main and auxiliary outputs. The expected
format is `[main_output, aux1, aux2, aux3]`.
"""
b1 = group1_forward_branch(
self.layer["b1"],
input_tensor,
functional.interpolate(input_tensor, size=(128, 128), mode="bicubic"),
)
b2 = group1_forward_branch(
self.layer["b2"],
b1,
functional.interpolate(input_tensor, size=(64, 64), mode="bicubic"),
)
b3 = group1_forward_branch(
self.layer["b3"],
b2,
functional.interpolate(input_tensor, size=(32, 32), mode="bicubic"),
)
b4 = group1_forward_branch(
self.layer["b4"],
b3,
functional.interpolate(input_tensor, size=(16, 16), mode="bicubic"),
)
b5 = group2_forward_branch(self.layer["b5"], b4)
b6 = group3_forward_branch(self.layer["b6"], b5, b4)
b7 = group3_forward_branch(self.layer["b7"], b6, b3)
b8 = group3_forward_branch(self.layer["b8"], b7, b2)
b9 = group3_forward_branch(self.layer["b9"], b8, b1)
fm1 = group4_forward_branch(self.layer["fm1"], b9)
fm2 = group4_forward_branch(self.layer["fm2"], b8)
fm3 = group4_forward_branch(self.layer["fm3"], b7)
aux1 = self.layer["aux_out1"](fm1)
aux2 = self.layer["aux_out2"](fm2)
aux3 = self.layer["aux_out3"](fm3)
out = torch.cat(tensors=(fm1, fm2, fm3), dim=1)
out = self.layer["out"](out)
return [out, aux1, aux2, aux3]
[docs]
@staticmethod
def postproc(image: np.ndarray) -> tuple[np.ndarray, dict]:
"""Post-processing script for MicroNet.
Args:
image (ndarray):
Input image of type numpy array.
Returns:
:class:`numpy.ndarray`:
Pixel-wise nuclear instance segmentation
prediction.
"""
pred_bin = np.argmax(image[0], axis=2)
pred_inst = ndimage.label(pred_bin)[0]
pred_inst = morphology.remove_small_objects(pred_inst, min_size=50)
canvas = np.zeros(pred_inst.shape[:2], dtype=np.int32)
for inst_id in range(1, np.max(pred_inst) + 1):
inst_map = np.array(pred_inst == inst_id, dtype=np.uint8)
inst_map = ndimage.binary_fill_holes(inst_map)
canvas[inst_map > 0] = inst_id
nuc_inst_info_dict = HoVerNet.get_instance_info(canvas)
return canvas, nuc_inst_info_dict
[docs]
@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Preprocessing function for MicroNet.
Performs per image standardization.
Args:
image (:class:`numpy.ndarray`):
Input image of type numpy array.
Returns:
:class:`numpy.ndarray`:
Pre-processed numpy array.
"""
image = np.transpose(image, axes=(2, 0, 1))
image = image / 255.0
image = torch.from_numpy(image)
image_mean = torch.mean(image, dim=(-1, -2, -3))
stddev = torch.std(image, dim=(-1, -2, -3))
num_pixels = torch.tensor(torch.numel(image), dtype=torch.float32)
min_stddev = torch.rsqrt(num_pixels)
adjusted_stddev = torch.max(stddev, min_stddev)
image -= image_mean
image = torch.div(image, adjusted_stddev)
return np.transpose(image.numpy(), axes=(1, 2, 0))
[docs]
@staticmethod
def infer_batch(
model: torch.nn.Module,
batch_data: torch.Tensor,
*,
on_gpu: bool,
) -> list[np.ndarray]:
"""Run inference on an input batch.
This contains logic for forward operation as well as batch I/O
aggregation.
Args:
model (nn.Module):
PyTorch defined model.
batch_data (:class:`torch.Tensor`):
A batch of data generated by
`torch.utils.data.DataLoader`.
on_gpu (bool):
Whether to run inference on a GPU.
Returns:
list(np.ndarray):
Probability map as a numpy array.
"""
patch_imgs = batch_data
device = misc.select_device(on_gpu=on_gpu)
patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW
patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous()
model.eval() # infer mode
with torch.inference_mode():
pred, _, _, _ = model(patch_imgs_gpu)
pred = pred.permute(0, 2, 3, 1).contiguous()
pred = pred.cpu().numpy()
return [
pred,
]