CNNModel

class CNNModel(backbone, num_classes=1)[source]

Retrieve the model backbone and attach an extra FCN to perform classification.

Parameters
  • backbone (str) – Model name.

  • num_classes (int) – Number of classes output by model.

num_classes

Number of classes output by the model.

Type

int

feat_extract

Backbone CNN model.

Type

nn.Module

pool

Type of pooling applied after feature extraction.

Type

nn.Module

classifier

Linear classifier module used to map the features to the output.

Type

nn.Module

Methods

forward

Pass input data through the model.

infer_batch

Run inference on an input batch.

postproc

Define the post-processing of this class of model.

Attributes

forward(imgs)[source]

Pass input data through the model.

Parameters

imgs (torch.Tensor) – Model input.

static infer_batch(model, batch_data, on_gpu)[source]

Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Parameters
  • model (nn.Module) – PyTorch defined model.

  • batch_data (ndarray) – A batch of data generated by torch.utils.data.DataLoader.

  • on_gpu (bool) – Whether to run inference on a GPU.

static postproc(image)[source]

Define the post-processing of this class of model.

This simply applies argmax along last axis of the input.