|
import sys |
|
import cv2 |
|
import utils |
|
import numpy as np |
|
|
|
import torch |
|
from PIL import Image |
|
|
|
from utils import convert_state_dict |
|
from models import restormer_arch |
|
from data.preprocess.crop_merge_image import stride_integral |
|
|
|
sys.path.append("./data/MBD/") |
|
from data.MBD.infer import net1_net2_infer_single_im |
|
|
|
|
|
def dewarp_prompt(img): |
|
mask = net1_net2_infer_single_im(img, "data/MBD/checkpoint/mbd.pkl") |
|
base_coord = utils.getBasecoord(256, 256) / 256 |
|
img[mask == 0] = 0 |
|
mask = cv2.resize(mask, (256, 256)) / 255 |
|
return img, np.concatenate((base_coord, np.expand_dims(mask, -1)), -1) |
|
|
|
|
|
def deshadow_prompt(img): |
|
h, w = img.shape[:2] |
|
|
|
img = cv2.resize(img, (1024, 1024)) |
|
rgb_planes = cv2.split(img) |
|
result_planes = [] |
|
result_norm_planes = [] |
|
bg_imgs = [] |
|
for plane in rgb_planes: |
|
dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8)) |
|
bg_img = cv2.medianBlur(dilated_img, 21) |
|
bg_imgs.append(bg_img) |
|
diff_img = 255 - cv2.absdiff(plane, bg_img) |
|
norm_img = cv2.normalize( |
|
diff_img, |
|
None, |
|
alpha=0, |
|
beta=255, |
|
norm_type=cv2.NORM_MINMAX, |
|
dtype=cv2.CV_8UC1, |
|
) |
|
result_planes.append(diff_img) |
|
result_norm_planes.append(norm_img) |
|
bg_imgs = cv2.merge(bg_imgs) |
|
bg_imgs = cv2.resize(bg_imgs, (w, h)) |
|
|
|
result_norm = cv2.merge(result_norm_planes) |
|
result_norm[result_norm == 0] = 1 |
|
shadow_map = np.clip( |
|
img.astype(float) / result_norm.astype(float) * 255, 0, 255 |
|
).astype(np.uint8) |
|
shadow_map = cv2.resize(shadow_map, (w, h)) |
|
shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_BGR2GRAY) |
|
shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_GRAY2BGR) |
|
|
|
return bg_imgs |
|
|
|
|
|
def deblur_prompt(img): |
|
x = cv2.Sobel(img, cv2.CV_16S, 1, 0) |
|
y = cv2.Sobel(img, cv2.CV_16S, 0, 1) |
|
absX = cv2.convertScaleAbs(x) |
|
absY = cv2.convertScaleAbs(y) |
|
high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) |
|
high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY) |
|
high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_GRAY2BGR) |
|
return high_frequency |
|
|
|
|
|
def appearance_prompt(img): |
|
h, w = img.shape[:2] |
|
|
|
img = cv2.resize(img, (1024, 1024)) |
|
rgb_planes = cv2.split(img) |
|
result_planes = [] |
|
result_norm_planes = [] |
|
for plane in rgb_planes: |
|
dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8)) |
|
bg_img = cv2.medianBlur(dilated_img, 21) |
|
diff_img = 255 - cv2.absdiff(plane, bg_img) |
|
norm_img = cv2.normalize( |
|
diff_img, |
|
None, |
|
alpha=0, |
|
beta=255, |
|
norm_type=cv2.NORM_MINMAX, |
|
dtype=cv2.CV_8UC1, |
|
) |
|
result_planes.append(diff_img) |
|
result_norm_planes.append(norm_img) |
|
result_norm = cv2.merge(result_norm_planes) |
|
result_norm = cv2.resize(result_norm, (w, h)) |
|
return result_norm |
|
|
|
|
|
def binarization_promptv2(img): |
|
result, thresh = utils.SauvolaModBinarization(img) |
|
thresh = thresh.astype(np.uint8) |
|
result[result > 155] = 255 |
|
result[result <= 155] = 0 |
|
|
|
x = cv2.Sobel(img, cv2.CV_16S, 1, 0) |
|
y = cv2.Sobel(img, cv2.CV_16S, 0, 1) |
|
absX = cv2.convertScaleAbs(x) |
|
absY = cv2.convertScaleAbs(y) |
|
high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) |
|
high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY) |
|
return np.concatenate( |
|
( |
|
np.expand_dims(thresh, -1), |
|
np.expand_dims(high_frequency, -1), |
|
np.expand_dims(result, -1), |
|
), |
|
-1, |
|
) |
|
|
|
|
|
def dewarping(model, im_org, device): |
|
INPUT_SIZE = 256 |
|
im_masked, prompt_org = dewarp_prompt(im_org.copy()) |
|
|
|
h, w = im_masked.shape[:2] |
|
im_masked = im_masked.copy() |
|
im_masked = cv2.resize(im_masked, (INPUT_SIZE, INPUT_SIZE)) |
|
im_masked = im_masked / 255.0 |
|
im_masked = torch.from_numpy(im_masked.transpose(2, 0, 1)).unsqueeze(0) |
|
im_masked = im_masked.float().to(device) |
|
|
|
prompt = torch.from_numpy(prompt_org.transpose(2, 0, 1)).unsqueeze(0) |
|
prompt = prompt.float().to(device) |
|
|
|
in_im = torch.cat((im_masked, prompt), dim=1) |
|
|
|
|
|
base_coord = utils.getBasecoord(INPUT_SIZE, INPUT_SIZE) / INPUT_SIZE |
|
model = model.float() |
|
with torch.no_grad(): |
|
pred = model(in_im) |
|
pred = pred[0][:2].permute(1, 2, 0).cpu().numpy() |
|
pred = pred + base_coord |
|
|
|
for i in range(15): |
|
pred = cv2.blur(pred, (3, 3), borderType=cv2.BORDER_REPLICATE) |
|
pred = cv2.resize(pred, (w, h)) * (w, h) |
|
pred = pred.astype(np.float32) |
|
out_im = cv2.remap(im_org, pred[:, :, 0], pred[:, :, 1], cv2.INTER_LINEAR) |
|
|
|
prompt_org = (prompt_org * 255).astype(np.uint8) |
|
prompt_org = cv2.resize(prompt_org, im_org.shape[:2][::-1]) |
|
|
|
return prompt_org[:, :, 0], prompt_org[:, :, 1], prompt_org[:, :, 2], out_im |
|
|
|
|
|
def appearance(model, im_org, device): |
|
MAX_SIZE = 1600 |
|
|
|
h, w = im_org.shape[:2] |
|
prompt = appearance_prompt(im_org) |
|
in_im = np.concatenate((im_org, prompt), -1) |
|
|
|
|
|
if max(w, h) < MAX_SIZE: |
|
in_im, padding_h, padding_w = stride_integral(in_im, 8) |
|
else: |
|
in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE)) |
|
|
|
|
|
in_im = in_im / 255.0 |
|
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) |
|
|
|
|
|
in_im = in_im.half().to(device) |
|
model = model.half() |
|
with torch.no_grad(): |
|
pred = model(in_im) |
|
pred = torch.clamp(pred, 0, 1) |
|
pred = pred[0].permute(1, 2, 0).cpu().numpy() |
|
pred = (pred * 255).astype(np.uint8) |
|
|
|
if max(w, h) < MAX_SIZE: |
|
out_im = pred[padding_h:, padding_w:] |
|
else: |
|
pred[pred == 0] = 1 |
|
shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype( |
|
float |
|
) / pred.astype(float) |
|
shadow_map = cv2.resize(shadow_map, (w, h)) |
|
shadow_map[shadow_map == 0] = 0.00001 |
|
out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8) |
|
|
|
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im |
|
|
|
|
|
def deshadowing(model, im_org, device): |
|
MAX_SIZE = 1600 |
|
|
|
h, w = im_org.shape[:2] |
|
prompt = deshadow_prompt(im_org) |
|
in_im = np.concatenate((im_org, prompt), -1) |
|
|
|
|
|
if max(w, h) < MAX_SIZE: |
|
in_im, padding_h, padding_w = stride_integral(in_im, 8) |
|
else: |
|
in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE)) |
|
|
|
|
|
in_im = in_im / 255.0 |
|
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) |
|
|
|
|
|
in_im = in_im.half().to(device) |
|
model = model.half() |
|
with torch.no_grad(): |
|
pred = model(in_im) |
|
pred = torch.clamp(pred, 0, 1) |
|
pred = pred[0].permute(1, 2, 0).cpu().numpy() |
|
pred = (pred * 255).astype(np.uint8) |
|
|
|
if max(w, h) < MAX_SIZE: |
|
out_im = pred[padding_h:, padding_w:] |
|
else: |
|
pred[pred == 0] = 1 |
|
shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype( |
|
float |
|
) / pred.astype(float) |
|
shadow_map = cv2.resize(shadow_map, (w, h)) |
|
shadow_map[shadow_map == 0] = 0.00001 |
|
out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8) |
|
|
|
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im |
|
|
|
|
|
def deblurring(model, im_org, device): |
|
|
|
in_im, padding_h, padding_w = stride_integral(im_org, 8) |
|
prompt = deblur_prompt(in_im) |
|
in_im = np.concatenate((in_im, prompt), -1) |
|
in_im = in_im / 255.0 |
|
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) |
|
in_im = in_im.half().to(device) |
|
|
|
model.to(device) |
|
model.eval() |
|
model = model.half() |
|
with torch.no_grad(): |
|
pred = model(in_im) |
|
pred = torch.clamp(pred, 0, 1) |
|
pred = pred[0].permute(1, 2, 0).cpu().numpy() |
|
pred = (pred * 255).astype(np.uint8) |
|
out_im = pred[padding_h:, padding_w:] |
|
|
|
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im |
|
|
|
|
|
def binarization(model, im_org, device): |
|
im, padding_h, padding_w = stride_integral(im_org, 8) |
|
prompt = binarization_promptv2(im) |
|
h, w = im.shape[:2] |
|
in_im = np.concatenate((im, prompt), -1) |
|
|
|
in_im = in_im / 255.0 |
|
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) |
|
in_im = in_im.to(device) |
|
model = model.half() |
|
in_im = in_im.half() |
|
with torch.no_grad(): |
|
pred = model(in_im) |
|
pred = pred[:, :2, :, :] |
|
pred = torch.max(torch.softmax(pred, 1), 1)[1] |
|
pred = pred[0].cpu().numpy() |
|
pred = (pred * 255).astype(np.uint8) |
|
pred = cv2.resize(pred, (w, h)) |
|
out_im = pred[padding_h:, padding_w:] |
|
|
|
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im |
|
|
|
|
|
def model_init(model_path, device): |
|
|
|
model = restormer_arch.Restormer( |
|
inp_channels=6, |
|
out_channels=3, |
|
dim=48, |
|
num_blocks=[2, 3, 3, 4], |
|
num_refinement_blocks=4, |
|
heads=[1, 2, 4, 8], |
|
ffn_expansion_factor=2.66, |
|
bias=False, |
|
LayerNorm_type="WithBias", |
|
dual_pixel_task=True, |
|
) |
|
|
|
if device == "cpu": |
|
state = convert_state_dict( |
|
torch.load(model_path, map_location="cpu")["model_state"] |
|
) |
|
else: |
|
state = convert_state_dict( |
|
torch.load(model_path, map_location="cuda:0")["model_state"] |
|
) |
|
model.load_state_dict(state) |
|
|
|
model.eval() |
|
model = model.to(device) |
|
return model |
|
|
|
|
|
def resize(image, max_size): |
|
h, w = image.shape[:2] |
|
if max(h, w) > max_size: |
|
if h > w: |
|
h_new = max_size |
|
w_new = int(w * h_new / h) |
|
else: |
|
w_new = max_size |
|
h_new = int(h * w_new / w) |
|
pil_image = Image.fromarray(image) |
|
pil_image = pil_image.resize((w_new, h_new), Image.Resampling.LANCZOS) |
|
image = np.array(pil_image) |
|
return image |
|
|
|
|
|
def inference_one_image(model, image, tasks, device): |
|
|
|
|
|
if "dewarping" in tasks: |
|
*_, image = dewarping(model, image, device) |
|
|
|
|
|
if len(tasks) == 1 and "dewarping" in tasks: |
|
return image |
|
|
|
image = resize(image, 1536) |
|
|
|
if "deshadowing" in tasks: |
|
*_, image = deshadowing(model, image, device) |
|
if "appearance" in tasks: |
|
*_, image = appearance(model, image, device) |
|
if "deblurring" in tasks: |
|
*_, image = deblurring(model, image, device) |
|
if "binarization" in tasks: |
|
*_, image = binarization(model, image, device) |
|
|
|
return image |
|
|