Source code for tiatoolbox.models.architecture.vanilla

"""Define vanilla CNNs with torch backbones, mainly for patch classification."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import torch
import torchvision.models as torch_models
from torch import nn

from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.utils.misc import select_device

if TYPE_CHECKING:  # pragma: no cover
    from torchvision.models import WeightsEnum


def _get_architecture(
    arch_name: str,
    weights: str or WeightsEnum = "DEFAULT",
    **kwargs: dict,
) -> list[nn.Sequential, ...] | nn.Sequential:
    """Get a model.

    Model architectures are either already defined within torchvision or
    they can be custom-made within tiatoolbox.

    Args:
        arch_name (str):
            Architecture name.
        weights (str or WeightsEnum):
            torchvision model weights (get_model_weights).
        kwargs (dict):
            Key-word arguments.

    Returns:
        List of PyTorch network layers wrapped with `nn.Sequential`.
        https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html

    """
    backbone_dict = {
        "alexnet": torch_models.alexnet,
        "resnet18": torch_models.resnet18,
        "resnet34": torch_models.resnet34,
        "resnet50": torch_models.resnet50,
        "resnet101": torch_models.resnet101,
        "resnext50_32x4d": torch_models.resnext50_32x4d,
        "resnext101_32x8d": torch_models.resnext101_32x8d,
        "wide_resnet50_2": torch_models.wide_resnet50_2,
        "wide_resnet101_2": torch_models.wide_resnet101_2,
        "densenet121": torch_models.densenet121,
        "densenet161": torch_models.densenet161,
        "densenet169": torch_models.densenet169,
        "densenet201": torch_models.densenet201,
        "inception_v3": torch_models.inception_v3,
        "googlenet": torch_models.googlenet,
        "mobilenet_v2": torch_models.mobilenet_v2,
        "mobilenet_v3_large": torch_models.mobilenet_v3_large,
        "mobilenet_v3_small": torch_models.mobilenet_v3_small,
    }
    if arch_name not in backbone_dict:
        msg = f"Backbone `{arch_name}` is not supported."
        raise ValueError(msg)

    creator = backbone_dict[arch_name]
    model = creator(weights=weights, **kwargs)

    # Unroll all the definition and strip off the final GAP and FCN
    if "resnet" in arch_name or "resnext" in arch_name:
        return nn.Sequential(*list(model.children())[:-2])
    if "densenet" in arch_name:
        return model.features
    if "alexnet" in arch_name:
        return model.features
    if "inception_v3" in arch_name or "googlenet" in arch_name:
        return nn.Sequential(*list(model.children())[:-3])

    return model.features


[docs] class CNNModel(ModelABC): """Retrieve the model backbone and attach an extra FCN to perform classification. Args: backbone (str): Model name. num_classes (int): Number of classes output by model. Attributes: num_classes (int): Number of classes output by the model. feat_extract (nn.Module): Backbone CNN model. pool (nn.Module): Type of pooling applied after feature extraction. classifier (nn.Module): Linear classifier module used to map the features to the output. """ def __init__(self: CNNModel, backbone: str, num_classes: int = 1) -> None: """Initialize :class:`CNNModel`.""" super().__init__() self.num_classes = num_classes self.feat_extract = _get_architecture(backbone) self.pool = nn.AdaptiveAvgPool2d((1, 1)) # Best way to retrieve channel dynamically is passing a small forward pass prev_num_ch = self.feat_extract(torch.rand([2, 3, 96, 96])).shape[1] self.classifier = nn.Linear(prev_num_ch, num_classes) # pylint: disable=W0221 # because abc is generic, this is actual definition
[docs] def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor: """Pass input data through the model. Args: imgs (torch.Tensor): Model input. """ feat = self.feat_extract(imgs) gap_feat = self.pool(feat) gap_feat = torch.flatten(gap_feat, 1) logit = self.classifier(gap_feat) return torch.softmax(logit, -1)
[docs] @staticmethod def postproc(image: np.ndarray) -> np.ndarray: """Define the post-processing of this class of model. This simply applies argmax along last axis of the input. """ return np.argmax(image, axis=-1)
[docs] @staticmethod def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, on_gpu: bool, ) -> np.ndarray: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. on_gpu (bool): Whether to run inference on a GPU. """ img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( torch.float32, ) # to NCHW img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() # Inference mode model.eval() # Do not compute the gradient (not training) with torch.inference_mode(): output = model(img_patches_device) # Output should be a single tensor or scalar return output.cpu().numpy()
[docs] class CNNBackbone(ModelABC): """Retrieve the model backbone and strip the classification layer. This is a wrapper for pretrained models within pytorch. Args: backbone (str): Model name. Currently, the tool supports following model names and their default associated weights from pytorch. - "alexnet" - "resnet18" - "resnet34" - "resnet50" - "resnet101" - "resnext50_32x4d" - "resnext101_32x8d" - "wide_resnet50_2" - "wide_resnet101_2" - "densenet121" - "densenet161" - "densenet169" - "densenet201" - "inception_v3" - "googlenet" - "mobilenet_v2" - "mobilenet_v3_large" - "mobilenet_v3_small" Examples: >>> # Creating resnet50 architecture from default pytorch >>> # without the classification layer with its associated >>> # weights loaded >>> model = CNNBackbone(backbone="resnet50") >>> model.eval() # set to evaluation mode >>> # dummy sample in NHWC form >>> samples = torch.rand(4, 3, 512, 512) >>> features = model(samples) >>> features.shape # features after global average pooling torch.Size([4, 2048]) """ def __init__(self: CNNBackbone, backbone: str) -> None: """Initialize :class:`CNNBackbone`.""" super().__init__() self.feat_extract = _get_architecture(backbone) self.pool = nn.AdaptiveAvgPool2d((1, 1)) # pylint: disable=W0221 # because abc is generic, this is actual definition
[docs] def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: """Pass input data through the model. Args: imgs (torch.Tensor): Model input. """ feat = self.feat_extract(imgs) gap_feat = self.pool(feat) return torch.flatten(gap_feat, 1)
[docs] @staticmethod def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, on_gpu: bool, ) -> list[np.ndarray, ...]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. Args: model (nn.Module): PyTorch defined model. batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. on_gpu (bool): Whether to run inference on a GPU. """ img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( torch.float32, ) # to NCHW img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() # Inference mode model.eval() # Do not compute the gradient (not training) with torch.inference_mode(): output = model(img_patches_device) # Output should be a single tensor or scalar return [output.cpu().numpy()]