Source code for tiatoolbox.utils.visualization

# ***** BEGIN GPL LICENSE BLOCK *****
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# The Original Code is Copyright (C) 2021, TIA Centre, University of Warwick
# All rights reserved.
# ***** END GPL LICENSE BLOCK *****

"""Visualisation and overlay functions used in tiatoolbox."""
import colorsys
import random
from typing import Tuple, Union

import cv2
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np


[docs]def random_colors(num_colors, bright=True): """Generate a number of random colors. To get visually distinct colors, generate them in HSV space then convert to RGB. Args: num_colors(int): Number of perceptively different colors to generate. bright(bool): To use bright color or not. Returns: List of (r, g, b) colors. """ brightness = 1.0 if bright else 0.7 hsv = [(i / num_colors, 1, brightness) for i in range(num_colors)] colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) random.shuffle(colors) return colors
[docs]def overlay_prediction_mask( img: np.ndarray, prediction: np.ndarray, alpha: float = 0.35, label_info: dict = None, min_val: float = 0.0, ax=None, return_ax: bool = True, ): """Generate an overlay, given a 2D prediction map. Args: img (ndarray): Input image to overlay the results on top of. prediction (ndarray): 2D prediction map. Multi-class prediction should have values ranging from 0 to N-1, where N is the number of classes. label_info (dict): A dictionary contains the mapping for each integer value within `prediction` to its string and color. [int] : (str, (int, int, int)). By default, integer will be taken as label and color will be random. min_val (float): Only consider predictions greater than or equal to `min_val`. Otherwise, the original WSI in those regions will be displayed. alpha (float): Opacity value used for the overlay. ax (ax): Matplotlib ax object. return_ax (bool): Whether to return the matplotlib ax object. If not, then the overlay array will be returned. Returns: If return_ax is True, return the matplotlib ax object. Else, return the overlay array. """ if img.shape[:2] != prediction.shape[:2]: raise ValueError( f"Mismatch shape " f"`img` {img.shape[:2]} vs `prediction` {prediction.shape[:2]}." ) if np.issubdtype(img.dtype, np.floating): if not (img.max() <= 1.0 and img.min() >= 0): raise ValueError("Not support float `img` outside [0, 1].") img = np.array(img * 255, dtype=np.uint8) # if `min_val` is defined, only display the overlay for areas with pred > min_val if min_val > 0: prediction_sel = prediction >= min_val overlay = img.copy() # generate random colours predicted_classes = sorted(np.unique(prediction).tolist()) if label_info is None: np.random.seed(123) label_info = {} for label_uid in predicted_classes: random_colour = np.random.choice(range(256), size=3) label_info[label_uid] = (str(label_uid), random_colour) else: # may need better error message check_uid_list = predicted_classes.copy() for label_uid, (label_name, label_colour) in label_info.items(): if label_uid in check_uid_list: check_uid_list.remove(label_uid) if not isinstance(label_uid, int): raise ValueError( "Wrong `label_info` format: label_uid " f"{[label_uid, (label_name, label_colour)]}" ) if not isinstance(label_name, str): raise ValueError( "Wrong `label_info` format: label_name " f"{[label_uid, (label_name, label_colour)]}" ) if not isinstance(label_colour, (tuple, list, np.ndarray)): raise ValueError( "Wrong `label_info` format: label_colour " f"{[label_uid, (label_name, label_colour)]}" ) if len(label_colour) != 3: raise ValueError( "Wrong `label_info` format: label_colour " f"{[label_uid, (label_name, label_colour)]}" ) # if len(check_uid_list) != 0: raise ValueError(f"Missing label for: {check_uid_list}.") rgb_prediction = np.zeros( [prediction.shape[0], prediction.shape[1], 3], dtype=np.uint8 ) for label_uid, (_, overlay_rgb) in label_info.items(): sel = prediction == label_uid rgb_prediction[sel] = overlay_rgb # add the overlay cv2.addWeighted(rgb_prediction, alpha, overlay, 1 - alpha, 0, overlay) overlay = overlay.astype(np.uint8) if min_val > 0.0: overlay[~prediction_sel] = img[~prediction_sel] if ax is None and not return_ax: return overlay # create colorbar parameters name_list = [v[0] for v in label_info.values()] color_list = [v[1] for v in label_info.values()] color_list = np.array(color_list) / 255 uid_list = list(label_info.keys()) cmap = mpl.colors.ListedColormap(color_list) colorbar_params = { "mappable": mpl.cm.ScalarMappable(cmap=cmap), "boundaries": uid_list + [uid_list[-1] + 1], "ticks": [b + 0.5 for b in uid_list], "spacing": "proportional", "orientation": "vertical", } # generate another ax, else using the provided if ax is None: _, ax = plt.subplots() ax.imshow(overlay) ax.axis("off") # generate colour bar cbar = plt.colorbar(**colorbar_params) cbar.ax.set_yticklabels(name_list) cbar.ax.tick_params(labelsize=12) return ax
[docs]def overlay_probability_map( img: np.ndarray, prediction: np.ndarray, alpha: float = 0.35, colour_map: str = "jet", min_val: float = 0.0, ax=None, return_ax: bool = True, ): """Generate an overlay, given a 2D prediction map. Args: img (ndarray): Input image to overlay the results on top of. Assumed to be HW. prediction (ndarray): 2D prediction map. Values are expected to be between 0-1. alpha (float): Opacity value used for the overlay. colour_map (string): The colour map to use for the heatmap. `jet` is used as the default. min_val (float): Only consider pixels that are greater than or equal to `min_val`. Otherwise, the original WSI in those regions will be displayed. alpha (float): Opacity value used for the overlay. ax (ax): Matplotlib ax object. return_ax (bool): Whether to return the matplotlib ax object. If not, then the overlay array will be returned. Returns: If return_ax is True, return the matplotlib ax object. Else, return the overlay array. """ if prediction.ndim != 2: raise ValueError("The input prediction must be 2-dimensional of the form HW.") if img.shape[:2] != prediction.shape[:2]: raise ValueError( "Mismatch shape `img` {0} vs `prediction` {1}.".format( img.shape[:2], prediction.shape[:2] ) ) prediction = prediction.astype(np.float32) if prediction.max() > 1.0: raise ValueError("Not support float `prediction` outside [0, 1].") if prediction.min() < 0: raise ValueError("Not support float `prediction` outside [0, 1].") if np.issubdtype(img.dtype, np.floating): if img.max() > 1.0: raise ValueError("Not support float `img` outside [0, 1].") if img.min() < 0: raise ValueError("Not support float `img` outside [0, 1].") img = np.array(img * 255, dtype=np.uint8) # if `min_val` is defined, only display the overlay for areas with prob > min_val if min_val < 0.0: raise ValueError(f"`min_val={min_val}` is not between [0, 1]") if min_val > 1.0: raise ValueError(f"`min_val={min_val}` is not between [0, 1]") prediction_sel = prediction >= min_val overlay = img.copy() cmap = plt.get_cmap(colour_map) prediction = np.squeeze(prediction.astype("float32")) # take RGB from RGBA heat map rgb_prediction = (cmap(prediction)[..., :3] * 255).astype("uint8") # add the overlay # cv2.addWeighted(rgb_prediction, alpha, overlay, 1 - alpha, 0, overlay) overlay = (1 - alpha) * rgb_prediction + alpha * overlay overlay[overlay > 255.0] = 255.0 overlay = overlay.astype(np.uint8) if min_val > 0.0: overlay[~prediction_sel] = img[~prediction_sel] if ax is None and not return_ax: return overlay colorbar_params = { "mappable": mpl.cm.ScalarMappable(cmap="jet"), "spacing": "proportional", "orientation": "vertical", } # generate another ax, else using the provided if ax is None: _, ax = plt.subplots() ax.imshow(overlay) ax.axis("off") # generate colour bar cbar = plt.colorbar(**colorbar_params) cbar.ax.tick_params(labelsize=12) return ax
[docs]def overlay_prediction_contours( canvas: np.ndarray, inst_dict: dict, draw_dot: bool = False, type_colours: dict = None, inst_colours: Union[np.ndarray, Tuple[int]] = (255, 255, 0), line_thickness: int = 2, ): """Overlaying instance contours on image. Internally, colours from `type_colours` are prioritized over `inst_colours`. However, if `inst_colours` is `None` and `type_colours` is not provided, random colour is generated for each instance. Args: canvas (ndarray): Image to draw predictions on. inst_dict (dict): Dictionary of instances. It is expected to be in the following format: {instance_id: {type: int, contour: List[List[int]], centroid:List[float]}. draw_dot (bool): To draw a dot for each centroid or not. type_colours (dict): A dict of {type_id : (type_name, colour)}, `type_id` is from 0-N and `colour` is a tuple of (R, G, B). inst_colours (tuple, np.ndarray): A colour to assign for all instances, or a list of colours to assigned for each instance in `inst_dict`. By default, all instances will have RGB colour `(255, 255, 0). line_thickness: line thickness of contours. Returns: (np.ndarray) The overlaid image. """ overlay = np.copy((canvas)) if inst_colours is None: inst_colours = random_colors(len(inst_dict)) inst_colours = np.array(inst_colours) * 255 inst_colours = inst_colours.astype(np.uint8) elif isinstance(inst_colours, tuple): inst_colours = np.array([inst_colours] * len(inst_dict)) elif not isinstance(inst_colours, np.ndarray): raise ValueError( f"`inst_colours` must be np.ndarray or tuple: {type(inst_colours)}" ) inst_colours = inst_colours.astype(np.uint8) for idx, [_, inst_info] in enumerate(inst_dict.items()): inst_contour = inst_info["contour"] if "type" in inst_info and type_colours is not None: inst_colour = type_colours[inst_info["type"]][1] else: inst_colour = (inst_colours[idx]).tolist() cv2.drawContours( overlay, [np.array(inst_contour)], -1, inst_colour, line_thickness ) if draw_dot: inst_centroid = inst_info["centroid"] inst_centroid = tuple([int(v) for v in inst_centroid]) overlay = cv2.circle(overlay, inst_centroid, 3, (255, 0, 0), -1) return overlay
[docs]def plot_graph( canvas: np.ndarray, nodes: np.ndarray, edges: np.ndarray, node_colors: Union[Tuple[int], np.ndarray] = (255, 0, 0), node_size: int = 5, edge_colors: Union[Tuple[int], np.ndarray] = (0, 0, 0), edge_size: int = 5, ): """Drawing a graph onto a canvas. Drawing a graph onto a canvas. Args: canvas (np.ndarray): Canvas to be drawn upon. nodes (np.ndarray): List of nodes, expected to be Nx2 where N is the number of nodes. Each node is expected to be of `(x, y)` and should be within the height and width of the canvas. edges (np.ndarray): List of egdes, expected to be Mx2 where M is the number of edges. Each edge is defined as `(src, dst)` where each is respectively the index of within `nodes`. node_colors (tuple or np.ndarray): A color or list of node colors. Each color is expected to be `(r, g, b)` and is between 0-255. edge_colors (tuple or np.ndarray): A color or list of node colors. Each color is expected to be `(r, g, b)` and is between 0-255. node_size (int): Radius of each node. edge_size (int): Linewidth of the edge. """ if isinstance(node_colors, tuple): node_colors = [node_colors] * len(nodes) if isinstance(edge_colors, tuple): edge_colors = [edge_colors] * len(edges) # draw the edges def to_int_tuple(x): """Helper to convert to tuple of int.""" return tuple([int(v) for v in x]) for idx, (src, dst) in enumerate(edges): src = to_int_tuple(nodes[src]) dst = to_int_tuple(nodes[dst]) color = to_int_tuple(edge_colors[idx]) cv2.line(canvas, src, dst, color, thickness=edge_size) # draw the nodes for idx, node in enumerate(nodes): node = to_int_tuple(node) color = to_int_tuple(node_colors[idx]) cv2.circle(canvas, node, node_size, color, thickness=-1) return canvas