TimmModel

class TimmModel(backbone, num_classes=1, *, pretrained)[source]

Retrieve the tile encoder from timm.

This is a wrapper for pretrained models within timm.

Parameters:
  • backbone (str) –

    Model name. Currently, the tool supports following

    model names and their default associated weights from timm. - “efficientnet_b{i}” for i in [0, 1, …, 7] - “UNI” - “prov-gigapath” - “UNI2” - “Virchow” - “Virchow2” - “kaiko” - “H-optimus-0” - “H-optimus-1” - “H0-mini”

  • num_classes (int) – Number of classes output by model.

  • pretrained (bool, keyword-only) – Whether to load pretrained weights.

num_classes

Number of classes output by the model.

Type:

int

pretrained

Whether to load pretrained weights.

Type:

bool

feat_extract

Backbone Timm model.

Type:

nn.Module

classifier

Linear classifier module used to map the features to the output.

Type:

nn.Module

Example

>>> model = TimmModel("UNI", pretrained=True)
>>> output = model(torch.randn(1, 3, 224, 224))
>>> print(output.shape)

Initialize TimmModel.

Methods

forward

Pass input data through the model.

infer_batch

Run inference on an input batch.

postproc

Define the post-processing of this class of model.

Attributes

training

forward(imgs)[source]

Pass input data through the model.

Parameters:
Returns:

The output logits after passing through the model.

Return type:

torch.Tensor

static infer_batch(model, batch_data, device)[source]

Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Parameters:
  • model (nn.Module) – PyTorch defined model.

  • batch_data (torch.Tensor) – A batch of data generated by torch.utils.data.DataLoader.

  • device (str) – Transfers model to the specified device. Default is “cpu”.

Returns:

The model predictions as a NumPy array.

Return type:

dict[str, np.ndarray]

Example

>>> output = _infer_batch(model, batch_data, "cuda")
>>> print(output)
static postproc(image)[source]

Define the post-processing of this class of model.

This simply applies argmax along last axis of the input.

Parameters:

image (np.ndarray) – The input image array.

Returns:

The post-processed image array.

Return type:

np.ndarray