PatchPredictor¶
- class PatchPredictor(batch_size=8, num_loader_workers=0, model=None, pretrained_model=None, pretrained_weights=None, *, verbose=True)[source]¶
Patch level predictor.
The models provided by tiatoolbox should give the following results:
PatchPredictor performance on the Kather100K dataset [1]¶ Model name
F1score
alexnet-kather100k
0.965
resnet18-kather100k
0.990
resnet34-kather100k
0.991
resnet50-kather100k
0.989
resnet101-kather100k
0.989
resnext50_32x4d-kather100k
0.992
resnext101_32x8d-kather100k
0.991
wide_resnet50_2-kather100k
0.989
wide_resnet101_2-kather100k
0.990
densenet121-kather100k
0.993
densenet161-kather100k
0.992
densenet169-kather100k
0.992
densenet201-kather100k
0.991
mobilenet_v2-kather100k
0.990
mobilenet_v3_large-kather100k
0.991
mobilenet_v3_small-kather100k
0.992
googlenet-kather100k
0.992
PatchPredictor performance on the PCam dataset [2]¶ Model name
F1score
alexnet-pcam
0.840
resnet18-pcam
0.888
resnet34-pcam
0.889
resnet50-pcam
0.892
resnet101-pcam
0.888
resnext50_32x4d-pcam
0.900
resnext101_32x8d-pcam
0.892
wide_resnet50_2-pcam
0.901
wide_resnet101_2-pcam
0.898
densenet121-pcam
0.897
densenet161-pcam
0.893
densenet169-pcam
0.895
densenet201-pcam
0.891
mobilenet_v2-pcam
0.899
mobilenet_v3_large-pcam
0.895
mobilenet_v3_small-pcam
0.890
googlenet-pcam
0.867
- Parameters:
model (nn.Module) – Use externally defined PyTorch model for prediction with. weights already loaded. Default is None. If provided, pretrained_model argument is ignored.
pretrained_model (str) – Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the docs By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the pretrained_weights argument. Argument is case-insensitive.
pretrained_weights (str) –
Path to the weight of the corresponding pretrained_model.
>>> predictor = PatchPredictor( ... pretrained_model="resnet18-kather100k", ... pretrained_weights="resnet18_local_weight")
batch_size (int) – Number of images fed into the model each time.
num_loader_workers (int) – Number of workers to load the data. Take note that they will also perform preprocessing.
verbose (bool) – Whether to output logging information.
- img¶
A HWC image or a path to WSI.
- Type:
str
orpathlib.Path
ornumpy.ndarray
- model¶
Defined PyTorch model.
- Type:
nn.Module
- pretrained_model¶
Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the docs By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the pretrained_weights argument. Argument is case insensitive.
- Type:
Examples
>>> # list of 2 image patches as input >>> data = [img1, img2] >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch')
>>> # array of list of 2 image patches as input >>> data = np.array([img1, img2]) >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch')
>>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch')
>>> # list of 2 image tile files as input >>> tile_file = ['path/tile1.png', 'path/tile2.png'] >>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k") >>> output = predictor.predict(tile_file, mode='tile')
>>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] >>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k") >>> output = predictor.predict(wsi_file, mode='wsi')
References
[1] Kather, Jakob Nikolas, et al. “Predicting survival from colorectal cancer histology slides using deep learning: A retrospective multicenter study.” PLoS medicine 16.1 (2019): e1002730.
[2] Veeling, Bastiaan S., et al. “Rotation equivariant CNNs for digital pathology.” International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2018.
Initialize
PatchPredictor
.Methods
Merge patch level predictions to form a 2-dimensional prediction map.
Make a prediction for a list of input data.
- static merge_predictions(img, output, resolution=None, units=None, postproc_func=None, *, return_raw=False)[source]¶
Merge patch level predictions to form a 2-dimensional prediction map.
#! Improve how the below reads. The prediction map will contain values from 0 to N, where N is the number of classes. Here, 0 is the background which has not been processed by the model and N is the number of classes predicted by the model.
- Parameters:
img (
str
orpathlib.Path
ornumpy.ndarray
) – A HWC image or a path to WSI.output (dict) – Output generated by the model.
resolution (Resolution) – Resolution of merged predictions.
units (Units) – Units of resolution used when merging predictions. This must be the same units used when processing the data.
postproc_func (callable) – A function to post-process raw prediction from model. By default, internal code uses the np.argmax function.
return_raw (bool) – Return raw result without applying the postproc_func on the assembled image.
- Returns:
Merged predictions as a 2D array.
- Return type:
Examples
>>> # pseudo output dict from model with 2 patches >>> output = { ... 'resolution': 1.0, ... 'units': 'baseline', ... 'probabilities': [[0.45, 0.55], [0.90, 0.10]], ... 'predictions': [1, 0], ... 'coordinates': [[0, 0, 2, 2], [2, 2, 4, 4]], ... } >>> merged = PatchPredictor.merge_predictions( ... np.zeros([4, 4]), ... output, ... resolution=1.0, ... units='baseline' ... ) >>> merged ... array([[2, 2, 0, 0], ... [2, 2, 0, 0], ... [0, 0, 1, 1], ... [0, 0, 1, 1]])
- predict(imgs, masks=None, labels=None, mode='patch', ioconfig=None, patch_input_shape=None, stride_shape=None, resolution=None, units=None, device='cpu', *, return_probabilities=False, return_labels=False, merge_predictions=False, save_dir=None, save_output=False)[source]¶
Make a prediction for a list of input data.
- Parameters:
imgs (list, ndarray) – List of inputs to process. when using patch mode, the input must be either a list of images, a list of image file paths or a numpy array of an image list. When using tile or wsi mode, the input must be a list of file paths.
masks (list) – List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if they are within a masked area. If not provided, then a tissue mask will be automatically generated for whole-slide images or the entire image is processed for image tiles.
labels (list | None) – List of labels. If using tile or wsi mode, then only a single label per image tile or whole-slide image is supported.
mode (str) – Type of input to process. Choose from either patch, tile or wsi.
return_probabilities (bool) – Whether to return per-class probabilities.
return_labels (bool) – Whether to return the labels with the predictions.
device (str) –
torch.device
to run the model. Select the device to run the model. Please see https://pytorch.org/docs/stable/tensor_attributes.html#torch.device for more details on input parameters for device. Default value is “cpu”.ioconfig (IOPatchPredictorConfig) – Patch Predictor IO configuration.
patch_input_shape (tuple) – Size of patches input to the model. Patches are at requested read resolution, not with respect to level 0, and must be positive.
stride_shape (tuple) – Stride using during tile and WSI processing. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, stride_shape=patch_input_shape.
resolution (Resolution) – Resolution used for reading the image. Please see
WSIReader
for details.units (Units) – Units of resolution used for reading the image. Choose from either level, power or mpp. Please see
WSIReader
for details.merge_predictions (bool) – Whether to merge the predictions to form a 2-dimensional map. This is only applicable for mode=’wsi’ or mode=’tile’.
save_dir (str or pathlib.Path) – Output directory when processing multiple tiles and whole-slide images. By default, it is folder output where the running script is invoked.
save_output (bool) – Whether to save output for a single file. default=False
self (PatchPredictor)
- Returns:
Model predictions of the input dataset. If multiple image tiles or whole-slide images are provided as input, or save_output is True, then results are saved to save_dir and a dictionary indicating save location for each input is returned.
The dict has the following format:
img_path: path of the input image.
raw: path to save location for raw prediction, saved in .json.
merged: path to .npy contain merged predictions if merge_predictions is True.
- Return type:
(
numpy.ndarray
or list or dict)
Examples
>>> wsis = ['wsi1.svs', 'wsi2.svs'] >>> predictor = PatchPredictor( ... pretrained_model="resnet18-kather100k") >>> output = predictor.predict(wsis, mode="wsi") >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} >>> output['wsi2.svs'] ... {'raw': '1.raw.json', 'merged': '1.merged.npy'}