Spaces:
Runtime error
Runtime error
#! /usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# File : patch_match.py | |
# Author : Jiayuan Mao | |
# Email : [email protected] | |
# Date : 01/09/2020 | |
# | |
# Distributed under terms of the MIT license. | |
import ctypes | |
import os.path as osp | |
from typing import Optional, Union | |
import numpy as np | |
from PIL import Image | |
import os | |
if os.name!="nt": | |
# Otherwise, fall back to the subprocess. | |
import subprocess | |
print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__)))) | |
# subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__)) | |
subprocess.check_call("make clean && make", cwd=osp.dirname(__file__), shell=True) | |
__all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity'] | |
class CShapeT(ctypes.Structure): | |
_fields_ = [ | |
('width', ctypes.c_int), | |
('height', ctypes.c_int), | |
('channels', ctypes.c_int), | |
] | |
class CMatT(ctypes.Structure): | |
_fields_ = [ | |
('data_ptr', ctypes.c_void_p), | |
('shape', CShapeT), | |
('dtype', ctypes.c_int) | |
] | |
import tempfile | |
from urllib.request import urlopen, Request | |
import shutil | |
from pathlib import Path | |
from tqdm import tqdm | |
def download_url_to_file(url, dst, hash_prefix=None, progress=True): | |
r"""Download object at the given URL to a local path. | |
Args: | |
url (string): URL of the object to download | |
dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file`` | |
hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. | |
Default: None | |
progress (bool, optional): whether or not to display a progress bar to stderr | |
Default: True | |
https://pytorch.org/docs/stable/_modules/torch/hub.html#load_state_dict_from_url | |
""" | |
file_size = None | |
req = Request(url) | |
u = urlopen(req) | |
meta = u.info() | |
if hasattr(meta, 'getheaders'): | |
content_length = meta.getheaders("Content-Length") | |
else: | |
content_length = meta.get_all("Content-Length") | |
if content_length is not None and len(content_length) > 0: | |
file_size = int(content_length[0]) | |
# We deliberately save it in a temp file and move it after | |
# download is complete. This prevents a local working checkpoint | |
# being overridden by a broken download. | |
dst = os.path.expanduser(dst) | |
dst_dir = os.path.dirname(dst) | |
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) | |
try: | |
with tqdm(total=file_size, disable=not progress, | |
unit='B', unit_scale=True, unit_divisor=1024) as pbar: | |
while True: | |
buffer = u.read(8192) | |
if len(buffer) == 0: | |
break | |
f.write(buffer) | |
pbar.update(len(buffer)) | |
f.close() | |
shutil.move(f.name, dst) | |
finally: | |
f.close() | |
if os.path.exists(f.name): | |
os.remove(f.name) | |
if os.name!="nt": | |
PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.so')) | |
else: | |
if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')): | |
download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll",dst=osp.join(osp.dirname(__file__), 'libpatchmatch.dll')) | |
if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')): | |
download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll",dst=osp.join(osp.dirname(__file__), 'opencv_world460.dll')) | |
if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')): | |
print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll and put it into the PyPatchMatch folder") | |
if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')): | |
print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll and put it into the PyPatchMatch folder") | |
PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')) | |
PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint] | |
PMLIB.PM_set_verbose.argtypes = [ctypes.c_int] | |
PMLIB.PM_free_pymat.argtypes = [CMatT] | |
PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int] | |
PMLIB.PM_inpaint.restype = CMatT | |
PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float] | |
PMLIB.PM_inpaint_regularity.restype = CMatT | |
PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int] | |
PMLIB.PM_inpaint2.restype = CMatT | |
PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float] | |
PMLIB.PM_inpaint2_regularity.restype = CMatT | |
def set_random_seed(seed: int): | |
PMLIB.PM_set_random_seed(ctypes.c_uint(seed)) | |
def set_verbose(verbose: bool): | |
PMLIB.PM_set_verbose(ctypes.c_int(verbose)) | |
def inpaint( | |
image: Union[np.ndarray, Image.Image], | |
mask: Optional[Union[np.ndarray, Image.Image]] = None, | |
*, | |
global_mask: Optional[Union[np.ndarray, Image.Image]] = None, | |
patch_size: int = 15 | |
) -> np.ndarray: | |
""" | |
PatchMatch based inpainting proposed in: | |
PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing | |
C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman | |
SIGGRAPH 2009 | |
Args: | |
image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR. | |
mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel. | |
If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255). | |
global_mask (Union[np.array, Image.Image], optional): the target mask of the output image. | |
patch_size (int): the patch size for the inpainting algorithm. | |
Return: | |
result (np.ndarray): the repaired image, of the same size as the input image. | |
""" | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
image = np.ascontiguousarray(image) | |
assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8' | |
if mask is None: | |
mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8') | |
mask = np.ascontiguousarray(mask) | |
else: | |
mask = _canonize_mask_array(mask) | |
if global_mask is None: | |
ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size)) | |
else: | |
global_mask = _canonize_mask_array(global_mask) | |
ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size)) | |
ret_npmat = pymat_to_np(ret_pymat) | |
PMLIB.PM_free_pymat(ret_pymat) | |
return ret_npmat | |
def inpaint_regularity( | |
image: Union[np.ndarray, Image.Image], | |
mask: Optional[Union[np.ndarray, Image.Image]], | |
ijmap: np.ndarray, | |
*, | |
global_mask: Optional[Union[np.ndarray, Image.Image]] = None, | |
patch_size: int = 15, guide_weight: float = 0.25 | |
) -> np.ndarray: | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
image = np.ascontiguousarray(image) | |
assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32' | |
ijmap = np.ascontiguousarray(ijmap) | |
assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8' | |
if mask is None: | |
mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8') | |
mask = np.ascontiguousarray(mask) | |
else: | |
mask = _canonize_mask_array(mask) | |
if global_mask is None: | |
ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight)) | |
else: | |
global_mask = _canonize_mask_array(global_mask) | |
ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight)) | |
ret_npmat = pymat_to_np(ret_pymat) | |
PMLIB.PM_free_pymat(ret_pymat) | |
return ret_npmat | |
def _canonize_mask_array(mask): | |
if isinstance(mask, Image.Image): | |
mask = np.array(mask) | |
if mask.ndim == 2 and mask.dtype == 'uint8': | |
mask = mask[..., np.newaxis] | |
assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8' | |
return np.ascontiguousarray(mask) | |
dtype_pymat_to_ctypes = [ | |
ctypes.c_uint8, | |
ctypes.c_int8, | |
ctypes.c_uint16, | |
ctypes.c_int16, | |
ctypes.c_int32, | |
ctypes.c_float, | |
ctypes.c_double, | |
] | |
dtype_np_to_pymat = { | |
'uint8': 0, | |
'int8': 1, | |
'uint16': 2, | |
'int16': 3, | |
'int32': 4, | |
'float32': 5, | |
'float64': 6, | |
} | |
def np_to_pymat(npmat): | |
assert npmat.ndim == 3 | |
return CMatT( | |
ctypes.cast(npmat.ctypes.data, ctypes.c_void_p), | |
CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]), | |
dtype_np_to_pymat[str(npmat.dtype)] | |
) | |
def pymat_to_np(pymat): | |
npmat = np.ctypeslib.as_array( | |
ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])), | |
(pymat.shape.height, pymat.shape.width, pymat.shape.channels) | |
) | |
ret = np.empty(npmat.shape, npmat.dtype) | |
ret[:] = npmat | |
return ret | |