Using Foundation Models in TIAToolbox¶
Click to open in: [GitHub][Colab]
About this demo¶
In this example, we demonstrate how to extract features from a pre-trained PyTorch foundation model using the PyTorch Image Models (timm
module). This model operates outside of TIAToolbox, but we will utilize the WSI inference engines provided by TIAToolbox to achieve this.
GPU or CPU runtime¶
Processes in this notebook can be accelerated by using a GPU. Therefore, whether you are running this notebook on your system or Colab, you need to check and specify appropriate device e.g., “cuda” or “cpu” whether you are using GPU or CPU. In Colab, you need to make sure that the runtime type is set to GPU in the “Runtime→Change runtime type→Hardware accelerator”. If you are not using GPU, consider changing the device
flag to cpu
value, otherwise, some errors will be raised when running the following cells.
device = "cuda" # Choose appropriate device
![ -d tmp ] && ( echo "deleting tmp directory"; rm -rf tmp )
Downloading the required files¶
We download, over the internet, a histology whole slide image of cancerous breast tissue samples to show how the feature extractor works. Download is needed once in each Colab session.
In Colab, if you click the files icon (see below) in the vertical toolbar on the left hand side then you can see all the files which the code in this notebook can access. The data will appear here when it is downloaded.
global_save_dir = Path("tmp/")
# File name of WSI
wsi_path = global_save_dir / "sample_wsi.svs"
logger.info("Download has started. Please wait...")
# Downloading and unzip a sample whole-slide image
download_data(
"https://tiatoolbox.dcs.warwick.ac.uk/sample_wsis/TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F.svs",
wsi_path,
)
logger.info("Download is complete.")
|2024-12-02|14:52:58.760| [INFO] Download has started. Please wait...
|2024-12-02|14:52:58.763| [INFO] Download is complete.
Feature extraction with foundation models¶
In this section, we extract deep features using foudnation models. These features could be used to train a downstream model. These models require access to HuggingFace, so please ensure that you have a University linked account if you are using this for research.
We then sign in to HuggingFace.
notebook_login()
Token has not been saved to git credential helper.
Next, we create the model using pre-trained network architectures. FFor other models available in the timm
library, such as computational pathology-specific foundation models, use TimmBackbone
(e.g. EfficientNet, UNI, Prov-GigaPath, H-optimus-0). For standard CNN model architectures available in PyTorch (e.g., AlexNet, ResNet, DenseNet, Inception), use CNNBackbone
.
In the example below, we use the UNI
model. However, this can be changed to other computational pathology-specific foundation models by modifying the backbone
argument to prov-gigapath
or H-optimus-0
. When using foundation models, please ensure to cite the corresponding paper and follow the specific access requirements. Certain models require users to link their GitHub and HuggingFace accounts and have their model access request accepted, subject to certain conditions, such as for UNI and Prov-GigaPath. Other models, such as H-optimius-0, have no such requirements.
We also provide an IOSegmentorConfig
specifying the input/output patch shape and resolution for processing and saving the output.
Finally, we use the DeepFeatureExtractor
to extract these deep features, per patch, from the WSI. A mask is automatically generated to guide the patch extraction process and ignore the background.
model = TimmBackbone(backbone="UNI", pretrained=True)
wsi_ioconfig = IOSegmentorConfig(
input_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_input_shape=[224, 224],
output_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_output_shape=[224, 224],
stride_shape=[224, 224],
)
# create the feature extractor and run it on the WSI
extractor = DeepFeatureExtractor(
model=model,
auto_generate_mask=True,
batch_size=32,
num_loader_workers=4,
num_postproc_workers=4,
)
out = extractor.predict(
imgs=[wsi_path],
mode="wsi",
ioconfig=wsi_ioconfig,
save_dir=global_save_dir / "wsi_features",
device=device,
)
|2024-12-02|14:53:32.136| [INFO] Loading pretrained weights from Hugging Face hub (MahmoodLab/UNI)
|2024-12-02|14:53:32.984| [WARNING] GPU is not compatible with torch.compile. Compatible GPUs include NVIDIA V100, A100, and H100. Speedup numbers may be lower than expected.
|2024-12-02|14:53:33.419| [WARNING] Read: Scale > 1.This means that the desired resolution is higher than the WSI baseline (maximum encoded resolution). Interpolation of read regions may occur.
Process Batch: 100%|##########################| 630/630 [05:18<00:00, 1.98it/s]
|2024-12-02|14:58:52.845| [INFO] Finish: 0
|2024-12-02|14:58:52.846| [INFO] --Input: tmp/sample_wsi.svs
|2024-12-02|14:58:52.847| [INFO] --Output: /newdata/u1973415/TIAToolbox/tiatoolbox/examples/tmp/wsi_features/0
Post-processing and Visualization¶
These deep features could be used to train a downstream model. However, in this section, we will use UMAP
reduction to visualize the features in RGB space to gain some intuition about what the features represent. Points labeled with similar colors should have similar features, allowing us to check if the features naturally separate into different tissue regions when we overlay the UMAP
reduction on the WSI thumbnail.
The method returns a list of paths to its inputs and the processed outputs saved on disk. These paths can be used to load the results for further processing and visualization.
# First we define a function to calculate the umap reduction
def umap_reducer(x: np.ndarray, dims: int = 3, nns: int = 10) -> np.ndarray:
"""UMAP reduction of the input data."""
reducer = umap.UMAP(
n_neighbors=nns,
n_components=dims,
metric="manhattan",
spread=0.5,
random_state=2,
)
reduced = reducer.fit_transform(x)
reduced -= reduced.min(axis=0)
reduced /= reduced.max(axis=0)
return reduced
# load the features output by our feature extractor
pos = np.load(global_save_dir / "wsi_features" / "0.position.npy")
feats = np.load(global_save_dir / "wsi_features" / "0.features.0.npy")
pos = pos / 8 # as we extracted at 0.5mpp, and we are overlaying on a thumbnail at 4mpp
# reduce the features into 3 dimensional (rgb) space
reduced = umap_reducer(feats)
overview_resolution = (
4 # the resolution in which we desire to merge and visualize the patch predictions
)
# the unit of the `resolution` parameter. Can be "power", "level", "mpp", or "baseline"
overview_unit = "mpp"
wsi = WSIReader.open(wsi_path)
wsi_overview = wsi.slide_thumbnail(resolution=overview_resolution, units=overview_unit)
plt.figure(), plt.imshow(wsi_overview)
plt.axis("off")
# plot the feature map reduction
plt.figure()
plt.imshow(wsi_overview)
plt.scatter(pos[:, 0], pos[:, 1], c=reduced, s=1, alpha=0.5)
plt.axis("off")
plt.title("UMAP feature embedding")
Text(0.5, 1.0, 'UMAP feature embedding')


We observe that the feature map from our feature encoder captures similar information about the tissue types in the WSI, as different tissue types appear in distinct colors. This serves as a good sanity check, confirming that our models are functioning as expected.