# ***** BEGIN GPL LICENSE BLOCK *****
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick
# All rights reserved.
# ***** END GPL LICENSE BLOCK *****
"""Defines utlity layers and operators for models in tiatoolbox."""
from typing import Union
import numpy as np
import torch
import torch.nn as nn
[docs]def centre_crop(
img: Union[np.ndarray, torch.tensor],
crop_shape: Union[np.ndarray, torch.tensor],
data_format: str = "NCHW",
):
"""A function to center crop image with given crop shape.
Args:
img (ndarray, torch.tensor): input image, should be of 3 channels
crop_shape (ndarray, torch.tensor): the substracted amount in the form of
[substracted height, substracted width].
data_format (str): choose either `NCHW` or `NHWC`
Returns:
(ndarray, torch.tensor) Cropped image.
"""
if data_format not in ["NCHW", "NHWC"]:
raise ValueError(f"Unknown input format `{data_format}`")
crop_t = crop_shape[0] // 2
crop_b = crop_shape[0] - crop_t
crop_l = crop_shape[1] // 2
crop_r = crop_shape[1] - crop_l
if data_format == "NCHW":
img = img[:, :, crop_t:-crop_b, crop_l:-crop_r]
else:
img = img[:, crop_t:-crop_b, crop_l:-crop_r, :]
return img
[docs]def centre_crop_to_shape(
x: Union[np.ndarray, torch.tensor],
y: Union[np.ndarray, torch.tensor],
data_format: str = "NCHW",
):
"""A function to center crop image to shape.
Centre crop `x` so that `x` has shape of `y` and `y` height and width must
be smaller than `x` heigh width.
Args:
x (ndarray, torch.tensor): Image to be cropped.
y (ndarray, torch.tensor): Reference image for getting cropping shape,
should be of 3 channels.
data_format: Should either be `NCHW` or `NHWC`.
Returns:
(ndarray, torch.tensor) Cropped image.
"""
if data_format not in ["NCHW", "NHWC"]:
raise ValueError(f"Unknown input format `{data_format}`")
if data_format == "NCHW":
_, _, h1, w1 = x.shape
_, _, h2, w2 = y.shape
else:
_, h1, w1, _ = x.shape
_, h2, w2, _ = y.shape
if h1 <= h2 or w1 <= w2:
raise ValueError(
(
"Height or width of `x` is smaller than `y` ",
f"{[h1, w1]} vs {[h2, w2]}",
)
)
x_shape = x.shape
y_shape = y.shape
if data_format == "NCHW":
crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3])
else:
crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2])
return centre_crop(x, crop_shape, data_format)
[docs]class UpSample2x(nn.Module):
"""A layer to scale input by a factor of 2.
This layer uses Kronecker product underneath rather than the default
pytorch interpolation.
"""
def __init__(self):
super().__init__()
# correct way to create constant within module
self.register_buffer(
"unpool_mat", torch.from_numpy(np.ones((2, 2), dtype="float32"))
)
self.unpool_mat.unsqueeze(0)
[docs] def forward(self, x: torch.Tensor):
"""Logic for using layers defined in init.
Args:
x (torch.Tensor): Input images, the tensor is in the shape of NCHW.
Returns:
ret (torch.Tensor): Input images upsampled by a factor of 2
via nearest neighbour interpolation. The tensor is the shape
as NCHW.
"""
input_shape = list(x.shape)
# un-squeeze is the same as expand_dims
# permute is the same as transpose
# view is the same as reshape
x = x.unsqueeze(-1) # bchwx1
mat = self.unpool_mat.unsqueeze(0) # 1xshxsw
ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw
ret = ret.permute(0, 1, 2, 4, 3, 5)
ret = ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2))
return ret