Source code for saliency_metrics.models.build_classifier

from copy import deepcopy
from typing import Dict, Optional

import timm
import torch
import torch.nn as nn
from mmcv import Registry
from torchvision import models

__all__ = ["build_classifier", "TIMM_CLASSIFIERS", "TORCHVISION_CLASSIFIERS", "CUSTOM_CLASSIFIERS"]


def _preprocess_cfg(cfg: Dict, default_args: Optional[Dict] = None) -> Dict:
    """Override the `cfg` with `default_args`."""
    cfg = deepcopy(cfg)
    if default_args is not None:
        for name, value in default_args.items():
            cfg.setdefault(name, value)
    return cfg


def _build_timm_classifier(registry: Registry, cfg: Dict, default_args: Optional[Dict] = None) -> nn.Module:
    """Build a classifier from `timm` library."""
    cfg = _preprocess_cfg(cfg, default_args=default_args)
    model_name = cfg.pop("type")
    return timm.create_model(model_name, **cfg)


def _build_torchvision_classifier(registry: Registry, cfg: Dict, default_args: Optional[Dict] = None) -> nn.Module:
    """Build a classifier from `torchvision` library."""
    cfg = _preprocess_cfg(cfg, default_args=default_args)
    model_name = cfg.pop("type")
    ckpt_path = cfg.pop("checkpoint_path", None)
    _builder = getattr(models, model_name)
    model = _builder(**cfg)

    if ckpt_path is not None:
        ckpt = torch.load(ckpt_path, map_location="cpu")
        model.load_state_dict(ckpt, strict=False)
    return model


TIMM_CLASSIFIERS = Registry("timm_classifiers", scope="timm", build_func=_build_timm_classifier)
TORCHVISION_CLASSIFIERS = Registry(
    "torchvision_classifiers", scope="torchvision", build_func=_build_torchvision_classifier
)
CUSTOM_CLASSIFIERS = Registry("custom_classifiers", scope="custom")


[docs]def build_classifier(cfg: Dict, default_args: Optional[Dict] = None) -> nn.Module: """Build a classifier. This function supports building a classifier from timm_, torchvision_ or **custom-defined** models. ``cfg`` must contain the field "type", which should have the format ``<scope>.<model_name>``. Specifically, ``scope`` can be one of ``timm``, ``torchvision``, or ``custom``. When building a custom-defined classifier, the model should be already registered under ``saliency_metrics.models.CUSTOM_CLASSIFIERS`` registry. When building a classifier from ``torchvision`` or ``timm``, the ``model_name`` should be the name of corresponding builder function, e.g., ``resnet18``. I.e., ``build_classifier(cfg=dict(type="torchvision.resnet18"))`` is equivalent to call ``torchvision.models.resnet18``. .. _timm: https://rwightman.github.io/pytorch-image-models/ .. _torchvision: https://pytorch.org/vision/stable/index.html Args: cfg: A config dict that contains the arguments for building a classifier. It should at least contain the field "type". default_args: Other default arguments. Returns: The classifier. Examples: Build a `torchvision` classifier: .. code-block:: python from torchvision.models.resnet import ResNet from saliency_metrics.models import build_classifier cfg_1 = dict(type="torchvision.resnet18", num_classes=2, pretrained=False) model = build_classifier(cfg_1) assert isinstance(model, ResNet) Build a `timm` classifier: .. code-block:: python from timm.models.efficientnet import EfficientNet from saliency_metrics.models import build_classifier cfg_2 = dict(type="timm.efficientnet_b0", num_classes=2) model = build_classifier(cfg_2) assert isinstance(model, EfficientNet) Build a `custom` classifier: .. code-block:: python import torch import torch.nn as nn from saliency_metrics.models import CUSTOM_CLASSIFIERS, build_classifier # First register the class @CUSTOM_CLASSIFIERS.register_module() class MLP(nn.Module): def __init__(self, hidden_size: int = 10) -> None: super().__init__() self.layers = nn.Sequential( nn.Linear(10, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 2)) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.layers(x) # Now the instance of the class can be built from a config cfg_3 = dict(type="custom.MLP") model = build_classifier(cfg_3, default_args=dict(hidden_size=5)) assert isinstance(model, MLP) """ cfg = deepcopy(cfg) if "type" not in cfg: raise ValueError("Key 'type' must be contained in the config.") # timm and torchvision registries are actually empty. If using default build_from_cfg function to build # from CLASSIFIERS, it will raise an KeyError. We need to find the scope and call the corresponding build function scope, model_name = Registry.split_scope_key(cfg["type"]) if scope is None: raise ValueError(f"type must be in format <scope>.<model_name>, but got {cfg['type']}.") if scope == "": raise ValueError("scope must not be an empty string.") cfg.update({"type": model_name}) if scope == "timm": return TIMM_CLASSIFIERS.build(cfg=cfg, default_args=default_args) elif scope == "torchvision": return TORCHVISION_CLASSIFIERS.build(cfg=cfg, default_args=default_args) elif scope == "custom": return CUSTOM_CLASSIFIERS.build(cfg=cfg, default_args=default_args) else: raise ValueError(f"Invalid scope name, should be one of 'timm', 'torchvision', 'custom', but got {scope}.")