Source code for saliency_metrics.metrics.attribution_method

from abc import abstractmethod
from typing import Any, Dict, Optional, Protocol, Sequence, Tuple, Union, runtime_checkable

import torch.nn as nn
from numpy import ndarray
from torch import Tensor

from ..models import get_module
from ..utils import resize_img

__all__ = ["AttributionMethod", "CaptumGradCAM"]


[docs]@runtime_checkable class AttributionMethod(Protocol): """Protocol of attribution (also known as explanation) methods. This protocol is mainly used in the Sanity Check metric, where an attribution method must implement this protocol. """
[docs] @abstractmethod def attribute( self, img: Tensor, target: Union[int, Sequence[int], Tensor], as_ndarray: bool = True, interpolate: bool = True, interpolate_args: Optional[Dict] = None, **kwargs: Any, ) -> Union[Tensor, ndarray, Tuple[Union[Tensor, ndarray]]]: """Attribute and produce the saliency map. .. note:: The method performs attribution on single image, i.e., the local explanation. Args: img: Input image with shape ``(1, num_channels, height, width)``. target: Ground-truth target. as_ndarray: If True, then convert the saliency map into a ndarray. interpolate: if True, then resize and interpolate the saliency map to the image's spatial size. interpolate_args: Other arguments for interpolation. See also :func:`saliency_metrics.utils.resize_img`. **kwargs: Other keyword arguments for attribution. Returns: A saliency map or a tuple of saliency maps. """ raise NotImplementedError
class CaptumGradCAM(AttributionMethod): """A wrapper class for ``captum.attr.LayerGradCam``. This class is only for internal testing. """ def __init__(self, classifier: nn.Module, layer: str, **kwargs: Any) -> None: from captum.attr import LayerGradCam _layer: nn.Module = get_module(classifier, layer) self._grad_cam = LayerGradCam(forward_func=classifier, layer=_layer, **kwargs) def attribute( self, img: Tensor, target: Union[int, Sequence[int], Tensor], as_ndarray: bool = True, interpolate: bool = True, interpolate_args: Optional[Dict] = None, **kwargs: Any, ) -> Union[Tensor, ndarray, Tuple[Union[Tensor, ndarray]]]: height, width = img.shape[-2:] smap: Union[Tensor, Tuple[Tensor]] = self._grad_cam.attribute(img, target, **kwargs) smap_list = ( [ smap.detach(), ] if not isinstance(smap, (list, tuple)) else [s.detach() for s in smap] ) if interpolate: interpolate_args = dict() if interpolate_args is None else interpolate_args smap_list = [resize_img(s, output_shape=(height, width), **interpolate_args) for s in smap_list] if as_ndarray: smap_list = [s.detach().numpy() for s in smap_list] if len(smap_list) == 1: return smap_list[0] else: return tuple(smap_list)