"""Utility functions for image_to_fen module.""" import base64 import contextlib import hashlib from io import BytesIO import os from pathlib import Path from typing import Union from urllib.request import urlretrieve import numpy as np from PIL import Image import smart_open from tqdm import tqdm def to_categorical(y, num_classes): """1-hot encode a tensor.""" return np.eye(num_classes, dtype="uint8")[y] def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: with smart_open.open(image_uri, "rb") as image_file: return read_image_pil_file(image_file, grayscale) def read_image_pil_file(image_file, grayscale=False) -> Image: with Image.open(image_file) as image: if grayscale: image = image.convert(mode="L") else: image = image.convert(mode=image.mode) return image @contextlib.contextmanager def temporary_working_directory(working_dir: Union[str, Path]): """Temporarily switches to a directory, then returns to the original directory on exit.""" curdir = os.getcwd() os.chdir(working_dir) try: yield finally: os.chdir(curdir) def encode_b64_image(image, format="png"): """Encode a PIL image as a base64 string.""" _buffer = BytesIO() # bytes that live in memory image.save(_buffer, format=format) # but which we write to like a file encoded_image = base64.b64encode(_buffer.getvalue()).decode("utf8") return encoded_image def compute_sha256(filename: Union[Path, str]): """Return SHA256 checksum of a file.""" with open(filename, "rb") as f: return hashlib.sha256(f.read()).hexdigest() class TqdmUpTo(tqdm): """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" def update_to(self, blocks=1, bsize=1, tsize=None): """ Parameters ---------- blocks: int, optional Number of blocks transferred so far [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: self.total = tsize self.update(blocks * bsize - self.n) # will also set self.n = b * bsize def download_url(url, filename): """Download a file from url to filename, with a progress bar.""" with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310 # the function takes the original prediction and the iou threshold. # function to convert a torchtensor back to PIL image def torch_to_pil(img): return torchvision.transforms.ToPILImage()(img).convert('RGB')