CNNBackbone

class CNNBackbone(backbone)[source]

Retrieve the model backbone and strip the classification layer.

This is a wrapper for pretrained models within pytorch.

Parameters:

backbone (str) –

Model name. Currently, the tool supports following
model names and their default associated weights from pytorch.
  • ”alexnet”

  • ”resnet18”

  • ”resnet34”

  • ”resnet50”

  • ”resnet101”

  • ”resnext50_32x4d”

  • ”resnext101_32x8d”

  • ”wide_resnet50_2”

  • ”wide_resnet101_2”

  • ”densenet121”

  • ”densenet161”

  • ”densenet169”

  • ”densenet201”

  • ”inception_v3”

  • ”googlenet”

  • ”mobilenet_v2”

  • ”mobilenet_v3_large”

  • ”mobilenet_v3_small”

Examples

>>> # Creating resnet50 architecture from default pytorch
>>> # without the classification layer with its associated
>>> # weights loaded
>>> model = CNNBackbone(backbone="resnet50")
>>> model.eval()  # set to evaluation mode
>>> # dummy sample in NHWC form
>>> samples = torch.rand(4, 3, 512, 512)
>>> features = model(samples)
>>> features.shape  # features after global average pooling
torch.Size([4, 2048])

Initialize CNNBackbone.

Methods

forward

Pass input data through the model.

infer_batch

Run inference on an input batch.

Attributes

training

forward(imgs)[source]

Pass input data through the model.

Parameters:
Return type:

Tensor

static infer_batch(model, batch_data, *, on_gpu)[source]

Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Parameters:
  • model (nn.Module) – PyTorch defined model.

  • batch_data (torch.Tensor) – A batch of data generated by torch.utils.data.DataLoader.

  • on_gpu (bool) – Whether to run inference on a GPU.

Return type:

list[ndarray, …]