Source code for tiatoolbox.tools.graph

"""Construction and visualisation of graphs for WSI prediction."""

from __future__ import annotations

from collections import defaultdict
from numbers import Number
from typing import TYPE_CHECKING, Callable

import numpy as np
import torch
import umap
from matplotlib import pyplot as plt
from scipy.cluster import hierarchy
from scipy.spatial import Delaunay, cKDTree

if TYPE_CHECKING:  # pragma: no cover
    from matplotlib.axes import Axes
    from numpy.typing import ArrayLike


[docs] def delaunay_adjacency(points: ArrayLike, dthresh: float) -> list: """Create an adjacency matrix via Delaunay triangulation from a list of coordinates. Points which are further apart than dthresh will not be connected. See https://en.wikipedia.org/wiki/Adjacency_matrix. Args: points (ArrayLike): An nxm list of coordinates. dthresh (float): Distance threshold for triangulation. Returns: ArrayLike: Adjacency matrix of shape NxN where 1 indicates connected and 0 indicates unconnected. Example: >>> rng = np.random.default_rng() >>> points = rng.random((100, 2)) >>> adjacency = delaunay_adjacency(points) """ # Validate inputs if not isinstance(dthresh, Number): msg = "dthresh must be a number." raise TypeError(msg) if len(points) < 4: # noqa: PLR2004 msg = "Points must have length >= 4." raise ValueError(msg) if len(np.shape(points)) != 2: # noqa: PLR2004 msg = "Points must have an NxM shape." raise ValueError(msg) # Apply Delaunay triangulation to the coordinates to get a # tessellation of triangles. tessellation = Delaunay(points) # Find all connected neighbours for each point in the set of # triangles. Starting with an empty dictionary. triangle_neighbours: defaultdict triangle_neighbours = defaultdict(set) # Iterate over each triplet of point indexes which denotes a # triangle within the tessellation. for index_triplet in tessellation.simplices: for index in index_triplet: connected = set(index_triplet) connected.remove(index) # Do not allow connection to itself. triangle_neighbours[index] = triangle_neighbours[index].union(connected) # Initialise the nxn adjacency matrix with zeros. adjacency = np.zeros((len(points), len(points))) # Fill the adjacency matrix: for index in triangle_neighbours: neighbours = triangle_neighbours[index] neighbours = np.array(list(neighbours), dtype=int) kdtree = cKDTree(points[neighbours, :]) nearby_neighbours = kdtree.query_ball_point( x=points[index], r=dthresh, ) neighbours = neighbours[nearby_neighbours] adjacency[index, neighbours] = 1.0 adjacency[neighbours, index] = 1.0 # Return neighbours of each coordinate as an affinity (adjacency # in this case) matrix. return adjacency
[docs] def triangle_signed_area(triangle: ArrayLike) -> int: """Determine the signed area of a triangle. Args: triangle (ArrayLike): A 3x2 list of coordinates. Returns: int: The signed area of the triangle. It will be negative if the triangle has a clockwise winding, negative if the triangle has a counter-clockwise winding, and zero if the triangles points are collinear. """ # Validate inputs triangle = np.asarray(triangle) if triangle.shape != (3, 2): msg = "Input triangle must be a 3x2 array." raise ValueError(msg) # Calculate the area of the triangle return 0.5 * ( triangle[0, 0] * (triangle[1, 1] - triangle[2, 1]) + triangle[1, 0] * (triangle[2, 1] - triangle[0, 1]) + triangle[2, 0] * (triangle[0, 1] - triangle[1, 1]) )
[docs] def edge_index_to_triangles(edge_index: ArrayLike) -> ArrayLike: """Convert an edged index to triangle simplices (triplets of coordinate indices). Args: edge_index (ArrayLike): An Nx2 array of edges. Returns: ArrayLike: An Nx3 array of triangles. Example: >>> rng = np.random.default_rng() >>> points = rng.random((100, 2)) >>> adjacency = delaunay_adjacency(points) >>> edge_index = affinity_to_edge_index(adjacency) >>> triangles = edge_index_to_triangles(edge_index) """ # Validate inputs edge_index_shape = np.shape(edge_index) if edge_index_shape[0] != 2 or len(edge_index_shape) != 2: # noqa: PLR2004 msg = "Input edge_index must be a 2xM matrix." raise ValueError(msg) nodes = np.unique(edge_index).tolist() neighbours = defaultdict(set) edges = edge_index.T.tolist() # Find the neighbours of each node for a, b in edges: neighbours[a].add(b) neighbours[b].add(a) # Remove any nodes with less than two neighbours nodes = [node for node in nodes if len(neighbours[node]) >= 2] # noqa: PLR2004 # Find the triangles triangles = set() for node in nodes: for neighbour in neighbours[node]: overlap = neighbours[node].intersection(neighbours[neighbour]) while overlap: triangles.add(frozenset({node, neighbour, overlap.pop()})) return np.array([list(tri) for tri in triangles], dtype=np.int32, order="C")
[docs] def affinity_to_edge_index( affinity_matrix: torch.Tensor | ArrayLike, threshold: float = 0.5, ) -> torch.tensor | ArrayLike: """Convert an affinity matrix (similarity matrix) to an edge index. Converts an NxN affinity matrix to a 2xM edge index, where M is the number of node pairs with a similarity greater than the threshold value (defaults to 0.5). Args: affinity_matrix: An NxN matrix of affinities between nodes. threshold (Number): Threshold above which to be considered connected. Defaults to 0.5. Returns: ArrayLike or torch.Tensor: The edge index of shape (2, M). Example: >>> rng = np.random.default_rng() >>> points = rng.random((100, 2)) >>> adjacency = delaunay_adjacency(points) >>> edge_index = affinity_to_edge_index(adjacency) """ # Validate inputs input_shape = np.shape(affinity_matrix) if len(input_shape) != 2 or len(np.unique(input_shape)) != 1: # noqa: PLR2004 msg = "Input affinity_matrix must be square (NxN)." raise ValueError(msg) # Handle cases for pytorch and numpy inputs if isinstance(affinity_matrix, torch.Tensor): return (affinity_matrix > threshold).nonzero().t().contiguous() return np.ascontiguousarray( np.stack((affinity_matrix > threshold).nonzero(), axis=1).T, )
[docs] class SlideGraphConstructor: """Construct a graph using the SlideGraph+ (Liu et al. 2021) method. This uses a hybrid agglomerative clustering which uses a weighted combination of spatial distance (within the WSI) and feature-space distance to group patches into nodes. See the `build` function for more details on the graph construction method. """ @staticmethod def _umap_reducer(graph: dict[str, ArrayLike]) -> ArrayLike: """Default reduction which reduces `graph["x"]` to 3D values. Reduces graph features to 3D values using UMAP which are suitable for plotting as RGB values. Args: graph (dict): A graph with keys "x", "edge_index", and optionally "coordinates". Returns: ArrayLike: A UMAP embedding of `graph["x"]` with shape (N, 3) and values ranging from 0 to 1. """ reducer = umap.UMAP(n_components=3) reduced = reducer.fit_transform(graph["x"]) reduced -= reduced.min(axis=0) reduced /= reduced.max(axis=0) return reduced
[docs] @staticmethod def build( points: ArrayLike, features: ArrayLike, lambda_d: float = 3.0e-3, lambda_f: float = 1.0e-3, lambda_h: float = 0.8, connectivity_distance: int = 4000, neighbour_search_radius: int = 2000, feature_range_thresh: float | None = 1e-4, ) -> dict[str, ArrayLike]: """Build a graph via hybrid clustering in spatial and feature space. The graph is constructed via hybrid hierarchical clustering followed by Delaunay triangulation of these cluster centroids. This is part of the SlideGraph pipeline but may be used to construct a graph in general from point coordinates and features. The clustering uses a distance kernel, ranging between 0 and 1, which is a weighted product of spatial distance (distance between coordinates in `points`, e.g. WSI location) and feature-space distance (e.g. ResNet features). Points which are spatially further apart than `neighbour_search_radius` are given a similarity of 1 (most dissimilar). This significantly speeds up computation. This distance metric is then used to form clusters via hierarchical/agglomerative clustering. Next, a Delaunay triangulation is applied to the clusters to connect the neighouring clusters. Only clusters which are closer than `connectivity_distance` in the spatial domain will be connected. Args: points (ArrayLike): A list of (x, y) spatial coordinates, e.g. pixel locations within a WSI. features (ArrayLike): A list of features associated with each coordinate in `points`. Must be the same length as `points`. lambda_d (Number): Spatial distance (d) weighting. lambda_f (Number): Feature distance (f) weighting. lambda_h (Number): Clustering distance threshold. Applied to the similarity kernel (1-fd). Ranges between 0 and 1. Defaults to 0.8. A good value for this parameter will depend on the intra-cluster variance. connectivity_distance (Number): Spatial distance threshold to consider points as connected during the Delaunay triangulation step. neighbour_search_radius (Number): Search radius (L2 norm) threshold for points to be considered as similar for clustering. Points with a spatial distance above this are not compared and have a similarity set to 1 (most dissimilar). feature_range_thresh (Number): Minimal range for which a feature is considered significant. Features which have a range less than this are ignored. Defaults to 1e-4. If falsy (None, False, 0, etc.), then no features are removed. Returns: dict: A dictionary defining a graph for serialisation (e.g. JSON or msgpack) or converting into a torch-geometric Data object where each node is the centroid (mean) of the features in a cluster. The dictionary has the following entries: - :class:`numpy.ndarray` - x: Features of each node (mean of features in a cluster). Required for torch-geometric Data. - :class:`numpy.ndarray` - edge_index: Edge index matrix defining connectivity. Required for torch-geometric Data. - :py:obj:`numpy.ndarray` - coords: Coordinates of each node within the WSI (mean of point in a cluster). Useful for visualisation over the WSI. Example: >>> rng = np.random.default_rng() >>> points = rng.random((99, 2)) * 1000 >>> features = np.array([ ... rng.random(11) * n ... for n, _ in enumerate(points) ... ]) >>> graph_dict = SlideGraphConstructor.build(points, features) """ # Remove features which do not change significantly between patches if feature_range_thresh: feature_ranges = np.max(features, axis=0) - np.min(features, axis=0) where_significant = feature_ranges > feature_range_thresh features = features[:, where_significant] # Build a kd-tree and rank neighbours according to the euclidean # distance (nearest -> farthest). kd_tree = cKDTree(points) neighbour_distances_ckd, neighbour_indexes_ckd = kd_tree.query( x=points, k=len(points), ) # Initialise an empty 1-D condensed distance matrix. # For information on condensed distance matrices see: # - scipy.spatial.distance.pdist # - scipy.cluster.hierarchy.linkage condensed_distance_matrix = np.zeros(int(len(points) * (len(points) - 1) / 2)) # Find the similarity between pairs of patches index = 0 for i in range(len(points) - 1): # Only consider neighbours which are inside the radius # (neighbour_search_radius). neighbour_distances_single_point = neighbour_distances_ckd[i][ neighbour_distances_ckd[i] < neighbour_search_radius ] neighbour_indexes_single_point = neighbour_indexes_ckd[i][ : len(neighbour_distances_single_point) ] # Called f in the paper neighbour_feature_similarities = np.exp( -lambda_f * np.linalg.norm( features[i] - features[neighbour_indexes_single_point], axis=1, ), ) # Called d in paper neighbour_distance_similarities = np.exp( -lambda_d * neighbour_distances_single_point, ) # 1 - product of similarities (1 - fd) # (1 = most un-similar 0 = most similar) neighbour_similarities = ( 1 - neighbour_feature_similarities * neighbour_distance_similarities ) # Initialise similarity of coordinate i vs all coordinates to 1 # (most un-similar). i_vs_all_similarities = np.ones(len(points)) # Set the neighbours similarity to calculated values (similarity/fd) i_vs_all_similarities[neighbour_indexes_single_point] = ( neighbour_similarities ) i_vs_all_similarities = i_vs_all_similarities[i + 1 :] condensed_distance_matrix[index : index + len(i_vs_all_similarities)] = ( i_vs_all_similarities ) index = index + len(i_vs_all_similarities) # Perform hierarchical clustering (using similarity as distance) linkage_matrix = hierarchy.linkage(condensed_distance_matrix, method="average") clusters = hierarchy.fcluster(linkage_matrix, lambda_h, criterion="distance") # Finding the xy centroid and average features for each cluster unique_clusters = list(set(clusters)) point_centroids = [] feature_centroids = [] for c in unique_clusters: (idx,) = np.where(clusters == c) # Find the xy and feature space averages of the cluster point_centroids.append(np.round(points[idx, :].mean(axis=0))) feature_centroids.append(features[idx, :].mean(axis=0)) point_centroids = np.array(point_centroids) feature_centroids = np.array(feature_centroids) adjacency_matrix = delaunay_adjacency( points=point_centroids, dthresh=connectivity_distance, ) edge_index = affinity_to_edge_index(adjacency_matrix) return { "x": feature_centroids, "edge_index": edge_index.astype(np.int64), "coordinates": point_centroids, }
[docs] @classmethod def visualise( cls: type[SlideGraphConstructor], graph: dict[str, ArrayLike], color: ArrayLike | str | Callable | None = None, node_size: Number | ArrayLike | Callable = 25, edge_color: str | ArrayLike = (0, 0, 0, 0.33), ax: Axes | None = None, ) -> Axes: """Visualise a graph. The visualisation is a scatter plot of the graph nodes and the connections between them. By default, nodes are coloured according to the features of the graph via a UMAP embedding to the sRGB color space. This can be customised by passing a color argument which can be a single color, a list of colors, or a function which takes the graph and returns a list of colors for each node. The edge color(s) can be customised in the same way. Args: graph (dict): The graph to visualise as a dictionary with the following entries: - :class:`numpy.ndarray` - x: Features of each node (mean of features in a cluster). Required - :class:`numpy.ndarray` - edge_index: Edge index matrix defining connectivity. Required - :class:`numpy.ndarray` - coordinates: Coordinates of each node within the WSI (mean of point in a cluster). Required color (np.array or str or callable): Colours of the nodes in the plot. If it is a callable, it should take a graph as input and return a numpy array of matplotlib colours. If `None` then a default function is used (UMAP on `graph["x"]`). node_size (int or np.ndarray or callable): Size of the nodes in the plot. If it is a function then it is called with the graph as an argument. edge_color (str): Colour of edges in the graph plot. ax (:class:`matplotlib.axes.Axes`): The axes which were plotted on. Returns: matplotlib.axes.Axes: The axes object to plot the graph on. Example: >>> rng = np.random.default_rng() >>> points = rng.random((99, 2)) * 1000 >>> features = np.array([ ... rng.random(11) * n ... for n, _ in enumerate(points) ... ]) >>> graph_dict = SlideGraphConstructor.build(points, features) >>> fig, ax = plt.subplots() >>> slide_dims = wsi.info.slide_dimensions >>> ax.imshow(wsi.get_thumbnail(), extent=(0, *slide_dims, 0)) >>> SlideGraphConstructor.visualise(graph_dict, ax=ax) >>> plt.show() """ from matplotlib import collections as mc # Check that the graph is valid if "x" not in graph: msg = "Graph must contain key `x`." raise ValueError(msg) if "edge_index" not in graph: msg = "Graph must contain key `edge_index`." raise ValueError(msg) if "coordinates" not in graph: msg = "Graph must contain key `coordinates`" raise ValueError(msg) if ax is None: _, ax = plt.subplots() if color is None: color = cls._umap_reducer nodes = graph["coordinates"] edges = graph["edge_index"] # Plot the edges line_segments = nodes[edges.T] edge_collection = mc.LineCollection( line_segments, colors=edge_color, linewidths=1, ) ax.add_collection(edge_collection) # Plot the nodes plt.scatter( *nodes.T, c=color(graph) if callable(color) else color, s=node_size(graph) if callable(node_size) else node_size, zorder=2, ) return ax