UNetModel#

class UNetModel(num_input_channels=2, num_output_channels=2, encoder='resnet50', encoder_levels=None, decoder_block=None, skip_type='add')[source]#

Generate families of UNet model.

This supports different encoders. However, the decoder is relatively simple, each upsampling block contains a number of vanilla convolution layers, that are not customizable. Additionally, the aggregation between down-sampling and up-sampling is addition, not concatenation.

Parameters:
  • num_input_channels (int) – Number of channels in input images.

  • num_output_channels (int) – Number of channels in output images.

  • encoder (str) –

    Name of the encoder, currently supports:
    • ”resnet50”: The well-known ResNet50- this is not the pre-activation model.

    • ”unet”: The vanilla UNet encoder where each down-sampling level contains 2 blocks of Convolution-BatchNorm-ReLu.

  • encoder_levels (list) – A list of integers to configure “unet” encoder levels. Each number defines the number of output channels at each down-sampling level (2 convolutions). Number of intergers define the number down-sampling levels in the unet encoder. This is only applicable when encoder=”unet”.

  • decoder_block (list) – A list of convolution layers. Each item is an integer and denotes the layer kernel size.

  • skip_type (str) – Choosing between “add” or “concat” method to be used for combining feature maps from encoder and decoder parts at skip connections. Default is “add”.

Returns:

A pytorch model.

Return type:

torch.nn.Module

Examples

>>> # instantiate a UNet with resnet50 encoder and
>>> # only 1 3x3 per each up-sampling block in the decoder
>>> UNetModel.resnet50(
...     2, 2,
...     encoder="resnet50",
...     decoder_block=(3,)
... )

Methods

forward

Logic for using layers defined in init.

infer_batch

Run inference on an input batch.

Attributes

forward(imgs, *args, **kwargs)[source]#

Logic for using layers defined in init.

This method defines how layers are used in forward operation.

Parameters:

imgs (torch.Tensor) – Input images, the tensor is of the shape NCHW.

Returns:

The inference output. The tensor is of the shape NCHW. However, height and width may not be the same as the input images.

Return type:

torch.Tensor

static infer_batch(model, batch_data, on_gpu)[source]#

Run inference on an input batch.

This contains logic for forward operation as well as i/o aggregation.

Parameters:
  • model (nn.Module) – PyTorch defined model.

  • batch_data (numpy.ndarray) – A batch of data generated by torch.utils.data.DataLoader.

  • on_gpu (bool) – Whether to run inference on a GPU.

Returns:

List of network output head, each output is an numpy.ndarray.

Return type:

list