TimmModel¶
tiatoolbox
.models
.architecture
.vanilla
.TimmModel
- class TimmModel(backbone, num_classes=1, *, pretrained)[source]¶
Retrieve the tile encoder from timm.
This is a wrapper for pretrained models within timm.
- Parameters:
backbone (str) –
- Model name. Currently, the tool supports following
model names and their default associated weights from timm. - “efficientnet_b{i}” for i in [0, 1, …, 7] - “UNI” - “prov-gigapath” - “UNI2” - “Virchow” - “Virchow2” - “kaiko” - “H-optimus-0” - “H-optimus-1” - “H0-mini”
num_classes (int) – Number of classes output by model.
pretrained (bool, keyword-only) – Whether to load pretrained weights.
- feat_extract¶
Backbone Timm model.
- Type:
nn.Module
- classifier¶
Linear classifier module used to map the features to the output.
- Type:
nn.Module
Example
>>> model = TimmModel("UNI", pretrained=True) >>> output = model(torch.randn(1, 3, 224, 224)) >>> print(output.shape)
Initialize
TimmModel
.Methods
Pass input data through the model.
Run inference on an input batch.
Define the post-processing of this class of model.
Attributes
training
- forward(imgs)[source]¶
Pass input data through the model.
- Parameters:
imgs (torch.Tensor) – Model input.
self (TimmModel)
- Returns:
The output logits after passing through the model.
- Return type:
- static infer_batch(model, batch_data, device)[source]¶
Run inference on an input batch.
Contains logic for forward operation as well as i/o aggregation.
- Parameters:
model (nn.Module) – PyTorch defined model.
batch_data (torch.Tensor) – A batch of data generated by torch.utils.data.DataLoader.
device (str) – Transfers model to the specified device. Default is “cpu”.
- Returns:
The model predictions as a NumPy array.
- Return type:
Example
>>> output = _infer_batch(model, batch_data, "cuda") >>> print(output)