"""Define SCCNN architecture.
Sirinukunwattana, Korsuk, et al.
"Locality sensitive deep learning for detection and classification
of nuclei in routine colon cancer histology images."
IEEE transactions on medical imaging 35.5 (2016): 1196-1206.
"""
from __future__ import annotations
from collections import OrderedDict
import numpy as np
import torch
from skimage.feature import peak_local_max
from torch import nn
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.utils import misc
[docs]
class SCCNN(ModelABC):
"""Initialize SCCNN [1].
The following models have been included in tiatoolbox:
1. `sccnn-crchisto`:
This model is trained on `CRCHisto dataset
<https://warwick.ac.uk/fac/cross_fac/tia/data/crchistolabelednucleihe/>`_
2. `sccnn-conic`:
This model is trained on `CoNIC dataset
<https://conic-challenge.grand-challenge.org/evaluation/challenge/leaderboard//>`_
Centroids of ground truth masks were used to train this model.
The results are reported on the whole test data set including preliminary
and final set.
The original model was implemented in Matlab. The model has been reimplemented
in PyTorch for Python compatibility. The original model uses HRGB as input,
where 'H' represents hematoxylin. The model has been modified to rely on RGB
image as input.
The tiatoolbox model should produce the following results on the following datasets
using 8 pixels as radius for true detection:
.. list-table:: SCCNN performance
:widths: 15 15 15 15 15
:header-rows: 1
* - Model name
- Data set
- Precision
- Recall
- F1Score
* - sccnn-crchisto
- CRCHisto
- 0.82
- 0.80
- 0.81
* - sccnn-conic
- CoNIC
- 0.79
- 0.79
- 0.79
Args:
num_input_channels (int):
Number of channels in input. default=3.
patch_output_shape tuple(int):
Defines output height and output width. default=(13, 13).
radius (int):
Radius for nucleus detection, default = 12.
min_distance (int):
The minimal allowed distance separating peaks.
To find the maximum number of peaks, use `min_distance=1`, default=6.
threshold_abs (float):
Minimum intensity of peaks, default=0.20.
References:
[1] Sirinukunwattana, Korsuk, et al.
"Locality sensitive deep learning for detection and classification
of nuclei in routine colon cancer histology images."
IEEE transactions on medical imaging 35.5 (2016): 1196-1206.
"""
def __init__(
self: SCCNN,
num_input_channels: int = 3,
patch_output_shape: tuple[int, int] = (13, 13),
radius: int = 12,
min_distance: int = 6,
threshold_abs: float = 0.20,
) -> None:
"""Initialize :class:`SCCNN`."""
super().__init__()
out_height = patch_output_shape[0]
out_width = patch_output_shape[1]
self.in_ch = num_input_channels
self.out_height = out_height
self.out_width = out_width
# Create mesh grid and convert to 3D vector
x, y = torch.meshgrid(
torch.arange(start=0, end=out_height),
torch.arange(start=0, end=out_width),
indexing="ij",
)
self.register_buffer("xv", torch.unsqueeze(x, dim=0).type(torch.float32))
self.register_buffer("yv", torch.unsqueeze(y, dim=0).type(torch.float32))
self.radius = radius
self.min_distance = min_distance
self.threshold_abs = threshold_abs
def conv_act_block(
in_channels: int,
out_channels: int,
kernel_size: int,
) -> torch.nn.ModuleDict:
"""Convolution and activation branch for SCCNN.
This module combines the convolution and activation blocks in a single
function.
Args:
in_channels (int):
Number of channels in input.
out_channels (int):
Number of required channels in output.
kernel_size (int):
Kernel size of convolution filter.
Returns:
torch.nn.ModuleDict:
Module dictionary.
"""
module_dict = OrderedDict()
module_dict["conv1"] = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=(kernel_size, kernel_size),
stride=(1, 1),
padding=0,
bias=True,
),
nn.ReLU(),
)
return nn.ModuleDict(module_dict)
def spatially_constrained_layer1(
in_channels: int,
out_channels: int,
) -> torch.nn.ModuleDict:
"""Spatially constrained layer.
Takes fully connected layer and returns outputs for creating probability
map for the output. The output is Tensor is 3-dimensional where it defines
the row, height of the centre of nucleus and its confidence value.
Args:
in_channels (int):
Number of channels in input.
out_channels (int):
Number of required channels in output.
Returns:
torch.nn.ModuleDict:
Module dictionary.
"""
module_dict = OrderedDict()
module_dict["conv1"] = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=0,
bias=True,
),
nn.Sigmoid(),
)
return nn.ModuleDict(module_dict)
module_dict = OrderedDict()
module_dict["l1"] = conv_act_block(num_input_channels, 30, 2)
module_dict["pool1"] = nn.MaxPool2d(2, padding=0)
module_dict["l2"] = conv_act_block(30, 60, 2)
module_dict["pool2"] = nn.MaxPool2d(2, padding=0)
module_dict["l3"] = conv_act_block(60, 90, 3)
module_dict["l4"] = conv_act_block(90, 1024, 5)
module_dict["dropout1"] = nn.Dropout2d(p=0.5)
module_dict["l5"] = conv_act_block(1024, 512, 1)
module_dict["dropout2"] = nn.Dropout2d(p=0.5)
module_dict["sc"] = spatially_constrained_layer1(512, 3)
self.layer = nn.ModuleDict(module_dict)
[docs]
def spatially_constrained_layer2(
self: SCCNN,
sc1_0: torch.Tensor,
sc1_1: torch.Tensor,
sc1_2: torch.Tensor,
) -> torch.Tensor:
"""Spatially constrained layer 2.
Estimates row, column and height for sc2 layer mapping.
Args:
sc1_0 (torch.Tensor):
Output of spatially_constrained_layer1 estimating
the x position of the nucleus.
sc1_1 (torch.Tensor):
Output of spatially_constrained_layer1 estimating
the y position of the nucleus.
sc1_2 (torch.Tensor):
Output of spatially_constrained_layer1 estimating
the confidence in nucleus detection.
Returns:
:class:`torch.Tensor`:
Probability map using the estimates from
spatially_constrained_layer1.
"""
x = torch.tile(self.xv, dims=[sc1_0.size(0), 1, 1, 1]) # Tile for batch size
y = torch.tile(self.yv, dims=[sc1_0.size(0), 1, 1, 1])
xvr = (x - sc1_0) ** 2
yvc = (y - sc1_1) ** 2
out_map = xvr + yvc
out_map_threshold = torch.lt(out_map, self.radius).type(torch.float32)
denominator = 1 + (out_map / 2)
sc2 = sc1_2 / denominator
return sc2 * out_map_threshold
[docs]
@staticmethod
def preproc(image: torch.Tensor) -> torch.Tensor:
"""Transforming network input to desired format.
This method is model and dataset specific, meaning that it can be replaced by
user's desired transform function before training/inference.
Args:
image (torch.Tensor): Input images, the tensor is of the shape NCHW.
Returns:
output (torch.Tensor): The transformed input.
"""
return image / 255.0
[docs]
def forward( # skipcq: PYL-W0221
self: SCCNN,
input_tensor: torch.Tensor,
) -> torch.Tensor:
"""Logic for using layers defined in init.
This method defines how layers are used in forward operation.
Args:
input_tensor (torch.Tensor):
Input images, the tensor is in the shape of NCHW.
Returns:
torch.Tensor:
Output map for cell detection. Peak detection should be applied
to this output for cell detection.
"""
def spatially_constrained_layer1(
layer: torch.nn.Module,
in_tensor: torch.Tensor,
out_height: int = 13,
out_width: int = 13,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Spatially constrained layer 1.
Estimates row, column and height for
`spatially_constrained_layer2` layer mapping.
Args:
layer (torch.nn.Module):
Torch layer as ModuleDict.
in_tensor (torch.Tensor):
Input Tensor.
out_height (int):
Output height.
out_width (int):
Output Width
Returns:
tuple:
Parameters for the requested nucleus location:
- torch.Tensor - Row location for the centre of the nucleus.
- torch.Tensor - Column location for the centre of the nucleus.
- torch.Tensor - Peak value for the probability function indicating
confidence value for the estimate.
"""
sigmoid = layer["conv1"](in_tensor)
sigmoid0 = sigmoid[:, 0:1, :, :] * (out_height - 1)
sigmoid1 = sigmoid[:, 1:2, :, :] * (out_width - 1)
sigmoid2 = sigmoid[:, 2:3, :, :]
return sigmoid0, sigmoid1, sigmoid2
input_tensor = self.preproc(input_tensor)
l1 = self.layer["l1"]["conv1"](input_tensor)
p1 = self.layer["pool1"](l1)
l2 = self.layer["l2"]["conv1"](p1)
p2 = self.layer["pool1"](l2)
l3 = self.layer["l3"]["conv1"](p2)
l4 = self.layer["l4"]["conv1"](l3)
drop1 = self.layer["dropout1"](l4)
l5 = self.layer["l5"]["conv1"](drop1)
drop2 = self.layer["dropout2"](l5)
s1_sigmoid0, s1_sigmoid1, s1_sigmoid2 = spatially_constrained_layer1(
self.layer["sc"],
drop2,
)
return self.spatially_constrained_layer2(s1_sigmoid0, s1_sigmoid1, s1_sigmoid2)
# skipcq: PYL-W0221 # noqa: ERA001
[docs]
def postproc(self: SCCNN, prediction_map: np.ndarray) -> np.ndarray:
"""Post-processing script for MicroNet.
Performs peak detection and extracts coordinates in x, y format.
Args:
prediction_map (ndarray):
Input image of type numpy array.
Returns:
:class:`numpy.ndarray`:
Pixel-wise nuclear instance segmentation
prediction.
"""
coordinates = peak_local_max(
np.squeeze(prediction_map[0], axis=2),
min_distance=self.min_distance,
threshold_abs=self.threshold_abs,
exclude_border=False,
)
return np.fliplr(coordinates)
[docs]
@staticmethod
def infer_batch(
model: nn.Module,
batch_data: np.ndarray | torch.Tensor,
*,
on_gpu: bool,
) -> list[np.ndarray]:
"""Run inference on an input batch.
This contains logic for forward operation as well as batch I/O
aggregation.
Args:
model (nn.Module):
PyTorch defined model.
batch_data (:class:`numpy.ndarray` or :class:`torch.Tensor`):
A batch of data generated by
`torch.utils.data.DataLoader`.
on_gpu (bool):
Whether to run inference on a GPU.
Returns:
list of :class:`numpy.ndarray`:
Output probability map.
"""
patch_imgs = batch_data
device = misc.select_device(on_gpu=on_gpu)
patch_imgs_gpu = patch_imgs.to(device).type(torch.float32)
# to NCHW
patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous()
model.eval() # infer mode
# --------------------------------------------------------------
with torch.inference_mode():
pred = model(patch_imgs_gpu)
pred = pred.permute(0, 2, 3, 1).contiguous()
pred = pred.cpu().numpy()
return [
pred,
]