Source code for saliency_metrics.models.model_utils

from typing import Optional

import torch.nn as nn


[docs]def get_module(model: nn.Module, module: str) -> Optional[nn.Module]: r"""Get a specific layer in a model. This function is adapted from `TorchRay <https://github.com/facebookresearch/TorchRay>`_. :attr:`module` is the name of a module (as given by the :func:`named_modules` function for :class:`torch.nn.Module` objects). The function searches for a module with the name :attr:`module` and returns a :class:`torch.nn.Module` if found; otherwise, ``None`` is returned. Args: model: Model in which to search for layer. module: Name of layer. Returns: Specific ``nn.Module`` layer (``None`` if the layer isn't found). Examples: .. code-block:: python from saliency_metrics.models import build_classifier, get_module cfg = dict(type="torchvision.resnet18", num_classes=2) model = build_classifier(cfg) # get the last block _ = get_module(model, "layer4.1") # get the last BN layer _ = get_module(model, "layer4.1.bn2") """ if not isinstance(module, str): raise TypeError(f"module can only be a str, but got {module.__class__.__name__}") if module == "": return model for name, curr_module in model.named_modules(): if name == module: return curr_module return None
[docs]def freeze_module(model: nn.Module, module: Optional[str] = None, eval_mode: bool = True) -> None: """Freeze a specific module of the model. This function freezes a specific layer of the model by setting the ``requires_grad`` flag of its parameters to False. It also converts the whole model into evaluation_ mode, if ``eval_mode`` is True. .. _evaluation: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval Args: model: Model to be processed. module: The name of the target module. If None, the target module to be frozen is the entire model. eval_mode: If True, turns the **entire** model into `eval` mode. Returns: None Examples: .. code-block:: python from saliency_metrics.models import build_classifier, get_module model_1 = build_classifier(dict(type="timm.resnet18", num_classes=2)) freeze_module(model_1, "fc", eval_mode=True) assert not model_1.training assert not model_1.fc.weight.requires_grad model_2 = build_classifier(dict(type="timm.resnet18", num_classes=2)) freeze_module(model_2, None) for p in model_2.parameters(): assert not p.requires_grad """ if eval_mode: model.eval() target_module = model if module is None else get_module(model, module) for p in target_module.parameters(): p.requires_grad = False