ResNetEncoder¶
tiatoolbox.models.architecture.unet.ResNetEncoder
- class ResNetEncoder(block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None)[source]¶
A subclass of ResNet defined in torch.
This class overwrites the forward implementation within pytorch to return features of each downsampling level. This is necessary for segmentation.
Methods
Shortcut method to create customised ResNet.
Shortcut method to create ResNet50.
Attributes
- Parameters
- Return type
None
- static resnet(num_input_channels, downsampling_levels)[source]¶
Shortcut method to create customised ResNet.
- Parameters
- Returns
a pytorch model.
- Return type
model (torch.nn.Module)
Examples
>>> # instantiate a resnet50 >>> ResNetEncoder.resnet50( ... num_input_channels, ... [3, 4, 6, 3], ... pretrained ... )