Multi-Task Segmentation Models

Click to open in: [GitHub][Colab]

About this demo

In image processing it may be desirable to perform multiple tasks simultaneously with the same model. This may not only have the advantage of decreasing processing time (in comparison to doing the tasks sequentially or separately), but it may also provide complementary information to improve the accuracy of each task. Thus, some multi-task models may actually provide better results than separate models performing the tasks independently. For example, within histology, certain nucleus types may be more prominent in specific tissue regions - e.g. epithelial cells within the epithelium vs. inflammatory cells within the connective tissue. Thus, a model that learns to semantically segment tissue regions at the same time as segmenting/classifying nuclear instances may provide superior results.

In this notebook, we demonstrate how to use HoVer-Net+, a subclass of HoVer-Net, for the semantic segmentation of intra-epithelial layers, whilst simultaneously segmenting/classifying nuclear instances (epithelial, inflammatory etc.). This model has been trained on a private cohort of oral epithelial dysplasia cases (not publicly available). For more information on these individual tasks (e.g. semantic segmentation and nuclear instance segmentation) please see their respective example notebooks. We will first show how this pretrained model, incorporated in TIAToolbox, can be used for multi-task inference on large image patches, before explaining how to use your pretrained model in the TIAToolbox model inference pipeline to do prediction on a set of WSIs.

Downloading the required files

We download, over the internet, image files used for the purpose of this notebook. In particular, we download a histology tile and a whole slide image of cancerous breast tissue samples to show how semantic segmentation models work. Pretrained weights of a Pytorch model and a small WSI are downloaded to illustrate how to incorporate your own models in the existing TIAToolbox segmentation tool.

In Colab, if you click the files icon (see below) in the vertical toolbar on the left hand side then you can see a list of files that have been downloaded and are thus directly accessible from this notebook.

image.png

# These file name are used for
img_file_name = global_save_dir / "sample_tile.png"
wsi_file_name = global_save_dir / "sample_wsi.svs"

logger.info("Download has started. Please wait...")

# Downloading sample image tile
download_data(
    "https://tiatoolbox.dcs.warwick.ac.uk/sample_imgs/tcga_hnscc.png",
    img_file_name,
)

# Downloading sample whole-slide image
download_data(
    "https://tiatoolbox.dcs.warwick.ac.uk/testdata/models/samples/wsi2_4k_4k.svs",
    wsi_file_name,
)

logger.info("Download is complete.")
Hide code cell output
|2023-12-15|16:41:43.744| [INFO] Download has started. Please wait...
|2023-12-15|16:41:50.378| [INFO] Download is complete.

Multi-Task Segmentation using TIAToolbox pretrained models

In this section, we investigate the use of multi-task models that have already been trained on specific tasks and incorporated in the TIAToolbox. We will particularly focus on HoVer-Net+. In the future we plan to incorporate more models. HoVer-Net+ has an encoder-decoder framework, consisting of multiple decoder branches that allows it to perform multiple tasks simultaneouly. It is therefore assumed that the representation of the input image learned by the encoder, is useful for both donwstream tasks. This model performs two tasks:

  1. It segments out nuclear instances from the given input, whilst assigning them to one of two classes: epithelial nuclei or other nuclei (connective/immune cell etc.).

  2. The model semantically segments tissue regions, classifying each indiviudal pixel of a image tile or WSI, as one of five tissue types:

  • (Superficial) keratin layer

  • Epithelial layer

  • Basal epithelial layer

  • Other (connective tissue etc.)

  • Background

Note, the first three tissue classes can be considered as the three layers of the epithelium.

More information on the model and the dataset used for training can be found here (Shephard et al., “Simultaneous Nuclear Instance and Layer Segmentation in Oral Epithelial Dysplasia”) and the data is available for download using this link.

Inference on tiles

Similarly to the semantic segmentation functionality of the tiatoolbox, the multi-task segmentation module works on both image tiles and structured WSIs. First, we need to create an instance of the MultiTaskSegmentor class which controls the whole process of multi-task segmentation and then use it to do prediction on the input image(s):

# Tile prediction
multi_segmentor = MultiTaskSegmentor(
    pretrained_model="hovernetplus-oed",
    num_loader_workers=0,
    num_postproc_workers=0,
    batch_size=4,
)

tile_output = multi_segmentor.predict(
    [img_file_name],
    save_dir=global_save_dir / "sample_tile_results",
    mode="tile",
    on_gpu=ON_GPU,
    crash_on_exception=True,
)
Hide code cell output
|2023-12-15|16:42:22.484| [WARNING] WSIPatchDataset only reads image tile at `units="baseline"`. Resolutions will be converted to baseline value.
|2023-12-15|16:42:23.316| [WARNING] WSIPatchDataset only reads image tile at `units="baseline"`. Resolutions will be converted to baseline value.
|2023-12-15|16:42:23.478| [WARNING] Raw data is None.
|2023-12-15|16:42:23.480| [WARNING] Unknown scale (no objective_power or mpp)
Process Batch:   0%|                                     | 0/43 [00:00<?, ?it/s]|2023-12-15|16:42:23.644| [WARNING] Raw data is None.
Process Batch: 100%|############################| 43/43 [00:28<00:00,  1.51it/s]
Process Batch: 100%|##############################| 7/7 [00:04<00:00,  1.64it/s]
Process Batch: 100%|##############################| 7/7 [00:04<00:00,  1.67it/s]
Process Batch: 100%|##############################| 1/1 [00:00<00:00,  1.56it/s]
Process Batch: 100%|##############################| 7/7 [00:04<00:00,  1.66it/s]
Process Batch: 100%|##############################| 1/1 [00:00<00:00,  1.56it/s]
Process Batch: 100%|##############################| 7/7 [00:04<00:00,  1.66it/s]
Process Batch: 100%|##############################| 1/1 [00:00<00:00,  1.54it/s]
Process Batch: 100%|##############################| 3/3 [00:01<00:00,  2.11it/s]
|2023-12-15|16:43:24.906| [INFO] Finish: 0
|2023-12-15|16:43:24.910| [INFO] --Input: tmp/sample_tile.png
|2023-12-15|16:43:24.911| [INFO] --Output: /content/tmp/sample_tile_results/0

There we go! With only two lines of code, thousands of images can be processed automatically. There are many parameters associated with MultiTaskSegmentor. Please see the MultiTaskSegmentor notebook and documentation for more information on these parameters. Here we explain only the ones mentioned above:

  • pretrain_model: specifies the name of the pretrained model included in the TIAToolbox (case sensitive). We are expanding our library of models pretrained on various segmentation tasks. You can find a complete list of available pretrained models here. In this example, we use the "hovernetplus-oed" pretrained model, which is the HoVer-Net+ model trained on a private cohort of oral epithelial dysplasia cases.

  • num_loader_workers: as the name suggests, this parameter controls the number of CPU cores (workers) that are responsible for the “loading of network input” process, which consists of patch extraction, preprocessing, etc.

  • batch_size: controls the batch size, or the number of input instances to the network in each iteration. If you use a GPU, be careful not to set the batch_size larger than the GPU memory limit would allow.

After the multi_segmentor has been instantiated as a semantic segmentation engine with our desired pretrained model, one can call the predict method to do inference on a list of input images (or WSIs). The predict function automatically processes all the images on the input list and saves the results on the disk. The process usually comprises patch extraction (because the whole tile or WSI won’t fit into limited GPU memory), preprocessing, model inference, post-processing and prediction assembly. Here are some important parameters that should be set to use the predict method properly:

  • imgs: List of inputs to be processed. Note that items in the list should be paths to the inputs stored on the disk.

  • save_dir: Path to the main folder in which prediction results for each input will be stored separately.

  • mode: the mode of inference which can be set to either 'tile' or 'wsi', for plain histology images or structured whole slides images, respectively.

  • on_gpu: can be either True or False to dictate running the computations on GPU or CPU.

  • crash_on_exception: If set to True, the running loop will crash if there is an error during processing a WSI. Otherwise, the loop will move on to the next image (wsi) for processing. We suggest that you first make sure that prediction is working as expected by testing it on a couple of inputs and then set this flag to False to process large cohorts of inputs.

In the output, the prediction method returns a list of the paths to its inputs and to the processed outputs saved on the disk. This can be used for loading the results for processing and visualisation.

Now that the prediction has finished, let’s use the paths in tile_output to load and examine the predictions.

inst_dict = joblib.load(f"{tile_output[0][1]}.0.dat")
layer_map = np.load(f"{tile_output[0][1]}.1.npy")
logger.info("Number of detected nuclei: %d.", len(inst_dict))
logger.info(
    "Processed prediction dimensions: (%d, %d)",
    layer_map.shape[0],
    layer_map.shape[1],
)
logger.info(
    "Prediction method output is: %s, %s,",
    tile_output[0][0],
    tile_output[0][1],
)

# showing the predicted semantic segmentation
tile = imread(img_file_name)
logger.info(
    "Input image dimensions: (%d, %d, %d)",
    tile.shape[0],
    tile.shape[1],
    tile.shape[2],
)

semantic_color_dict = {
    0: ("Background", (0, 0, 0)),
    1: ("Other", (255, 165, 0)),
    2: ("Basal", (255, 0, 0)),
    3: ("Epithelial", (0, 255, 0)),
    4: ("Keratin", (0, 0, 255)),
}

inst_color_dict = {
    0: ("Background", (0, 0, 0)),
    1: ("Other", (255, 165, 0)),
    2: ("Epithelium", (255, 0, 0)),
}

# Create the overlay image
overlaid_predictions_inst = overlay_prediction_contours(
    canvas=tile,
    inst_dict=inst_dict,
    draw_dot=False,
    type_colours=inst_color_dict,
    line_thickness=2,
)

# Create the semantic segmentation image (in colour)
semantic_map = np.zeros((layer_map.shape[0], layer_map.shape[1], 3)).astype("uint8")
for idx, (_label, color) in semantic_color_dict.items():
    semantic_map[layer_map == idx] = color

# showing processed results alongside the original images
fig2 = plt.figure()
ax1 = plt.subplot(1, 3, 1), plt.imshow(tile), plt.axis("off"), plt.title("Tile")
ax2 = (
    plt.subplot(1, 3, 2),
    plt.imshow(semantic_map),
    plt.axis("off"),
    plt.title("Semantic segm."),
)
ax3 = (
    plt.subplot(1, 3, 3),
    plt.imshow(overlaid_predictions_inst),
    plt.axis("off"),
    plt.title("Instance segm."),
)
|2023-12-15|16:43:25.698| [INFO] Number of detected nuclei: 3351.
|2023-12-15|16:43:25.702| [INFO] Processed prediction dimensions: (2000, 2000)
|2023-12-15|16:43:25.705| [INFO] Prediction method output is: tmp/sample_tile.png, /content/tmp/sample_tile_results/0,
|2023-12-15|16:43:25.825| [INFO] Input image dimensions: (2000, 2000, 3)
../../_images/004c48c91a0c8abdd4505f4585f0174823272930216dd02deab5f1562bc34585.png

Above, we display the raw image tile, along with the semantic segmentation and nuclear instance segmentation predictions. In the instance prediction map, the contours of the nuclear instances are drawn. Red contours represent epithelial nuclei, whilst orange contours represent other nuclei. In the semantic segmentation prediction map, blue represents the keratin layer, green the epithelial layer, red the basal epithelial layer, orange is other tissue, and finally black is background.

Inference on WSIs

The next step is to use TIAToolbox’s embedded model for region segmentation in a whole slide image. The process is quite similar to what we have done for tiles. Here we introduce some important parameters that should be considered when configuring the segmentor for WSI inference. For this example we infer HoVer-Net+ on a small breast tissue WSI.

multi_segmentor = MultiTaskSegmentor(
    pretrained_model="hovernetplus-oed",
    num_loader_workers=0,
    num_postproc_workers=0,
    batch_size=4,
    auto_generate_mask=False,
)

# WSI prediction
wsi_output = multi_segmentor.predict(
    imgs=[wsi_file_name],
    masks=None,
    save_dir=global_save_dir / "sample_wsi_results/",
    mode="wsi",
    on_gpu=ON_GPU,
    crash_on_exception=True,
)
Process Batch: 100%|############################| 43/43 [00:30<00:00,  1.39it/s]
Process Batch: 100%|##############################| 7/7 [00:04<00:00,  1.49it/s]
Process Batch: 100%|##############################| 7/7 [00:04<00:00,  1.53it/s]
Process Batch: 100%|##############################| 1/1 [00:00<00:00,  1.46it/s]
Process Batch: 100%|##############################| 7/7 [00:04<00:00,  1.51it/s]
Process Batch: 100%|##############################| 1/1 [00:00<00:00,  1.44it/s]
Process Batch: 100%|##############################| 7/7 [00:04<00:00,  1.54it/s]
Process Batch: 100%|##############################| 1/1 [00:00<00:00,  1.45it/s]
Process Batch: 100%|##############################| 3/3 [00:01<00:00,  1.95it/s]
|2023-12-15|16:45:26.940| [INFO] Finish: 0
|2023-12-15|16:45:26.941| [INFO] --Input: tmp/sample_wsi.svs
|2023-12-15|16:45:26.943| [INFO] --Output: /content/tmp/sample_wsi_results/0

Note the only differences made here are:

  1. Adding auto_generate_mask=False to the MultiTaskSegmentor. If True and if no masks input is provided to the predict function, the toolbox automatically extracts tissue masks from WSIs.

  2. Setting mode='wsi' in the predict function indicates that we are predicting region segmentations for inputs in the form of WSIs.

  3. masks=None in the predict function: the masks argument is a list of paths to the desired image masks. Patches from imgs are only processed if they are within a masked area of their corresponding masks. If not provided (masks=None), then either a tissue mask is automatically generated for whole-slide images or the entire image is processed as a collection of image tiles.

The above cell might take a while to process, especially if you have set ON_GPU=False. The processing time depends on the size of the input WSI and the selected resolution. Here, we have not specified any values and we use the assumed input resolution (20x) of HoVer-Net+.

logger.info(
    "Prediction method output is: %s, %s,",
    wsi_output[0][0],
    wsi_output[0][1],
)
inst_dict = joblib.load(f"{wsi_output[0][1]}.0.dat")
layer_map = np.load(f"{wsi_output[0][1]}.1.npy")
logger.info(
    "Processed prediction dimensions: (%d, %d)",
    layer_map.shape[0],
    layer_map.shape[1],
)


# [WSI overview extraction]
# Now reading the WSI to extract it's overview
wsi = WSIReader.open(wsi_file_name)
logger.info(
    "WSI original dimensions: (%d, %d)",
    wsi.info.slide_dimensions[0],
    wsi.info.slide_dimensions[1],
)

# Reading the whole slide in the highest resolution as a plane image
wsi_overview = wsi.slide_thumbnail(resolution=0.5, units="mpp")
logger.info(
    "WSI overview dimensions: (%d, %d, %d)",
    wsi_overview.shape[0],
    wsi_overview.shape[1],
    wsi_overview.shape[2],
)

semantic_color_dict = {
    0: ("Background", (0, 0, 0)),
    1: ("Other", (255, 165, 0)),
    2: ("Basal", (255, 0, 0)),
    3: ("Epithelial", (0, 255, 0)),
    4: ("Keratin", (0, 0, 255)),
}

inst_color_dict = {
    0: ("Background", (0, 0, 0)),
    1: ("Other", (255, 165, 0)),
    2: ("Epithelium", (255, 0, 0)),
}

# Create the instance segmentation overlay map
# using the `overlay_prediction_contours` helper function
overlaid_inst_pred = overlay_prediction_contours(
    canvas=wsi_overview,
    inst_dict=inst_dict,
    draw_dot=False,
    type_colours=inst_color_dict,
    line_thickness=4,
)


fig = (
    plt.figure(),
    plt.imshow(wsi_overview),
    plt.axis("off"),
    plt.title("Large Visual Field"),
)
fig = (
    plt.figure(),
    plt.imshow(overlaid_inst_pred),
    plt.axis("off"),
    plt.title("Instance Segmentation Overlaid"),
)
# Create semantic segmentation overlay map
# using the `overlay_patch_prediction` helper function
fig = plt.figure()
overlaid_semantic_pred = overlay_prediction_mask(
    wsi_overview,
    layer_map,
    alpha=0.5,
    label_info=semantic_color_dict,
    return_ax=True,
)
plt.title("Semantic Segmentation Overlaid")
|2023-12-15|16:48:15.100| [INFO] Prediction method output is: tmp/sample_wsi.svs, /content/tmp/sample_wsi_results/0,
|2023-12-15|16:48:15.761| [INFO] Processed prediction dimensions: (2016, 2016)
|2023-12-15|16:48:15.780| [INFO] WSI original dimensions: (4000, 4000)
|2023-12-15|16:48:17.773| [INFO] WSI overview dimensions: (2016, 2016, 3)
Text(0.5, 1.0, 'Semantic Segmentation Overlaid')
../../_images/186d42cd25880709dc59af06cab9fa1f30dc9b887e5f065506b7a709bcf7f7b2.png ../../_images/74108e38297fbff65f39393bc58bbb16c53392eb0a3be89ab6dae18e49bdc968.png
<Figure size 1920x1440 with 0 Axes>
../../_images/9636d6fed30f25502d3421fecfd16e3cd94ed14e5e0e141a293e3520dcff4b81.png

As you can see above, our method first creates the semantic and instance segmentation maps (and corresponding dictionaries). Then, in order to visualise the segmentation prediction on the tissue image, we read the processed WSI and extract its overview. Please note that HoVer-Net+ assumes a base resolution of 0.50 mpp, whilst the baseline resolution of the input WSI was 0.252 mpp. Thus, the output of HoVer-Net+ has shape 2016 x 2016, compared to the input 4000 x 4000. The overview image is therefore also extracted at this resolution. We did not edit this into the processing of the tile as the tile processing assumes the correct input resolution at baseline.

We used the overlay_prediction_mask helper function of the TIAToolbox to overlay the predicted semantic segmentation map on the overview image and depict it with a colour legend. We also used the overlay_prediction_contours helper function of the TIAToolbox to overlay the predicted instance segmentation map on the overview image.

In the instance prediction map, the contours of the nuclear instances are drawn. Red contours represent epithelial nuclei, whilst orange contours represent other nuclei. In the semantic segmentation prediction map, blue represents the keratin layer, green the epithelial layer, red the basal epithelial layer, orange is other tissue, and finally black is background. Please note that the semantic segmentation output seen above is relatively spurious in areas as a result of the model being trained on head and neck tissue only (specifically within oral epithelial dysplasia). However, the WSI inferred on is from breast tissue and thus the model does not generalise well to this task. We have chosen to include this WSI, despite the results not being perfect, for demonstration purposes.

In summary, it is very easy to use pretrained models in the TIAToolbox to do predefined tasks. In fact, you don’t even need to set any parameters related to a model’s input/output when you decide to work with one of TIAToolbox’s pretrained models (they will be set automatically, based on their optimal values).

Feel free to play around with the parameters, models, and experiment with new images (just remember to run the first cell of this notebook again, so that the created folders for the current examples are removed. Alternatively, change the save_dir parameters in new calls of predict function). Currently, we are extending our collection of pre-trained models. To keep track of them, make sure to follow our releases. You can also check here. Furthermore, if you want to use your own pretrained model for semantic segmentation (or any other pixel-wise prediction models) in the TIAToolbox framework, you can follow the instructions in our example notebook on advanced model techniques to gain some insights and guidance. We welcome any trained model in computational pathology (in any task) for addition to TIAToolbox. If you have such a model (in Pytorch) and want to contribute, please contact us or simply create a PR on our Github page.