UNetModel

class UNetModel(num_input_channels=2, num_output_channels=2, encoder='resnet50', encoder_levels=None, decoder_block=None, skip_type='add')[source]

Generate families of UNet model.

This supports different encoders. However, the decoder is relatively simple, each upsampling block contains a number of vanilla convolution layers, that are not customizable. Additionally, the aggregation between down-sampling and up-sampling is addition, not concatenation.

Parameters:
  • num_input_channels (int) – Number of channels in input images.

  • num_output_channels (int) – Number of channels in output images.

  • encoder (str) –

    Name of the encoder, currently supports:
    • ”resnet50”: The well-known ResNet50- this is not the pre-activation model.

    • ”unet”: The vanilla UNet encoder where each down-sampling level contains 2 blocks of Convolution-BatchNorm-ReLu.

  • encoder_levels (list) – A list of integers to configure “unet” encoder levels. Each number defines the number of output channels at each down-sampling level (2 convolutions). Number of intergers define the number down-sampling levels in the unet encoder. This is only applicable when encoder=”unet”.

  • decoder_block (list) – A list of convolution layers. Each item is an integer and denotes the layer kernel size.

  • skip_type (str) – Choosing between “add” or “concat” method to be used for combining feature maps from encoder and decoder parts at skip connections. Default is “add”.

Returns:

A pytorch model.

Return type:

torch.nn.Module

Examples

>>> # instantiate a UNet with resnet50 encoder and
>>> # only 1 3x3 per each up-sampling block in the decoder
>>> UNetModel.resnet50(
...     2, 2,
...     encoder="resnet50",
...     decoder_block=(3,)
... )

Initialize UNetModel.

Methods

forward

Logic for using layers defined in init.

infer_batch

Run inference on an input batch.

Attributes

training

forward(imgs, *args, **kwargs)[source]

Logic for using layers defined in init.

This method defines how layers are used in forward operation.

Parameters:
  • imgs (torch.Tensor) – Input images, the tensor is of the shape NCHW.

  • args (list) – List of input arguments. Not used here. Provided for consistency with the API.

  • kwargs (dict) – Key-word arguments. Not used here. Provided for consistency with the API.

  • self (UNetModel)

Returns:

The inference output. The tensor is of the shape NCHW. However, height and width may not be the same as the input images.

Return type:

torch.Tensor

static infer_batch(model, batch_data, *, on_gpu)[source]

Run inference on an input batch.

This contains logic for forward operation as well as i/o aggregation.

Parameters:
  • model (nn.Module) – PyTorch defined model.

  • batch_data (torch.Tensor) – A batch of data generated by torch.utils.data.DataLoader.

  • on_gpu (bool) – Whether to run inference on a GPU.

Returns:

List of network output head, each output is an numpy.ndarray.

Return type:

list