"""Define utility layers and operators for models in tiatoolbox."""
from __future__ import annotations
import sys
from typing import cast
import numpy as np
import torch
from torch import nn
from tiatoolbox import logger
[docs]
def is_torch_compile_compatible() -> bool:
"""Check if the current GPU is compatible with torch-compile.
Returns:
True if current GPU is compatible with torch-compile, False otherwise.
Raises:
Warning if GPU is not compatible with `torch.compile`.
"""
gpu_compatibility = True
if torch.cuda.is_available(): # pragma: no cover
device_cap = torch.cuda.get_device_capability()
if device_cap not in ((7, 0), (8, 0), (9, 0)):
logger.warning(
"GPU is not compatible with torch.compile. "
"Compatible GPUs include NVIDIA V100, A100, and H100. "
"Speedup numbers may be lower than expected.",
stacklevel=2,
)
gpu_compatibility = False
else:
logger.warning(
"No GPU detected or cuda not installed, "
"torch.compile is only supported on selected NVIDIA GPUs. "
"Speedup numbers may be lower than expected.",
stacklevel=2,
)
gpu_compatibility = False
return gpu_compatibility
[docs]
def compile_model(
model: nn.Module,
*,
mode: str = "default",
) -> nn.Module:
"""A decorator to compile a model using torch-compile.
Args:
model (torch.nn.Module):
Model to be compiled.
mode (str):
Mode to be used for torch-compile. Available modes are:
- `disable` disables torch-compile
- `default` balances performance and overhead
- `reduce-overhead` reduces overhead of CUDA graphs (useful for small
batches)
- `max-autotune` leverages Triton/template based matrix multiplications
on GPUs
- `max-autotune-no-cudagraphs` similar to “max-autotune” but without
CUDA graphs
Returns:
torch.nn.Module:
Compiled model.
"""
if mode == "disable":
return model
# Check if GPU is compatible with torch.compile
gpu_compatibility = is_torch_compile_compatible()
if not gpu_compatibility:
return model
if sys.platform == "win32": # pragma: no cover
msg = (
"`torch.compile` is not supported on Windows. Please see "
"https://github.com/pytorch/pytorch/issues/122094."
)
logger.warning(msg=msg)
return model
if isinstance( # pragma: no cover
model,
torch._dynamo.eval_frame.OptimizedModule, # skipcq: PYL-W0212 # noqa: SLF001
):
logger.info(
("The model is already compiled. ",),
)
return model
return cast("nn.Module", torch.compile(model, mode=mode)) # pragma: no cover
[docs]
def centre_crop(
img: np.ndarray | torch.Tensor,
crop_shape: np.ndarray | torch.Tensor | tuple,
data_format: str = "NCHW",
) -> np.ndarray | torch.Tensor:
"""A function to center crop image with given crop shape.
Args:
img (:class:`numpy.ndarray`, torch.Tensor):
Input image, should be of 3 channels.
crop_shape (:class:`numpy.ndarray`, torch.Tensor):
The subtracted amount in the form of `[subtracted height,
subtracted width]`.
data_format (str):
Either `"NCHW"` or `"NHWC"`.
Returns:
(:class:`numpy.ndarray`, torch.Tensor):
Cropped image.
"""
if data_format not in ["NCHW", "NHWC"]:
msg = f"Unknown input format `{data_format}`."
raise ValueError(msg)
crop_t = crop_shape[0] // 2
crop_b = crop_shape[0] - crop_t
crop_l = crop_shape[1] // 2
crop_r = crop_shape[1] - crop_l
if data_format == "NCHW":
return img[:, :, crop_t:-crop_b, crop_l:-crop_r]
return img[:, crop_t:-crop_b, crop_l:-crop_r, :]
[docs]
def centre_crop_to_shape(
x: np.ndarray | torch.Tensor,
y: np.ndarray | torch.Tensor,
data_format: str = "NCHW",
) -> np.ndarray | torch.Tensor:
"""A function to center crop image to shape.
Centre crop `x` so that `x` has shape of `y` and `y` height and
width must be smaller than `x` height width.
Args:
x (:class:`numpy.ndarray`, torch.Tensor):
Image to be cropped.
y (:class:`numpy.ndarray`, torch.Tensor):
Reference image for getting cropping shape, should be of 3
channels.
data_format:
Either `"NCHW"` or `"NHWC"`.
Returns:
(:class:`numpy.ndarray`, torch.Tensor):
Cropped image.
"""
if data_format not in ["NCHW", "NHWC"]:
msg = f"Unknown input format `{data_format}`."
raise ValueError(msg)
if data_format == "NCHW":
_, _, h1, w1 = x.shape
_, _, h2, w2 = y.shape
else:
_, h1, w1, _ = x.shape
_, h2, w2, _ = y.shape
if h1 <= h2 or w1 <= w2:
raise ValueError(
(
"Height or width of `x` is smaller than `y` ",
f"{[h1, w1]} vs {[h2, w2]}",
),
)
x_shape = x.shape
y_shape = y.shape
if data_format == "NCHW":
crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3])
else:
crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2])
return centre_crop(x, crop_shape, data_format)
[docs]
class UpSample2x(nn.Module):
"""A layer to scale input by a factor of 2.
This layer uses Kronecker product underneath rather than the default
pytorch interpolation.
"""
def __init__(self: UpSample2x) -> None:
"""Initialize :class:`UpSample2x`."""
super().__init__()
# correct way to create constant within module
self.unpool_mat: torch.Tensor
self.register_buffer(
"unpool_mat",
torch.from_numpy(np.ones((2, 2), dtype="float32")),
)
self.unpool_mat.unsqueeze(0)
[docs]
def forward(self: UpSample2x, x: torch.Tensor) -> torch.Tensor:
"""Logic for using layers defined in init.
Args:
x (torch.Tensor):
Input images, the tensor is in the shape of NCHW.
Returns:
torch.Tensor:
Input images upsampled by a factor of 2 via nearest
neighbour interpolation. The tensor is the shape as
NCHW.
"""
input_shape = list(x.shape)
# un-squeeze is the same as expand_dims
# permute is the same as transpose
# view is the same as reshape
x = x.unsqueeze(-1) # bchwx1
mat = self.unpool_mat.unsqueeze(0) # 1xshxsw
ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw
ret = ret.permute(0, 1, 2, 4, 3, 5)
return ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2))