get_pretrained_model¶

get_pretrained_model(pretrained_model=None, pretrained_weights=None, *, overwrite=False)[source]¶

Load a predefined PyTorch model with the appropriate pretrained weights.

Parameters:
  • pretrained_model (str) –

    Name of the existing models support by tiatoolbox for processing the data. The models currently supported:

    • alexnet

    • resnet18

    • resnet34

    • resnet50

    • resnet101

    • resnext50_32x4d

    • resnext101_32x8d

    • wide_resnet50_2

    • wide_resnet101_2

    • densenet121

    • densenet161

    • densenet169

    • densenet201

    • mobilenet_v2

    • mobilenet_v3_large

    • mobilenet_v3_small

    • googlenet

    Each model has been trained on the Kather100K and PCam datasets. The format of pretrained_model is <model_name>-<dataset_name>. For example, to use a resnet18 model trained on Kather100K, use resnet18-kather100k and to use an alexnet model trained on PCam, use `alexnet-pcam.

    By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the pretrained_weights argument. Argument is case-insensitive.

  • pretrained_weights (str) – Path to the weight of the corresponding pretrained_model.

  • overwrite (bool) – To always overwriting downloaded weights.

Return type:

tuple[torch.nn.Module, IOConfigABC]

Examples

>>> # get mobilenet pretrained on Kather100K dataset by the TIA team
>>> model = get_pretrained_model(pretrained_model='mobilenet_v2-kather100k')
>>> # get mobilenet defined by TIA team, but loaded with user defined weights
>>> model = get_pretrained_model(
...     pretrained_model='mobilenet_v2-kather100k',
...     pretrained_weights='/A/B/C/my_weights.tar',
... )
>>> # get resnet34 pretrained on PCam dataset by TIA team
>>> model = get_pretrained_model(pretrained_model='resnet34-pcam')