TimmBackbone¶
tiatoolbox
.models
.architecture
.vanilla
.TimmBackbone
- class TimmBackbone(backbone, *, pretrained)[source]¶
Retrieve tile encoders from timm.
This is a wrapper for pretrained models within timm.
- Parameters:
- feat_extract¶
Backbone timm model.
- Type:
nn.Module
Examples
>>> # Creating UNI tile encoder >>> model = TimmBackbone(backbone="UNI", pretrained=True) >>> model.eval() # set to evaluation mode >>> # dummy sample in NHWC form >>> samples = torch.rand(4, 3, 224, 224) >>> features = model(samples) >>> features.shape # feature vector torch.Size([4, 1024])
Initialize
TimmBackbone
.Methods
Pass input data through the model.
Run inference on an input batch.
Attributes
training
- forward(imgs)[source]¶
Pass input data through the model.
- Parameters:
imgs (torch.Tensor) – Model input.
self (TimmBackbone)
- Returns:
The extracted features.
- Return type:
- static infer_batch(model, batch_data, device)[source]¶
Run inference on an input batch.
Contains logic for forward operation as well as i/o aggregation.
- Parameters:
model (nn.Module) – PyTorch defined model.
batch_data (torch.Tensor) – A batch of data generated by torch.utils.data.DataLoader.
device (str) – Transfers model to the specified device. Default is “cpu”.
- Returns:
list of dictionary values with numpy arrays.
- Return type:
Example
>>> output = TimmBackbone.infer_batch(model, batch_data, "cuda") >>> print(output)