CNNModel#
tiatoolbox
.models
.architecture
.vanilla
.CNNModel
- class CNNModel(backbone, num_classes=1)[source]#
Retrieve the model backbone and attach an extra FCN to perform classification.
- feat_extract#
Backbone CNN model.
- Type:
nn.Module
- pool#
Type of pooling applied after feature extraction.
- Type:
nn.Module
- classifier#
Linear classifier module used to map the features to the output.
- Type:
nn.Module
Methods
Pass input data through the model.
Run inference on an input batch.
Define the post-processing of this class of model.
Attributes
- forward(imgs)[source]#
Pass input data through the model.
- Parameters:
imgs (torch.Tensor) – Model input.
- static infer_batch(model, batch_data, on_gpu)[source]#
Run inference on an input batch.
Contains logic for forward operation as well as i/o aggregation.
- Parameters:
model (nn.Module) – PyTorch defined model.
batch_data (ndarray) – A batch of data generated by torch.utils.data.DataLoader.
on_gpu (bool) – Whether to run inference on a GPU.