TimmBackbone

class TimmBackbone(backbone, *, pretrained)[source]

Retrieve tile encoders from timm.

This is a wrapper for pretrained models within timm.

Parameters:
  • backbone (str) –

    Model name. Supported model names include:
    • ”efficientnet_b{i}” for i in [0, 1, …, 7]

    • ”UNI”

    • ”prov-gigapath”

    • ”UNI2”

    • ”Virchow”

    • ”Virchow2”

    • ”kaiko”

    • ”H-optimus-0”

    • ”H-optimus-1”

    • ”H0-mini”

  • pretrained (bool, keyword-only) – Whether to load pretrained weights.

feat_extract

Backbone timm model.

Type:

nn.Module

Examples

>>> # Creating UNI tile encoder
>>> model = TimmBackbone(backbone="UNI", pretrained=True)
>>> model.eval()  # set to evaluation mode
>>> # dummy sample in NHWC form
>>> samples = torch.rand(4, 3, 224, 224)
>>> features = model(samples)
>>> features.shape  # feature vector
torch.Size([4, 1024])

Initialize TimmBackbone.

Methods

forward

Pass input data through the model.

infer_batch

Run inference on an input batch.

Attributes

training

forward(imgs)[source]

Pass input data through the model.

Parameters:
Returns:

The extracted features.

Return type:

torch.Tensor

static infer_batch(model, batch_data, device)[source]

Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Parameters:
  • model (nn.Module) – PyTorch defined model.

  • batch_data (torch.Tensor) – A batch of data generated by torch.utils.data.DataLoader.

  • device (str) – Transfers model to the specified device. Default is “cpu”.

Returns:

list of dictionary values with numpy arrays.

Return type:

list[dict[str, np.ndarray]]

Example

>>> output = TimmBackbone.infer_batch(model, batch_data, "cuda")
>>> print(output)