UNetModel

class UNetModel(num_input_channels=2, num_output_channels=2, encoder='resnet50', decoder_block=(3, 3))[source]

Generate families of UNet model.

This supports different encoders. However, the decoder is relatively simple- each upsampling block contains a number of vanilla convolution layers, that are not customizable. Additionally, the aggregation between down-sampling and up-sampling is addition, not concatenation.

Parameters
  • num_input_channels (int) – Number of channels in input images.

  • num_output_channels (int) – Number of channels in output images.

  • encoder (str) –

    Name of the encoder, currently supports: - “resnet50”: The well-known ResNet50- this is not the pre-activation model. - “unet”: The vanilla UNet encoder where each down-sampling level

    contains 2 blocks of Convolution-BatchNorm-ReLu.

  • decoder_block (list) – A list of convolution layers. Each item is an integer and denotes the layer kernel size.

  • classifier (list) – A list of convolution layers before the final 1x1 convolution. Each item is an integer denotes the layer kernel size. The default is None and contains only the 1x1 convolution.

Returns

a pytorch model.

Return type

model (torch.nn.Module)

Examples

>>> # instantiate a UNet with resnet50 endcoder and
>>> # only 1 3x3 per each up-sampling block in the decoder
>>> UNetModel.resnet50(
...     2, 2
...     encoder="resnet50",
...     decoder_block=(3,)
... )

Methods

forward

Logic for using layers defined in init.

infer_batch

Run inference on an input batch.

Attributes

forward(imgs, *args, **kwargs)[source]

Logic for using layers defined in init.

This method defines how layers are used in forward operation.

Parameters

imgs (torch.Tensor) – Input images, the tensor is of the shape NCHW.

Returns

The inference output. The tensor is of the shape

NCHW. However, height and width may not be the same as the input images.

Return type

output (torch.Tensor)

static infer_batch(model, batch_data, on_gpu)[source]

Run inference on an input batch.

This contains logic for forward operation as well as i/o aggregation.

Parameters
  • model (nn.Module) – PyTorch defined model.

  • batch_data (ndarray) – A batch of data generated by torch.utils.data.DataLoader.

  • on_gpu (bool) – Whether to run inference on a GPU.

Returns

List of network output head, each output is a ndarray.