"""Define original NuClick architecture.
Koohbanani, N. A., Jahanifar, M., Tajadin, N. Z., & Rajpoot, N. (2020).
NuClick: a deep learning framework for interactive segmentation of microscopic images.
Medical Image Analysis, 65, 101771.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
import torch
from skimage.morphology import (
disk,
reconstruction,
remove_small_holes,
remove_small_objects,
)
from torch import nn
from tiatoolbox import logger
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.utils import misc
if TYPE_CHECKING: # pragma: no cover
from tiatoolbox.typing import IntPair
bn_axis = 1
[docs]
class ConvBnRelu(nn.Module):
"""Performs Convolution, Batch Normalization and activation.
Args:
num_input_channels (int):
Number of channels in input.
num_output_channels (int):
Number of channels in output.
kernel_size (int):
Size of the kernel in the convolution layer.
strides (int):
Size of the stride in the convolution layer.
use_bias (bool):
Whether to use bias in the convolution layer.
dilation_rate (int):
Dilation rate in the convolution layer.
activation (str):
Name of the activation function to use.
do_batchnorm (bool):
Whether to do batch normalization after the convolution layer.
Returns:
model (torch.nn.Module): a pytorch model.
"""
def __init__(
self: ConvBnRelu,
num_input_channels: int,
num_output_channels: int,
kernel_size: int | tuple[int, int] = (3, 3),
strides: int | tuple[int, int] = (1, 1),
dilation_rate: tuple[int, int] = (1, 1),
activation: str | None = "relu",
*,
use_bias: bool = False,
do_batchnorm: bool = True,
) -> None:
"""Initialize :class:`ConvBnRelu`."""
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(strides, int):
strides = (strides, strides)
self.conv_bn_relu = self.get_block(
num_input_channels,
num_output_channels,
kernel_size,
strides,
dilation_rate,
activation,
do_batchnorm=do_batchnorm,
use_bias=use_bias,
)
[docs]
def forward(self: ConvBnRelu, input_tensor: torch.Tensor) -> torch.Tensor:
"""Logic for using layers defined in init.
This method defines how layers are used in forward operation.
Args:
input_tensor (torch.Tensor):
Input, the tensor is of the shape NCHW.
Returns:
output (torch.Tensor):
The inference output.
"""
return self.conv_bn_relu(input_tensor)
[docs]
@staticmethod
def get_block(
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
strides: IntPair,
dilation_rate: int or IntPair,
activation: str,
*,
do_batchnorm: bool,
use_bias: bool,
) -> torch.nn.Sequential:
"""Function to acquire a convolutional block.
Args:
in_channels (int):
Number of channels in input.
out_channels (int):
Number of channels in output.
kernel_size (int or tuple(int, int)):
Size of the kernel in the acquired convolution block.
strides (int):
Size of stride in the convolution layer.
use_bias (bool):
Whether to use bias in the convolution layer.
dilation_rate (int or tuple(int, int)):
Dilation rate for each convolution layer.
activation (str):
Name of the activation function to use.
do_batchnorm (bool):
Whether to do batch normalization after the convolution layer.
Returns:
torch.nn.Sequential: a pytorch layer
"""
conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=strides,
dilation=dilation_rate,
bias=use_bias,
padding="same",
padding_mode="zeros",
)
torch.nn.init.xavier_uniform_(conv1.weight)
layers = [conv1]
if do_batchnorm:
layers.append(nn.BatchNorm2d(num_features=out_channels, eps=1.001e-5))
if activation == "relu":
layers.append(nn.ReLU())
return nn.Sequential(*layers)
[docs]
class MultiscaleConvBlock(nn.Module):
"""Define Multiscale convolution block.
Args:
num_input_channels (int):
Number of channels in input.
num_output_channels (int):
Number of channels in output.
kernel_sizes (list):
Size of the kernel in each convolution layer.
strides (int):
Size of stride in the convolution layer.
use_bias (bool):
Whether to use bias in the convolution layer.
dilation_rates (list):
Dilation rate for each convolution layer.
activation (str):
Name of the activation function to use.
Returns:
torch.nn.Module:
A PyTorch model.
"""
def __init__(
self: MultiscaleConvBlock,
num_input_channels: int,
kernel_sizes: int | tuple[int, int, int, int] | IntPair,
dilation_rates: int | tuple[int, int, int, int] | IntPair,
num_output_channels: int = 32,
strides: tuple[int, int] | np.ndarray = (1, 1),
activation: str = "relu",
*,
use_bias: bool = False,
) -> None:
"""Initialize :class:`MultiscaleConvBlock`."""
super().__init__()
self.conv_block_1 = ConvBnRelu(
num_input_channels=num_input_channels,
num_output_channels=num_output_channels,
kernel_size=kernel_sizes[0],
strides=strides,
activation=activation,
use_bias=use_bias,
dilation_rate=(dilation_rates[0], dilation_rates[0]),
)
self.conv_block_2 = ConvBnRelu(
num_input_channels=num_input_channels,
num_output_channels=num_output_channels,
kernel_size=kernel_sizes[1],
strides=strides,
activation=activation,
use_bias=use_bias,
dilation_rate=(dilation_rates[1], dilation_rates[1]),
)
self.conv_block_3 = ConvBnRelu(
num_input_channels=num_input_channels,
num_output_channels=num_output_channels,
kernel_size=kernel_sizes[2],
strides=strides,
activation=activation,
use_bias=use_bias,
dilation_rate=(dilation_rates[2], dilation_rates[2]),
)
self.conv_block_4 = ConvBnRelu(
num_input_channels=num_input_channels,
num_output_channels=num_output_channels,
kernel_size=kernel_sizes[3],
strides=strides,
activation=activation,
use_bias=use_bias,
dilation_rate=(dilation_rates[3], dilation_rates[3]),
)
[docs]
def forward(self: MultiscaleConvBlock, input_map: torch.Tensor) -> torch.Tensor:
"""Logic for using layers defined in MultiscaleConvBlock init.
This method defines how layers are used in forward operation.
Args:
input_map (torch.Tensor):
Input, the tensor is of the shape NCHW.
Returns:
output (torch.Tensor):
The inference output.
"""
conv0 = input_map
conv1 = self.conv_block_1(conv0)
conv2 = self.conv_block_2(conv0)
conv3 = self.conv_block_3(conv0)
conv4 = self.conv_block_4(conv0)
return torch.cat([conv1, conv2, conv3, conv4], dim=bn_axis)
[docs]
class ResidualConv(nn.Module):
"""Residual Convolution block.
Args:
num_input_channels (int):
Number of channels in input.
num_output_channels (int):
Number of channels in output.
kernel_size (int):
Size of the kernel in all convolution layers.
strides (int):
Size of the stride in all convolution layers.
use_bias (bool):
Whether to use bias in the convolution layers.
dilation_rate (int):
Dilation rate in all convolution layers.
Returns:
model (torch.nn.Module):
A pytorch model.
"""
def __init__(
self: ResidualConv,
num_input_channels: int,
num_output_channels: int = 32,
kernel_size: tuple[int, int] | np.ndarray = (3, 3),
strides: tuple[int, int] | np.ndarray = (1, 1),
dilation_rate: tuple[int, int] | np.ndarray = (1, 1),
*,
use_bias: bool = False,
) -> None:
"""Initialize :class:`ResidualConv`."""
super().__init__()
self.conv_block_1 = ConvBnRelu(
num_input_channels,
num_output_channels,
kernel_size=kernel_size,
strides=strides,
activation="None",
use_bias=use_bias,
dilation_rate=dilation_rate,
do_batchnorm=True,
)
self.conv_block_2 = ConvBnRelu(
num_output_channels,
num_output_channels,
kernel_size=kernel_size,
strides=strides,
activation="None",
use_bias=use_bias,
dilation_rate=dilation_rate,
do_batchnorm=True,
)
self.activation = nn.ReLU()
[docs]
def forward(self: ResidualConv, input_tensor: torch.Tensor) -> torch.Tensor:
"""Logic for using layers defined in ResidualConv init.
This method defines how layers are used in forward operation.
Args:
input_tensor (torch.Tensor):
Input, the tensor is of the shape NCHW.
Returns:
output (torch.Tensor):
The inference output.
"""
conv1 = self.conv_block_1(input_tensor)
conv2 = self.conv_block_2(conv1)
out = torch.add(conv1, conv2)
return self.activation(out)
[docs]
class NuClick(ModelABC):
"""NuClick Architecture.
NuClick is used for interactive nuclei segmentation.
NuClick takes an RGB image patch along with an inclusion and an exclusion map.
Args:
num_input_channels (int):
Number of channels in input.
num_output_channels (int):
Number of channels in output.
Returns:
model (torch.nn.Module): a pytorch model.
Examples:
>>> # instantiate a NuClick model for interactive nucleus segmentation.
>>> NuClick(num_input_channels = 5, num_output_channels = 1)
"""
def __init__(
self: NuClick,
num_input_channels: int,
num_output_channels: int,
) -> None:
"""Initialize :class:`NuClick`."""
super().__init__()
self.net_name = "NuClick"
self.n_channels = num_input_channels
self.n_classes = num_output_channels
# -------------Convolution + Batch Normalization + ReLu blocks------------
self.conv_block_1 = nn.Sequential(
ConvBnRelu(
num_input_channels=self.n_channels,
num_output_channels=64,
kernel_size=7,
),
ConvBnRelu(num_input_channels=64, num_output_channels=32, kernel_size=5),
ConvBnRelu(num_input_channels=32, num_output_channels=32, kernel_size=3),
)
self.conv_block_2 = nn.Sequential(
ConvBnRelu(num_input_channels=64, num_output_channels=64),
ConvBnRelu(num_input_channels=64, num_output_channels=32),
ConvBnRelu(num_input_channels=32, num_output_channels=32),
)
self.conv_block_3 = ConvBnRelu(
num_input_channels=32,
num_output_channels=self.n_classes,
kernel_size=(1, 1),
strides=1,
activation=None,
use_bias=True,
do_batchnorm=False,
)
# -------------Residual Convolution blocks------------
self.residual_block_1 = nn.Sequential(
ResidualConv(num_input_channels=32, num_output_channels=64),
ResidualConv(num_input_channels=64, num_output_channels=64),
)
self.residual_block_2 = ResidualConv(
num_input_channels=64,
num_output_channels=128,
)
self.residual_block_3 = ResidualConv(
num_input_channels=128,
num_output_channels=128,
)
self.residual_block_4 = nn.Sequential(
ResidualConv(num_input_channels=128, num_output_channels=256),
ResidualConv(num_input_channels=256, num_output_channels=256),
ResidualConv(num_input_channels=256, num_output_channels=256),
)
self.residual_block_5 = nn.Sequential(
ResidualConv(num_input_channels=256, num_output_channels=512),
ResidualConv(num_input_channels=512, num_output_channels=512),
ResidualConv(num_input_channels=512, num_output_channels=512),
)
self.residual_block_6 = nn.Sequential(
ResidualConv(num_input_channels=512, num_output_channels=1024),
ResidualConv(num_input_channels=1024, num_output_channels=1024),
)
self.residual_block_7 = nn.Sequential(
ResidualConv(num_input_channels=1024, num_output_channels=512),
ResidualConv(num_input_channels=512, num_output_channels=256),
)
self.residual_block_8 = ResidualConv(
num_input_channels=512,
num_output_channels=256,
)
self.residual_block_9 = ResidualConv(
num_input_channels=256,
num_output_channels=256,
)
self.residual_block_10 = nn.Sequential(
ResidualConv(num_input_channels=256, num_output_channels=128),
ResidualConv(num_input_channels=128, num_output_channels=128),
)
self.residual_block_11 = ResidualConv(
num_input_channels=128,
num_output_channels=64,
)
self.residual_block_12 = ResidualConv(
num_input_channels=64,
num_output_channels=64,
)
# -------------Multi-scale Convolution blocks------------
self.multiscale_block_1 = MultiscaleConvBlock(
num_input_channels=128,
num_output_channels=32,
kernel_sizes=(3, 3, 5, 5),
dilation_rates=(1, 3, 3, 6),
)
self.multiscale_block_2 = MultiscaleConvBlock(
num_input_channels=256,
num_output_channels=64,
kernel_sizes=(3, 3, 5, 5),
dilation_rates=(1, 3, 2, 3),
)
self.multiscale_block_3 = MultiscaleConvBlock(
num_input_channels=64,
num_output_channels=16,
kernel_sizes=(3, 3, 5, 7),
dilation_rates=(1, 3, 2, 6),
)
# -------------Max Pooling blocks------------
self.pool_block_1 = nn.MaxPool2d(kernel_size=(2, 2))
self.pool_block_2 = nn.MaxPool2d(kernel_size=(2, 2))
self.pool_block_3 = nn.MaxPool2d(kernel_size=(2, 2))
self.pool_block_4 = nn.MaxPool2d(kernel_size=(2, 2))
self.pool_block_5 = nn.MaxPool2d(kernel_size=(2, 2))
# -------------Transposed Convolution blocks------------
self.conv_transpose_1 = nn.ConvTranspose2d(
in_channels=1024,
out_channels=512,
kernel_size=2,
stride=(2, 2),
)
self.conv_transpose_2 = nn.ConvTranspose2d(
in_channels=256,
out_channels=256,
kernel_size=2,
stride=(2, 2),
)
self.conv_transpose_3 = nn.ConvTranspose2d(
in_channels=256,
out_channels=128,
kernel_size=2,
stride=(2, 2),
)
self.conv_transpose_4 = nn.ConvTranspose2d(
in_channels=128,
out_channels=64,
kernel_size=2,
stride=(2, 2),
)
self.conv_transpose_5 = nn.ConvTranspose2d(
in_channels=64,
out_channels=32,
kernel_size=2,
stride=(2, 2),
)
# pylint: disable=W0221
[docs]
def forward(self: NuClick, imgs: torch.Tensor) -> torch.Tensor:
"""Logic for using layers defined in NuClick init.
This method defines how layers are used in forward operation.
Args:
imgs (torch.Tensor): Input images, the tensor is of the shape NCHW.
Returns:
output (torch.Tensor): The inference output.
"""
conv1 = self.conv_block_1(imgs)
pool1 = self.pool_block_1(conv1)
conv2 = self.residual_block_1(pool1)
pool2 = self.pool_block_2(conv2)
conv3 = self.residual_block_2(pool2)
conv3 = self.multiscale_block_1(conv3)
conv3 = self.residual_block_3(conv3)
pool3 = self.pool_block_3(conv3)
conv4 = self.residual_block_4(pool3)
pool4 = self.pool_block_4(conv4)
conv5 = self.residual_block_5(pool4)
pool5 = self.pool_block_5(conv5)
conv51 = self.residual_block_6(pool5)
up61 = torch.cat([self.conv_transpose_1(conv51), conv5], dim=1)
conv61 = self.residual_block_7(up61)
up6 = torch.cat([self.conv_transpose_2(conv61), conv4], dim=1)
conv6 = self.residual_block_8(up6)
conv6 = self.multiscale_block_2(conv6)
conv6 = self.residual_block_9(conv6)
up7 = torch.cat([self.conv_transpose_3(conv6), conv3], dim=1)
conv7 = self.residual_block_10(up7)
up8 = torch.cat([self.conv_transpose_4(conv7), conv2], dim=1)
conv8 = self.residual_block_11(up8)
conv8 = self.multiscale_block_3(conv8)
conv8 = self.residual_block_12(conv8)
up9 = torch.cat([self.conv_transpose_5(conv8), conv1], dim=1)
conv9 = self.conv_block_2(up9)
return self.conv_block_3(conv9)
[docs]
@staticmethod
def postproc(
preds: np.ndarray,
thresh: float = 0.33,
min_size: int = 10,
min_hole_size: int = 30,
nuc_points: np.ndarray = None,
*,
do_reconstruction: bool = False,
) -> np.ndarray:
"""Post-processing.
Args:
preds (ndarray): list of prediction output of each patch and
assumed to be in the order of (no.patch, h, w) (match with the output
of `infer_batch`).
thresh (float): Threshold value. If a pixel has a predicted value larger
than the threshold, it will be classified as nuclei.
min_size (int): The smallest allowable object size.
min_hole_size (int): The maximum area, in pixels, of a contiguous hole
that will be filled.
do_reconstruction (bool): Whether to perform a morphological reconstruction
of an image.
nuc_points (ndarray): In the order of (no.patch, h, w).
In each patch, The pixel that has been 'clicked' is set to 1 and the
rest pixels are set to 0.
Returns:
masks (ndarray): pixel-wise nuclei instance segmentation
prediction, shape:(no.patch, h, w).
"""
masks = preds > thresh
masks = remove_small_objects(masks, min_size=min_size)
masks = remove_small_holes(masks, area_threshold=min_hole_size)
if do_reconstruction:
for i in range(len(masks)):
this_mask = masks[i, :, :]
this_marker = nuc_points[i, :, :] > 0
if np.any(this_mask[this_marker > 0]):
this_mask = reconstruction(
this_marker,
this_mask,
footprint=disk(1),
)
masks[i] = np.array([this_mask])
else:
logger.warning(
"Nuclei reconstruction was not done for nucleus #%d",
i,
stacklevel=2,
)
return masks
[docs]
@staticmethod
def infer_batch(
model: nn.Module,
batch_data: torch.Tensor,
*,
on_gpu: bool,
) -> np.ndarray:
"""Run inference on an input batch.
This contains logic for forward operation as well as batch i/o
aggregation.
Args:
model (nn.Module): PyTorch defined model.
batch_data (torch.Tensor): a batch of data generated by
torch.utils.data.DataLoader.
on_gpu (bool): Whether to run inference on a GPU.
Returns:
Pixel-wise nuclei prediction for each patch, shape: (no.patch, h, w).
"""
model.eval()
device = misc.select_device(on_gpu=on_gpu)
# Assume batch_data is NCHW
batch_data = batch_data.to(device).type(torch.float32)
with torch.inference_mode():
output = model(batch_data)
output = torch.sigmoid(output)
output = torch.squeeze(output, 1)
return output.cpu().numpy()