Source code for saliency_metrics.utils.resize_img
from typing import Any, Tuple, TypeVar
import cv2
import numpy as np
import torch
from torch.nn.functional import interpolate
T = TypeVar("T", torch.Tensor, np.ndarray)
[docs]def resize_img(img: T, output_shape: Tuple[int, int], **kwargs: Any) -> T:
"""Resize and interpolate the image to the given shape.
This function simply calls ``torch.nn.functional.interpolate`` if ``img`` is a ``torch.Tensor``,
or ``cv2.resize`` if ``img`` is a ``numpy.ndarray``.
.. note::
If ``img`` is a ``numpy.ndarray``, then its data type must be ``numpy.uint8``.
Args:
img: Input image. Can be ``torch.Tensor`` with shape ``(num_samples, num_channels, height, width)`` or
``numpy.ndarray`` with shape ``(height, width, 3)`` or ``(height, width)`` .
output_shape: output shape in the format of ``(out_height, out_width)``.
**kwargs: other interpolation arguments. See also `interpolate`_ or `resize`_.
Returns:
The interpolated image.
.. _interpolate: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html#torch.nn.functional.interpolate # noqa
.. _resize: https://docs.opencv.org/4.x/da/d54/group__imgproc__transform.html#ga47a974309e9102f5f08231edc7e7529d
"""
if isinstance(img, torch.Tensor):
return interpolate(img, size=output_shape, **kwargs)
else:
return cv2.resize(img, dsize=output_shape, **kwargs)