Source code for tiatoolbox.models.architecture.idars

# ***** BEGIN GPL LICENSE BLOCK *****
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# The Original Code is Copyright (C) 2021, TIALab, University of Warwick
# All rights reserved.
# ***** END GPL LICENSE BLOCK *****
"""Defines CNNs as used in IDaRS for prediction of molecular pathways and mutations."""

import numpy as np
from torchvision import transforms

from tiatoolbox.models.architecture.vanilla import CNNModel


TRANSFORM = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.1, 0.1, 0.1]),
    ]
)


[docs]class IDaRS(CNNModel): """Retrieve the model and add custom preprocessing used in IDaRS paper. Args: backbone (str): Model name. num_classes (int): Number of classes output by model. """ def __init__(self, backbone, num_classes=1): super().__init__(backbone, num_classes=num_classes)
[docs] @staticmethod # skipcq: PYL-W0221 def preproc(img: np.ndarray): """Define preprocessing steps. Args: img (np.ndarray): An image of shape HWC. Return: img (torch.Tensor): An image of shape HWC. """ img = img.copy() img = TRANSFORM(img) # toTensor will turn image to CHW so we transpose again img = img.permute(1, 2, 0) return img