UNetModel¶
tiatoolbox.models.architecture.unet.UNetModel
- class UNetModel(num_input_channels=2, num_output_channels=2, encoder='resnet50', decoder_block=(3, 3))[source]¶
Generate families of UNet model.
This supports different encoders. However, the decoder is relatively simple- each upsampling block contains a number of vanilla convolution layers, that are not customizable. Additionally, the aggregation between down-sampling and up-sampling is addition, not concatenation.
- Parameters
num_input_channels (int) – Number of channels in input images.
num_output_channels (int) – Number of channels in output images.
encoder (str) –
Name of the encoder, currently supports: - “resnet50”: The well-known ResNet50- this is not the pre-activation model. - “unet”: The vanilla UNet encoder where each down-sampling level
contains 2 blocks of Convolution-BatchNorm-ReLu.
decoder_block (list) – A list of convolution layers. Each item is an integer and denotes the layer kernel size.
classifier (list) – A list of convolution layers before the final 1x1 convolution. Each item is an integer denotes the layer kernel size. The default is None and contains only the 1x1 convolution.
- Returns
a pytorch model.
- Return type
model (torch.nn.Module)
Examples
>>> # instantiate a UNet with resnet50 endcoder and >>> # only 1 3x3 per each up-sampling block in the decoder >>> UNetModel.resnet50( ... 2, 2 ... encoder="resnet50", ... decoder_block=(3,) ... )
Methods
Logic for using layers defined in init.
Run inference on an input batch.
Attributes
- forward(imgs, *args, **kwargs)[source]¶
Logic for using layers defined in init.
This method defines how layers are used in forward operation.
- Parameters
imgs (torch.Tensor) – Input images, the tensor is of the shape NCHW.
- Returns
- The inference output. The tensor is of the shape
NCHW. However, height and width may not be the same as the input images.
- Return type
output (torch.Tensor)
- static infer_batch(model, batch_data, on_gpu)[source]¶
Run inference on an input batch.
This contains logic for forward operation as well as i/o aggregation.
- Parameters
model (nn.Module) – PyTorch defined model.
batch_data (ndarray) – A batch of data generated by torch.utils.data.DataLoader.
on_gpu (bool) – Whether to run inference on a GPU.
- Returns
List of network output head, each output is a ndarray.