Source code for saliency_metrics.metrics.build_metric

from abc import abstractmethod
from typing import Any, Dict, Optional, Protocol, Union, runtime_checkable

from mmcv import Config, Registry
from numpy import ndarray
from torch import Tensor

from .serializable_result import SerializableResult

__all__ = ["ReInferenceMetric", "ReTrainingMetric", "METRICS", "build_metric"]


[docs]@runtime_checkable class ReInferenceMetric(Protocol): """Re-inference based metric. A Metric implementing this protocol performs per-sample evaluation at inference time. Specifically, it first perturbs the input image according to the saliency map and then measure the degradation of the model's prediction. """ _result: SerializableResult
[docs] @abstractmethod def update(self, single_result: Dict) -> None: """Given the evaluation result on a single sample, update the cached result (for the whole dataset). Args: single_result: Evaluation result on a single sample. Returns: None """ raise NotImplementedError
[docs] @abstractmethod def evaluate( self, img: Union[Tensor, ndarray], smap: Union[Tensor, ndarray], target: Union[Tensor, int], **kwargs: Any ) -> Dict: """Perform evaluation on a single sample. Args: img: Input image. smap: Saliency map. target: Ground-truth target. kwargs: Other optional arguments for example the image path, original image size, and so on. Returns: Evaluation result on a single sample. """ raise NotImplementedError
[docs] def get_result(self) -> SerializableResult: """Get the cached result. Returns: The final evaluation result for the whole dataset. """ return self._result
[docs]@runtime_checkable class ReTrainingMetric(Protocol): """Re-training based metric. A metric implementing this protocol re-trains a model on a perturbed dataset and evaluate the performance degradation. """ _result: SerializableResult
[docs] @abstractmethod def evaluate(self, cfg: Config, dist_args: Optional[Dict] = None) -> None: """Perform re-training evaluation on the whole dataset. Args: cfg: Config. It specifies the hyper-parameters of e.g., dataset, model, optimizer, lr-scheduler, max epochs etc. dist_args: DDP training hyper-parameters e.g. ``nproc_per_node``, ``backend`` etc. See also: `Parallel`_. Returns: None .. _Parallel: https://pytorch.org/ignite/generated/ignite.distributed.launcher.Parallel.html """ raise NotImplementedError
[docs] def get_result(self) -> SerializableResult: """Get the cached result. Returns: The final evaluation result for the whole dataset. """ return self._result
METRICS = Registry("Metrics")
[docs]def build_metric(cfg: Dict, default_args: Optional[Dict] = None) -> Union[ReInferenceMetric, ReTrainingMetric]: """Build an evaluation metric. Args: cfg: Config dictionary. default_args: Other default arguments. Returns: A evaluation metric. """ return METRICS.build(cfg=cfg, default_args=default_args)