CNNBackbone¶
tiatoolbox
.models
.architecture
.vanilla
.CNNBackbone
- class CNNBackbone(backbone)[source]¶
Retrieve the model backbone and strip the classification layer.
This is a wrapper for pretrained models within pytorch.
- Parameters:
backbone (str) –
- Model name. Currently, the tool supports following
- model names and their default associated weights from pytorch.
”alexnet”
”resnet18”
”resnet34”
”resnet50”
”resnet101”
”resnext50_32x4d”
”resnext101_32x8d”
”wide_resnet50_2”
”wide_resnet101_2”
”densenet121”
”densenet161”
”densenet169”
”densenet201”
”inception_v3”
”googlenet”
”mobilenet_v2”
”mobilenet_v3_large”
”mobilenet_v3_small”
- feat_extract¶
Backbone CNN model.
- Type:
nn.Module
- pool¶
Type of pooling applied after feature extraction.
- Type:
nn.Module
Examples
>>> # Creating resnet50 architecture from default pytorch >>> # without the classification layer with its associated >>> # weights loaded >>> model = CNNBackbone(backbone="resnet50") >>> model.eval() # set to evaluation mode >>> # dummy sample in NHWC form >>> samples = torch.rand(4, 3, 512, 512) >>> features = model(samples) >>> features.shape # features after global average pooling torch.Size([4, 2048])
Initialize
CNNBackbone
.Methods
Pass input data through the model.
Run inference on an input batch.
Attributes
training
- forward(imgs)[source]¶
Pass input data through the model.
- Parameters:
imgs (torch.Tensor) – Model input.
self (CNNBackbone)
- Returns:
The extracted features.
- Return type:
- static infer_batch(model, batch_data, device)[source]¶
Run inference on an input batch.
Contains logic for forward operation as well as i/o aggregation.
- Parameters:
model (nn.Module) – PyTorch defined model.
batch_data (torch.Tensor) – A batch of data generated by torch.utils.data.DataLoader.
device (str) – Transfers model to the specified device. Default is “cpu”.
- Returns:
list of dictionary values with numpy arrays.
- Return type:
Example
>>> output = CNNBackbone.infer_batch(model, batch_data, "cuda") >>> print(output)