CNNModel¶

class CNNModel(backbone, num_classes=1)[source]¶

Retrieve the model backbone and attach an extra FCN to perform classification.

Parameters:
  • backbone (str) – Model name.

  • num_classes (int) – Number of classes output by model.

num_classes¶

Number of classes output by the model.

Type:

int

feat_extract¶

Backbone CNN model.

Type:

nn.Module

pool¶

Type of pooling applied after feature extraction.

Type:

nn.Module

classifier¶

Linear classifier module used to map the features to the output.

Type:

nn.Module

Initialize CNNModel.

Methods

forward

Pass input data through the model.

infer_batch

Run inference on an input batch.

postproc

Define the post-processing of this class of model.

Attributes

training

forward(imgs)[source]¶

Pass input data through the model.

Parameters:
Return type:

Tensor

static infer_batch(model, batch_data, *, on_gpu)[source]¶

Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Parameters:
  • model (nn.Module) – PyTorch defined model.

  • batch_data (torch.Tensor) – A batch of data generated by torch.utils.data.DataLoader.

  • on_gpu (bool) – Whether to run inference on a GPU.

Return type:

ndarray

static postproc(image)[source]¶

Define the post-processing of this class of model.

This simply applies argmax along last axis of the input.

Parameters:

image (ndarray)

Return type:

ndarray