UNetModel¶
tiatoolbox
.models
.architecture
.unet
.UNetModel
- class UNetModel(num_input_channels=2, num_output_channels=2, encoder='resnet50', encoder_levels=None, decoder_block=None, skip_type='add')[source]¶
Generate families of UNet model.
This supports different encoders. However, the decoder is relatively simple, each upsampling block contains a number of vanilla convolution layers, that are not customizable. Additionally, the aggregation between down-sampling and up-sampling is addition, not concatenation.
- Parameters:
num_input_channels (int) – Number of channels in input images.
num_output_channels (int) – Number of channels in output images.
encoder (str) –
- Name of the encoder, currently supports:
”resnet50”: The well-known ResNet50- this is not the pre-activation model.
”unet”: The vanilla UNet encoder where each down-sampling level contains 2 blocks of Convolution-BatchNorm-ReLu.
encoder_levels (list) – A list of integers to configure “unet” encoder levels. Each number defines the number of output channels at each down-sampling level (2 convolutions). Number of intergers define the number down-sampling levels in the unet encoder. This is only applicable when encoder=”unet”.
decoder_block (list) – A list of convolution layers. Each item is an integer and denotes the layer kernel size.
skip_type (str) – Choosing between “add” or “concat” method to be used for combining feature maps from encoder and decoder parts at skip connections. Default is “add”.
- Returns:
A pytorch model.
- Return type:
Examples
>>> # instantiate a UNet with resnet50 encoder and >>> # only 1 3x3 per each up-sampling block in the decoder >>> UNetModel.resnet50( ... 2, 2, ... encoder="resnet50", ... decoder_block=(3,) ... )
Initialize
UNetModel
.Methods
Logic for using layers defined in init.
Run inference on an input batch.
Attributes
training
- forward(imgs, *args, **kwargs)[source]¶
Logic for using layers defined in init.
This method defines how layers are used in forward operation.
- Parameters:
imgs (
torch.Tensor
) – Input images, the tensor is of the shape NCHW.args (list) – List of input arguments. Not used here. Provided for consistency with the API.
kwargs (dict) – Key-word arguments. Not used here. Provided for consistency with the API.
self (UNetModel)
- Returns:
The inference output. The tensor is of the shape NCHW. However, height and width may not be the same as the input images.
- Return type:
- static infer_batch(model, batch_data, *, device)[source]¶
Run inference on an input batch.
This contains logic for forward operation as well as i/o aggregation.
- Parameters:
model (nn.Module) – PyTorch defined model.
batch_data (
torch.Tensor
) – A batch of data generated by torch.utils.data.DataLoader.device (str) – Transfers model to the specified device. Default is “cpu”.
- Returns:
List of network output head, each output is an
numpy.ndarray
.- Return type: