resefa / utils /image_utils.py
akhaliq's picture
akhaliq HF staff
add files
8ca3a29
# python3.7
"""Contains utility functions for image processing.
The module is primarily built on `cv2`. But, differently, we assume all colorful
images are with `RGB` channel order by default. Also, we assume all gray-scale
images to be with shape [height, width, 1].
"""
import os
import cv2
import numpy as np
from .misc import IMAGE_EXTENSIONS
from .misc import check_file_ext
__all__ = [
'get_blank_image', 'load_image', 'save_image', 'resize_image',
'add_text_to_image', 'preprocess_image', 'postprocess_image',
'parse_image_size', 'get_grid_shape', 'list_images_from_dir'
]
def _check_2d_image(image):
"""Checks whether a given image is valid.
A valid image is expected to be with dtype `uint8`. Also, it should have
shape like:
(1) (height, width, 1) # gray-scale image.
(2) (height, width, 3) # colorful image.
(3) (height, width, 4) # colorful image with transparency (RGBA)
"""
assert isinstance(image, np.ndarray)
assert image.dtype == np.uint8
assert image.ndim == 3 and image.shape[2] in [1, 3, 4]
def get_blank_image(height, width, channels=3, use_black=True):
"""Gets a blank image, either white of black.
NOTE: This function will always return an image with `RGB` channel order for
color image and pixel range [0, 255].
Args:
height: Height of the returned image.
width: Width of the returned image.
channels: Number of channels. (default: 3)
use_black: Whether to return a black image. (default: True)
"""
shape = (height, width, channels)
if use_black:
return np.zeros(shape, dtype=np.uint8)
return np.ones(shape, dtype=np.uint8) * 255
def load_image(path):
"""Loads an image from disk.
NOTE: This function will always return an image with `RGB` channel order for
color image and pixel range [0, 255].
Args:
path: Path to load the image from.
Returns:
An image with dtype `np.ndarray`, or `None` if `path` does not exist.
"""
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
return None
if image.ndim == 2:
image = image[:, :, np.newaxis]
_check_2d_image(image)
if image.shape[2] == 3:
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if image.shape[2] == 4:
return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
def save_image(path, image):
"""Saves an image to disk.
NOTE: The input image (if colorful) is assumed to be with `RGB` channel
order and pixel range [0, 255].
Args:
path: Path to save the image to.
image: Image to save.
"""
if image is None:
return
_check_2d_image(image)
if image.shape[2] == 1:
cv2.imwrite(path, image)
elif image.shape[2] == 3:
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
elif image.shape[2] == 4:
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA))
def resize_image(image, *args, **kwargs):
"""Resizes image.
This is a wrap of `cv2.resize()`.
NOTE: The channel order of the input image will not be changed.
Args:
image: Image to resize.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
An image with dtype `np.ndarray`, or `None` if `image` is empty.
"""
if image is None:
return None
_check_2d_image(image)
if image.shape[2] == 1: # Re-expand the squeezed dim of gray-scale image.
return cv2.resize(image, *args, **kwargs)[:, :, np.newaxis]
return cv2.resize(image, *args, **kwargs)
def add_text_to_image(image,
text='',
position=None,
font=cv2.FONT_HERSHEY_TRIPLEX,
font_size=1.0,
line_type=cv2.LINE_8,
line_width=1,
color=(255, 255, 255)):
"""Overlays text on given image.
NOTE: The input image is assumed to be with `RGB` channel order.
Args:
image: The image to overlay text on.
text: Text content to overlay on the image. (default: empty)
position: Target position (bottom-left corner) to add text. If not set,
center of the image will be used by default. (default: None)
font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX)
font_size: Font size of the text added. (default: 1.0)
line_type: Line type used to depict the text. (default: cv2.LINE_8)
line_width: Line width used to depict the text. (default: 1)
color: Color of the text added in `RGB` channel order. (default:
(255, 255, 255))
Returns:
An image with target text overlaid on.
"""
if image is None or not text:
return image
_check_2d_image(image)
cv2.putText(img=image,
text=text,
org=position,
fontFace=font,
fontScale=font_size,
color=color,
thickness=line_width,
lineType=line_type,
bottomLeftOrigin=False)
return image
def preprocess_image(image, min_val=-1.0, max_val=1.0):
"""Pre-processes image by adjusting the pixel range and to dtype `float32`.
This function is particularly used to convert an image or a batch of images
to `NCHW` format, which matches the data type commonly used in deep models.
NOTE: The input image is assumed to be with pixel range [0, 255] and with
format `HWC` or `NHWC`. The returned image will be always be with format
`NCHW`.
Args:
image: The input image for pre-processing.
min_val: Minimum value of the output image.
max_val: Maximum value of the output image.
Returns:
The pre-processed image.
"""
assert isinstance(image, np.ndarray)
image = image.astype(np.float64)
image = image / 255.0 * (max_val - min_val) + min_val
if image.ndim == 3:
image = image[np.newaxis]
assert image.ndim == 4 and image.shape[3] in [1, 3, 4]
return image.transpose(0, 3, 1, 2)
def postprocess_image(image, min_val=-1.0, max_val=1.0):
"""Post-processes image to pixel range [0, 255] with dtype `uint8`.
This function is particularly used to handle the results produced by deep
models.
NOTE: The input image is assumed to be with format `NCHW`, and the returned
image will always be with format `NHWC`.
Args:
image: The input image for post-processing.
min_val: Expected minimum value of the input image.
max_val: Expected maximum value of the input image.
Returns:
The post-processed image.
"""
assert isinstance(image, np.ndarray)
image = image.astype(np.float64)
image = (image - min_val) / (max_val - min_val) * 255
image = np.clip(image + 0.5, 0, 255).astype(np.uint8)
assert image.ndim == 4 and image.shape[1] in [1, 3, 4]
return image.transpose(0, 2, 3, 1)
def parse_image_size(obj):
"""Parses an object to a pair of image size, i.e., (height, width).
Args:
obj: The input object to parse image size from.
Returns:
A two-element tuple, indicating image height and width respectively.
Raises:
If the input is invalid, i.e., neither a list or tuple, nor a string.
"""
if obj is None or obj == '':
height = 0
width = 0
elif isinstance(obj, int):
height = obj
width = obj
elif isinstance(obj, (list, tuple, str, np.ndarray)):
if isinstance(obj, str):
splits = obj.replace(' ', '').split(',')
numbers = tuple(map(int, splits))
else:
numbers = tuple(obj)
if len(numbers) == 0:
height = 0
width = 0
elif len(numbers) == 1:
height = int(numbers[0])
width = int(numbers[0])
elif len(numbers) == 2:
height = int(numbers[0])
width = int(numbers[1])
else:
raise ValueError('At most two elements for image size.')
else:
raise ValueError(f'Invalid type of input: `{type(obj)}`!')
return (max(0, height), max(0, width))
def get_grid_shape(size, height=0, width=0, is_portrait=False):
"""Gets the shape of a grid based on the size.
This function makes greatest effort on making the output grid square if
neither `height` nor `width` is set. If `is_portrait` is set as `False`, the
height will always be equal to or smaller than the width. For example, if
input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`,
output shape will be (3, 5). Otherwise, the height will always be equal to
or larger than the width.
Args:
size: Size (height * width) of the target grid.
height: Expected height. If `size % height != 0`, this field will be
ignored. (default: 0)
width: Expected width. If `size % width != 0`, this field will be
ignored. (default: 0)
is_portrait: Whether to return a portrait size of a landscape size.
(default: False)
Returns:
A two-element tuple, representing height and width respectively.
"""
assert isinstance(size, int)
assert isinstance(height, int)
assert isinstance(width, int)
if size <= 0:
return (0, 0)
if height > 0 and width > 0 and height * width != size:
height = 0
width = 0
if height > 0 and width > 0 and height * width == size:
return (height, width)
if height > 0 and size % height == 0:
return (height, size // height)
if width > 0 and size % width == 0:
return (size // width, width)
height = int(np.sqrt(size))
while height > 0:
if size % height == 0:
width = size // height
break
height = height - 1
return (width, height) if is_portrait else (height, width)
def list_images_from_dir(directory):
"""Lists all images from the given directory.
NOTE: Do NOT support finding images recursively.
Args:
directory: The directory to find images from.
Returns:
A list of sorted filenames, with the directory as prefix.
"""
image_list = []
for filename in os.listdir(directory):
if check_file_ext(filename, *IMAGE_EXTENSIONS):
image_list.append(os.path.join(directory, filename))
return sorted(image_list)