PatchPredictor

class PatchPredictor(batch_size=8, num_loader_workers=0, model=None, pretrained_model=None, pretrained_weights=None, verbose=True)[source]

Patch-level predictor.

Parameters
  • model (nn.Module) – Use externally defined PyTorch model for prediction with. weights already loaded. Default is None. If provided, pretrained_model argument is ignored.

  • pretrained_model (str) – Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the docs By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the pretrained_weights argument. Argument is case-insensitive.

  • pretrained_weights (str) –

    Path to the weight of the corresponding pretrained_model.

    >>> predictor = PatchPredictor(
    ...    pretrained_model="resnet18-kather100k",
    ...    pretrained_weights="resnet18_local_weight")
    

  • batch_size (int) – Number of images fed into the model each time.

  • num_loader_workers (int) – Number of workers to load the data. Take note that they will also perform preprocessing.

  • verbose (bool) – Whether to output logging information.

img

A HWC image or a path to WSI.

Type

str or pathlib.Path or numpy.ndarray

mode

Type of input to process. Choose from either patch, tile or wsi.

Type

str

model

Defined PyTorch model.

Type

nn.Module

pretrained_model

Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the docs By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the pretrained_weights argument. Argument is case insensitive.

Type

str

batch_size

Number of images fed into the model each time.

Type

int

num_loader_workers

Number of workers used in torch.utils.data.DataLoader.

Type

int

verbose

Whether to output logging information.

Type

bool

Examples

>>> # list of 2 image patches as input
>>> data = [img1, img2]
>>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k")
>>> output = predictor.predict(data, mode='patch')
>>> # array of list of 2 image patches as input
>>> data = np.array([img1, img2])
>>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k")
>>> output = predictor.predict(data, mode='patch')
>>> # list of 2 image patch files as input
>>> data = ['path/img.png', 'path/img.png']
>>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k")
>>> output = predictor.predict(data, mode='patch')
>>> # list of 2 image tile files as input
>>> tile_file = ['path/tile1.png', 'path/tile2.png']
>>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k")
>>> output = predictor.predict(tile_file, mode='tile')
>>> # list of 2 wsi files as input
>>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs']
>>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k")
>>> output = predictor.predict(wsi_file, mode='wsi')

Methods

merge_predictions

Merge patch-level predictions to form a 2-dimensional prediction map.

predict

Make a prediction for a list of input data.

static merge_predictions(img, output, resolution=None, units=None, postproc_func=None, return_raw=False)[source]

Merge patch-level predictions to form a 2-dimensional prediction map.

#! Improve how the below reads. The prediction map will contain values from 0 to N, where N is the number of classes. Here, 0 is the background which has not been processed by the model and N is the number of classes predicted by the model.

Parameters
  • img (str or pathlib.Path or numpy.ndarray) – A HWC image or a path to WSI.

  • output (dict) – Ouput generated by the model.

  • resolution (float) – Resolution of merged predictions.

  • units (str) – Units of resolution used when merging predictions. This must be the same units used when processing the data.

  • postproc_func (callable) – A function to post-process raw prediction from model. By default, internal code uses the np.argmax function.

  • return_raw (bool) – Return raw result without applying the postproc_func on the assembled image.

Returns

Merged predictions as a 2D array.

Return type

prediction_map (ndarray)

Examples

>>> # pseudo output dict from model with 2 patches
>>> output = {
...     'resolution': 1.0,
...     'units': 'baseline',
...     'probabilities': [[0.45, 0.55], [0.90, 0.10]],
...     'predictions': [1, 0],
...     'coordinates': [[0, 0, 2, 2], [2, 2, 4, 4]],
... }
>>> merged = PatchPredictor.merge_predictions(
...         np.zeros([4, 4]),
...         output,
...         resolution=1.0,
...         units='baseline'
... )
>>> merged
array([[2, 2, 0, 0],
       [2, 2, 0, 0],
       [0, 0, 1, 1],
       [0, 0, 1, 1]])
predict(imgs, masks=None, labels=None, mode='patch', return_probabilities=False, return_labels=False, on_gpu=True, ioconfig=None, patch_input_shape=None, stride_shape=None, resolution=None, units=None, merge_predictions=False, save_dir=None, save_output=False)[source]

Make a prediction for a list of input data.

Parameters
  • imgs (list, ndarray) – List of inputs to process. When using patch mode, the input must be either a list of images, a list of image file paths or a numpy array of an image list. When using tile or wsi mode, the input must be a list of file paths.

  • masks (list) – List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if they are within a masked area. If not provided, then a tissue mask will be automatically generated for whole-slide images or the entire image is processed for image tiles.

  • labels – List of labels. If using tile or wsi mode, then only a single label per image tile or whole-slide image is supported.

  • mode (str) – Type of input to process. Choose from either patch, tile or wsi.

  • return_probabilities (bool) – Whether to return per-class probabilities.

  • return_labels (bool) – Whether to return the labels with the predictions.

  • on_gpu (bool) – whether to run model on the GPU.

  • patch_input_shape (tuple) – Size of patches input to the model. Patches are at requested read resolution, not with respect to level 0, and must be positive.

  • stride_shape (tuple) – Stride using during tile and WSI processing. Stride is at requested read resolution, not with respect to to level 0, and must be positive. If not provided, stride_shape=patch_input_shape.

  • resolution (float) – Resolution used for reading the image. Please see WSIReader for details.

  • units (str) –

    Units of resolution used for reading the image. Choose from either level, power or mpp. Please see

    WSIReader for details.

  • merge_predictions (bool) – Whether to merge the predictions to form a 2-dimensional map. This is only applicable for mode=’wsi’ or mode=’tile’.

  • save_dir (str or pathlib.Path) – Output directory when processing multiple tiles and whole-slide images. By default, it is folder output where the running script is invoked.

  • save_output (bool) – Whether to save output for a single file. default=False

  • ioconfig (Optional[tiatoolbox.models.engine.patch_predictor.IOPatchPredictorConfig]) –

Returns

Model predictions of the input dataset. If

multiple image tiles or whole-slide images are provided as input, or save_output is True, then results are saved to save_dir and a dictionary indicating save location for each input is return. The dict has following format:

  • img_path: path of the input image.
    • raw: path to save location for raw prediction, saved in .json.

    • merged: path to .npy contain merged predictions if merge_predictions is True.

Return type

output (ndarray, dict)

Examples

>>> wsis = ['wsi1.svs', 'wsi2.svs']
>>> predictor = PatchPredictor(
...                 pretrained_model="resnet18-kather100k")
>>> output = predictor.predict(wsis, mode="wsi")
>>> output.keys()
['wsi1.svs', 'wsi2.svs']
>>> output['wsi1.svs']
{'raw': '0.raw.json', 'merged': '0.merged.npy}
>>> output['wsi2.svs']
{'raw': '1.raw.json', 'merged': '1.merged.npy}