Models#

build_classifier

Build a classifier.

freeze_module

Freeze a specific module of the model.

get_module

Get a specific layer in a model.

saliency_metrics.models.build_classifier(cfg, default_args=None)[source]#

Build a classifier.

This function supports building a classifier from timm, torchvision or custom-defined models. cfg must contain the field “type”, which should have the format <scope>.<model_name>. Specifically, scope can be one of timm, torchvision, or custom. When building a custom-defined classifier, the model should be already registered under saliency_metrics.models.CUSTOM_CLASSIFIERS registry. When building a classifier from torchvision or timm, the model_name should be the name of corresponding builder function, e.g., resnet18. I.e., build_classifier(cfg=dict(type="torchvision.resnet18")) is equivalent to call torchvision.models.resnet18.

Parameters
  • cfg (Dict) – A config dict that contains the arguments for building a classifier. It should at least contain the field “type”.

  • default_args (Optional[Dict]) – Other default arguments.

Return type

Module

Returns

The classifier.

Examples

Build a torchvision classifier:

from torchvision.models.resnet import ResNet
from saliency_metrics.models import build_classifier

cfg_1 = dict(type="torchvision.resnet18", num_classes=2, pretrained=False)
model = build_classifier(cfg_1)
assert isinstance(model, ResNet)

Build a timm classifier:

from timm.models.efficientnet import EfficientNet
from saliency_metrics.models import build_classifier

cfg_2 = dict(type="timm.efficientnet_b0", num_classes=2)
model = build_classifier(cfg_2)
assert isinstance(model, EfficientNet)

Build a custom classifier:

import torch
import torch.nn as nn
from saliency_metrics.models import CUSTOM_CLASSIFIERS, build_classifier

# First register the class
@CUSTOM_CLASSIFIERS.register_module()
class MLP(nn.Module):
    def __init__(self, hidden_size: int = 10) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(10, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 2))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

# Now the instance of the class can be built from a config
cfg_3 = dict(type="custom.MLP")
model = build_classifier(cfg_3, default_args=dict(hidden_size=5))
assert isinstance(model, MLP)
saliency_metrics.models.freeze_module(model, module=None, eval_mode=True)[source]#

Freeze a specific module of the model.

This function freezes a specific layer of the model by setting the requires_grad flag of its parameters to False. It also converts the whole model into evaluation mode, if eval_mode is True.

Parameters
  • model (Module) – Model to be processed.

  • module (Optional[str]) – The name of the target module. If None, the target module to be frozen is the entire model.

  • eval_mode (bool) – If True, turns the entire model into eval mode.

Return type

None

Returns

None

Examples

from saliency_metrics.models import build_classifier, get_module

model_1 = build_classifier(dict(type="timm.resnet18", num_classes=2))
freeze_module(model_1, "fc", eval_mode=True)
assert not model_1.training
assert not model_1.fc.weight.requires_grad

model_2 = build_classifier(dict(type="timm.resnet18", num_classes=2))
freeze_module(model_2, None)
for p in model_2.parameters():
    assert not p.requires_grad
saliency_metrics.models.get_module(model, module)[source]#

Get a specific layer in a model.

This function is adapted from TorchRay. module is the name of a module (as given by the named_modules() function for torch.nn.Module objects). The function searches for a module with the name module and returns a torch.nn.Module if found; otherwise, None is returned.

Parameters
  • model (Module) – Model in which to search for layer.

  • module (str) – Name of layer.

Return type

Optional[Module]

Returns

Specific nn.Module layer (None if the layer isn’t found).

Examples

from saliency_metrics.models import build_classifier, get_module

cfg = dict(type="torchvision.resnet18", num_classes=2)
model = build_classifier(cfg)

# get the last block
_ = get_module(model, "layer4.1")
# get the last BN layer
_ = get_module(model, "layer4.1.bn2")