HoVerNetPlus¶
- class HoVerNetPlus(num_input_channels=3, num_types=None, num_layers=None)[source]¶
Initialise HoVer-Net+.
HoVer-Net+ takes an RGB input image, and provides the option to simultaneously segment and classify the nuclei present, aswell as semantically segment different regions or layers in the images. Note the HoVer-Net+ architecture assumes an image resolution of 0.5 mpp, in contrast to HoVer-Net at 0.25 mpp.
Initialise HoVer-Net+.
- Parameters
Methods
Run inference on an input batch.
Post processing script for image tiles.
Attributes
- static infer_batch(model, batch_data, on_gpu)[source]¶
Run inference on an input batch.
This contains logic for forward operation as well as batch i/o aggregation.
- Parameters
model (nn.Module) – PyTorch defined model.
batch_data (ndarray) – a batch of data generated by torch.utils.data.DataLoader.
on_gpu (bool) – Whether to run inference on a GPU.
- static postproc(raw_maps)[source]¶
Post processing script for image tiles.
- Parameters
raw_maps (list(ndarray)) – list of prediction output of each head and assumed to be in the order of [np, hv, tp, ls] (match with the output of infer_batch).
- Returns
- pixel-wise nuclear instance segmentation
prediction.
- inst_dict (dict): a dictionary containing a mapping of each instance
within inst_map instance information. It has the following form inst_info = {
box: number[], centroids: number[], contour: number[][], type: number, prob: number,
} inst_dict = {[inst_uid: number] : inst_info} and inst_uid is an integer corresponds to the instance having the same pixel value within inst_map.
layer_map (ndarray): pixel-wise layer segmentation prediction. layer_dict (dict): a dictionary containing a mapping of each segmented
layer within layer_map. It has the following form layer_info = {
contour: number[][], type: number,
} layer_dict = {[layer_uid: number] : layer_info}
- Return type
inst_map (ndarray)
Examples
>>> from tiatoolbox.models.architecture.hovernet_plus import HoVerNetPlus >>> import torch >>> import numpy as np >>> batch = torch.from_numpy(image_patch)[None] >>> # image_patch is a 256x256x3 numpy array >>> weights_path = "A/weights.pth" >>> pretrained = torch.load(weights_path) >>> model = HoVerNetPlus(num_types=3, num_layers=5) >>> model.load_state_dict(pretrained) >>> output = model.infer_batch(model, batch, on_gpu=False) >>> output = [v[0] for v in output] >>> output = model.postproc(output)