get_pretrained_model¶
- get_pretrained_model(pretrained_model=None, pretrained_weights=None, *, overwrite=False)[source]¶
Load a predefined PyTorch model with the appropriate pretrained weights.
- Parameters:
pretrained_model (str) –
Name of the existing models support by tiatoolbox for processing the data. The models currently supported:
alexnet
resnet18
resnet34
resnet50
resnet101
resnext50_32x4d
resnext101_32x8d
wide_resnet50_2
wide_resnet101_2
densenet121
densenet161
densenet169
densenet201
mobilenet_v2
mobilenet_v3_large
mobilenet_v3_small
googlenet
Each model has been trained on the Kather100K and PCam datasets. The format of pretrained_model is <model_name>-<dataset_name>. For example, to use a resnet18 model trained on Kather100K, use resnet18-kather100k and to use an alexnet model trained on PCam, use `alexnet-pcam.
By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the pretrained_weights argument. Argument is case-insensitive.
pretrained_weights (str) – Path to the weight of the corresponding pretrained_model.
overwrite (bool) – To always overwriting downloaded weights.
- Return type:
Examples
>>> # get mobilenet pretrained on Kather100K dataset by the TIA team >>> model = get_pretrained_model(pretrained_model='mobilenet_v2-kather100k') >>> # get mobilenet defined by TIA team, but loaded with user defined weights >>> model = get_pretrained_model( ... pretrained_model='mobilenet_v2-kather100k', ... pretrained_weights='/A/B/C/my_weights.tar', ... ) >>> # get resnet34 pretrained on PCam dataset by TIA team >>> model = get_pretrained_model(pretrained_model='resnet34-pcam')