ResNetEncoder¶
tiatoolbox
.models
.architecture
.unet
.ResNetEncoder
- class ResNetEncoder(block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None)[source]¶
A subclass of ResNet defined in torch.
This class overwrites the forward implementation within pytorch to return features of each downsampling level. This is necessary for segmentation.
Methods
Attributes
training
- Parameters:
- static resnet(num_input_channels, downsampling_levels)[source]¶
Shortcut method to create customised ResNet.
- Parameters:
- Returns:
A pytorch model.
- Return type:
model (torch.nn.Module)
Examples
>>> # instantiate a resnet50 >>> ResNetEncoder.resnet50( ... num_input_channels, ... [3, 4, 6, 3], ... )