Source code for tiatoolbox.models.architecture.nuclick

"""Defines original NuClick architecture

Koohbanani, N. A., Jahanifar, M., Tajadin, N. Z., & Rajpoot, N. (2020).
NuClick: a deep learning framework for interactive segmentation of microscopic images.
Medical Image Analysis, 65, 101771.

"""
import warnings
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from skimage.morphology import (
    disk,
    reconstruction,
    remove_small_holes,
    remove_small_objects,
)

from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.utils import misc

bn_axis = 1


[docs]class ConvBnRelu(nn.Module): """Performs Convolution, Batch Normalization and activation. Args: num_input_channels (int): Number of channels in input. num_output_channels (int): Number of channels in output. kernel_size (int): Size of the kernel in the convolution layer. strides (int): Size of the stride in the convolution layer. use_bias (bool): Whether to use bias in the convolution layer. dilation_rate (int): Dilation rate in the convolution layer. activation (str): Name of the activation function to use. do_batchnorm (bool): Whether to do batch normalization after the convolution layer. Returns: model (torch.nn.Module): a pytorch model. """ def __init__( self, num_input_channels: int, num_output_channels: int, kernel_size: Union[Tuple[int, int], np.ndarray] = (3, 3), strides: Union[Tuple[int, int], np.ndarray] = (1, 1), use_bias: bool = False, dilation_rate: Union[Tuple[int, int], np.ndarray] = (1, 1), activation: str = "relu", do_batchnorm: bool = True, ): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(strides, int): strides = (strides, strides) self.conv_bn_relu = self.get_block( num_input_channels, num_output_channels, kernel_size, strides, use_bias, dilation_rate, activation, do_batchnorm, )
[docs] def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: """Logic for using layers defined in init. This method defines how layers are used in forward operation. Args: input_tensor (torch.Tensor): Input, the tensor is of the shape NCHW. Returns: output (torch.Tensor): The inference output. """ return self.conv_bn_relu(input_tensor)
[docs] @staticmethod def get_block( in_channels, out_channels, kernel_size, strides, use_bias, dilation_rate, activation, do_batchnorm, ): """Function to acquire a convolutional block. Args: in_channels (int): Number of channels in input. out_channels (int): Number of channels in output. kernel_size (list): Size of the kernel in the acquired convolution block. strides (int): Size of stride in the convolution layer. use_bias (bool): Whether to use bias in the convolution layer. dilation_rates (list): Dilation rate for each convolution layer. activation (str): Name of the activation function to use. do_batchnorm (bool): Whether to do batch normalization after the convolution layer. Returns: torch.nn.Sequential: a pytorch layer """ conv1 = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=strides, dilation=dilation_rate, bias=use_bias, padding="same", padding_mode="zeros", ) torch.nn.init.xavier_uniform_(conv1.weight) layers = [conv1] if do_batchnorm: layers.append(nn.BatchNorm2d(num_features=out_channels, eps=1.001e-5)) if activation == "relu": layers.append(nn.ReLU()) return nn.Sequential(*layers)
[docs]class MultiscaleConvBlock(nn.Module): """Defines Multiscale convolution block. Args: num_input_channels (int): Number of channels in input. num_output_channels (int): Number of channels in output. kernel_sizes (list): Size of the kernel in each convolution layer. strides (int): Size of stride in the convolution layer. use_bias (bool): Whether to use bias in the convolution layer. dilation_rates (list): Dilation rate for each convolution layer. activation (str): Name of the activation function to use. Returns: torch.nn.Module: A PyTorch model. """ def __init__( self, num_input_channels: int, kernel_sizes: Union[Tuple[int, int], np.ndarray], dilation_rates: Union[Tuple[int, int], np.ndarray], num_output_channels: int = 32, strides: Union[Tuple[int, int], np.ndarray] = (1, 1), activation: str = "relu", use_bias: bool = False, ): super().__init__() self.conv_block_1 = ConvBnRelu( num_input_channels=num_input_channels, num_output_channels=num_output_channels, kernel_size=kernel_sizes[0], strides=strides, activation=activation, use_bias=use_bias, dilation_rate=(dilation_rates[0], dilation_rates[0]), ) self.conv_block_2 = ConvBnRelu( num_input_channels=num_input_channels, num_output_channels=num_output_channels, kernel_size=kernel_sizes[1], strides=strides, activation=activation, use_bias=use_bias, dilation_rate=(dilation_rates[1], dilation_rates[1]), ) self.conv_block_3 = ConvBnRelu( num_input_channels=num_input_channels, num_output_channels=num_output_channels, kernel_size=kernel_sizes[2], strides=strides, activation=activation, use_bias=use_bias, dilation_rate=(dilation_rates[2], dilation_rates[2]), ) self.conv_block_4 = ConvBnRelu( num_input_channels=num_input_channels, num_output_channels=num_output_channels, kernel_size=kernel_sizes[3], strides=strides, activation=activation, use_bias=use_bias, dilation_rate=(dilation_rates[3], dilation_rates[3]), )
[docs] def forward(self, input_map): """Logic for using layers defined in MultiscaleConvBlock init. This method defines how layers are used in forward operation. Args: input_map (torch.Tensor): Input, the tensor is of the shape NCHW. Returns: output (torch.Tensor): The inference output. """ conv0 = input_map conv1 = self.conv_block_1(conv0) conv2 = self.conv_block_2(conv0) conv3 = self.conv_block_3(conv0) conv4 = self.conv_block_4(conv0) return torch.cat([conv1, conv2, conv3, conv4], dim=bn_axis)
[docs]class ResidualConv(nn.Module): """Residual Convolution block. Args: num_input_channels (int): Number of channels in input. num_output_channels (int): Number of channels in output. kernel_size (int): Size of the kernel in all convolution layers. strides (int): Size of the stride in all convolution layers. use_bias (bool): Whether to use bias in the convolution layers. dilation_rate (int): Dilation rate in all convolution layers. Returns: model (torch.nn.Module): A pytorch model. """ def __init__( self, num_input_channels: int, num_output_channels: int = 32, kernel_size: Union[Tuple[int, int], np.ndarray] = (3, 3), strides: Union[Tuple[int, int], np.ndarray] = (1, 1), use_bias: bool = False, dilation_rate: Union[Tuple[int, int], np.ndarray] = (1, 1), ): super().__init__() self.conv_block_1 = ConvBnRelu( num_input_channels, num_output_channels, kernel_size=kernel_size, strides=strides, activation="None", use_bias=use_bias, dilation_rate=dilation_rate, do_batchnorm=True, ) self.conv_block_2 = ConvBnRelu( num_output_channels, num_output_channels, kernel_size=kernel_size, strides=strides, activation="None", use_bias=use_bias, dilation_rate=dilation_rate, do_batchnorm=True, ) self.activation = nn.ReLU()
[docs] def forward(self, input_tensor): """Logic for using layers defined in ResidualConv init. This method defines how layers are used in forward operation. Args: input_tensor (torch.Tensor): Input, the tensor is of the shape NCHW. Returns: output (torch.Tensor): The inference output. """ conv1 = self.conv_block_1(input_tensor) conv2 = self.conv_block_2(conv1) out = torch.add(conv1, conv2) return self.activation(out)
[docs]class NuClick(ModelABC): """NuClick Architecture. NuClick is used for interactive nuclei segmentation. NuClick takes an RGB image patch along with an inclusion and an exclusion map. Args: num_input_channels (int): Number of channels in input. num_output_channels (int): Number of channels in output. Returns: model (torch.nn.Module): a pytorch model. Examples: >>> # instantiate a NuClick model for interactive nucleus segmentation. >>> NuClick(num_input_channels = 5, num_output_channels = 1) """ def __init__(self, num_input_channels: int, num_output_channels: int): super().__init__() self.net_name = "NuClick" self.n_channels = num_input_channels self.n_classes = num_output_channels # -------------Convolution + Batch Normalization + ReLu blocks------------ self.conv_block_1 = nn.Sequential( ConvBnRelu( num_input_channels=self.n_channels, num_output_channels=64, kernel_size=7, ), ConvBnRelu(num_input_channels=64, num_output_channels=32, kernel_size=5), ConvBnRelu(num_input_channels=32, num_output_channels=32, kernel_size=3), ) self.conv_block_2 = nn.Sequential( ConvBnRelu(num_input_channels=64, num_output_channels=64), ConvBnRelu(num_input_channels=64, num_output_channels=32), ConvBnRelu(num_input_channels=32, num_output_channels=32), ) self.conv_block_3 = ConvBnRelu( num_input_channels=32, num_output_channels=self.n_classes, kernel_size=(1, 1), strides=1, activation=None, use_bias=True, do_batchnorm=False, ) # -------------Residual Convolution blocks------------ self.residual_block_1 = nn.Sequential( ResidualConv(num_input_channels=32, num_output_channels=64), ResidualConv(num_input_channels=64, num_output_channels=64), ) self.residual_block_2 = ResidualConv( num_input_channels=64, num_output_channels=128 ) self.residual_block_3 = ResidualConv( num_input_channels=128, num_output_channels=128 ) self.residual_block_4 = nn.Sequential( ResidualConv(num_input_channels=128, num_output_channels=256), ResidualConv(num_input_channels=256, num_output_channels=256), ResidualConv(num_input_channels=256, num_output_channels=256), ) self.residual_block_5 = nn.Sequential( ResidualConv(num_input_channels=256, num_output_channels=512), ResidualConv(num_input_channels=512, num_output_channels=512), ResidualConv(num_input_channels=512, num_output_channels=512), ) self.residual_block_6 = nn.Sequential( ResidualConv(num_input_channels=512, num_output_channels=1024), ResidualConv(num_input_channels=1024, num_output_channels=1024), ) self.residual_block_7 = nn.Sequential( ResidualConv(num_input_channels=1024, num_output_channels=512), ResidualConv(num_input_channels=512, num_output_channels=256), ) self.residual_block_8 = ResidualConv( num_input_channels=512, num_output_channels=256 ) self.residual_block_9 = ResidualConv( num_input_channels=256, num_output_channels=256 ) self.residual_block_10 = nn.Sequential( ResidualConv(num_input_channels=256, num_output_channels=128), ResidualConv(num_input_channels=128, num_output_channels=128), ) self.residual_block_11 = ResidualConv( num_input_channels=128, num_output_channels=64 ) self.residual_block_12 = ResidualConv( num_input_channels=64, num_output_channels=64 ) # -------------Multi-scale Convolution blocks------------ self.multiscale_block_1 = MultiscaleConvBlock( num_input_channels=128, num_output_channels=32, kernel_sizes=[3, 3, 5, 5], dilation_rates=[1, 3, 3, 6], ) self.multiscale_block_2 = MultiscaleConvBlock( num_input_channels=256, num_output_channels=64, kernel_sizes=[3, 3, 5, 5], dilation_rates=[1, 3, 2, 3], ) self.multiscale_block_3 = MultiscaleConvBlock( num_input_channels=64, num_output_channels=16, kernel_sizes=[3, 3, 5, 7], dilation_rates=[1, 3, 2, 6], ) # -------------Max Pooling blocks------------ self.pool_block_1 = nn.MaxPool2d(kernel_size=(2, 2)) self.pool_block_2 = nn.MaxPool2d(kernel_size=(2, 2)) self.pool_block_3 = nn.MaxPool2d(kernel_size=(2, 2)) self.pool_block_4 = nn.MaxPool2d(kernel_size=(2, 2)) self.pool_block_5 = nn.MaxPool2d(kernel_size=(2, 2)) # -------------Transposed Convolution blocks------------ self.conv_transpose_1 = nn.ConvTranspose2d( in_channels=1024, out_channels=512, kernel_size=2, stride=(2, 2), ) self.conv_transpose_2 = nn.ConvTranspose2d( in_channels=256, out_channels=256, kernel_size=2, stride=(2, 2), ) self.conv_transpose_3 = nn.ConvTranspose2d( in_channels=256, out_channels=128, kernel_size=2, stride=(2, 2), ) self.conv_transpose_4 = nn.ConvTranspose2d( in_channels=128, out_channels=64, kernel_size=2, stride=(2, 2), ) self.conv_transpose_5 = nn.ConvTranspose2d( in_channels=64, out_channels=32, kernel_size=2, stride=(2, 2), ) # pylint: disable=W0221
[docs] def forward(self, imgs: torch.Tensor): """Logic for using layers defined in NuClick init. This method defines how layers are used in forward operation. Args: imgs (torch.Tensor): Input images, the tensor is of the shape NCHW. Returns: output (torch.Tensor): The inference output. """ conv1 = self.conv_block_1(imgs) pool1 = self.pool_block_1(conv1) conv2 = self.residual_block_1(pool1) pool2 = self.pool_block_2(conv2) conv3 = self.residual_block_2(pool2) conv3 = self.multiscale_block_1(conv3) conv3 = self.residual_block_3(conv3) pool3 = self.pool_block_3(conv3) conv4 = self.residual_block_4(pool3) pool4 = self.pool_block_4(conv4) conv5 = self.residual_block_5(pool4) pool5 = self.pool_block_5(conv5) conv51 = self.residual_block_6(pool5) up61 = torch.cat([self.conv_transpose_1(conv51), conv5], dim=1) conv61 = self.residual_block_7(up61) up6 = torch.cat([self.conv_transpose_2(conv61), conv4], dim=1) conv6 = self.residual_block_8(up6) conv6 = self.multiscale_block_2(conv6) conv6 = self.residual_block_9(conv6) up7 = torch.cat([self.conv_transpose_3(conv6), conv3], dim=1) conv7 = self.residual_block_10(up7) up8 = torch.cat([self.conv_transpose_4(conv7), conv2], dim=1) conv8 = self.residual_block_11(up8) conv8 = self.multiscale_block_3(conv8) conv8 = self.residual_block_12(conv8) up9 = torch.cat([self.conv_transpose_5(conv8), conv1], dim=1) conv9 = self.conv_block_2(up9) return self.conv_block_3(conv9)
[docs] @staticmethod def postproc( preds, thresh=0.33, min_size=10, min_hole_size=30, do_reconstruction=False, nuc_points=None, ): """Post processing. Args: preds (ndarray): list of prediction output of each patch and assumed to be in the order of (no.patch, h, w) (match with the output of `infer_batch`). thresh (float): Threshold value. If a pixel has a predicted value larger than the threshold, it will be classified as nuclei. min_size (int): The smallest allowable object size. min_hole_size (int): The maximum area, in pixels, of a contiguous hole that will be filled. do_reconstruction (bool): Whether to perform a morphological reconstruction of an image. nuc_points (ndarray): In the order of (no.patch, h, w). In each patch, The pixel that has been 'clicked' is set to 1 and the rest pixels are set to 0. Returns: masks (ndarray): pixel-wise nuclei instance segmentation prediction, shape:(no.patch, h, w). """ masks = preds > thresh masks = remove_small_objects(masks, min_size=min_size) masks = remove_small_holes(masks, area_threshold=min_hole_size) if do_reconstruction: for i in range(len(masks)): this_mask = masks[i, :, :] this_marker = nuc_points[i, :, :] > 0 if np.any(this_mask[this_marker > 0]): this_mask = reconstruction(this_marker, this_mask, selem=disk(1)) masks[i] = np.array([this_mask]) else: warnings.warn( f"Nuclei reconstruction was not done for nucleus #{i}" ) return masks
[docs] @staticmethod def infer_batch(model, batch_data, on_gpu): """Run inference on an input batch. This contains logic for forward operation as well as batch i/o aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (ndarray): a batch of data generated by torch.utils.data.DataLoader. on_gpu (bool): Whether to run inference on a GPU. Returns: Pixel-wise nuclei prediction for each patch, shape: (no.patch, h, w). """ model.eval() device = misc.select_device(on_gpu) # Assume batch_data is NCHW batch_data = batch_data.to(device).type(torch.float32) with torch.inference_mode(): output = model(batch_data) output = torch.sigmoid(output) output = torch.squeeze(output, 1) return output.cpu().numpy()