import gradio as gr import torch import torch.nn.functional as F import cv2 import kornia import numpy as np def min_(items): current = items[0] for item in items[1:]: current = torch.minimum(current, item) return current def max_(items): current = items[0] for item in items[1:]: current = torch.maximum(current, item) return current def apply_cas(image, amount): if image is None: return None # Convert to torch tensor and normalize image = torch.from_numpy(image).float() / 255.0 # Add batch dimension and rearrange to BCHW image = image.unsqueeze(0).permute(0, 3, 1, 2) epsilon = 1e-5 img = F.pad(image, pad=(1, 1, 1, 1)) a = img[..., :-2, :-2] b = img[..., :-2, 1:-1] c = img[..., :-2, 2:] d = img[..., 1:-1, :-2] e = img[..., 1:-1, 1:-1] f = img[..., 1:-1, 2:] g = img[..., 2:, :-2] h = img[..., 2:, 1:-1] i = img[..., 2:, 2:] cross = (b, d, e, f, h) mn = min_(cross) mx = max_(cross) diag = (a, c, g, i) mn2 = min_(diag) mx2 = max_(diag) mx = mx + mx2 mn = mn + mn2 inv_mx = torch.reciprocal(mx + epsilon) amp = inv_mx * torch.minimum(mn, (2 - mx)) amp = torch.sqrt(amp) w = - amp * (amount * (1/5 - 1/8) + 1/8) div = torch.reciprocal(1 + 4*w) output = ((b + d + f + h)*w + e) * div output = output.clamp(0, 1) # Convert back to HWC format and to uint8 output = output.squeeze(0).permute(1, 2, 0) output = (output.numpy() * 255).astype(np.uint8) return output def apply_smart_sharpen(image, noise_radius, preserve_edges, sharpen, ratio): if image is None: return None # Convert to torch tensor and normalize image = torch.from_numpy(image).float() / 255.0 if preserve_edges > 0: preserve_edges = max(1 - preserve_edges, 0.05) # Apply bilateral filter for noise reduction if noise_radius > 1: sigma = 0.3 * ((noise_radius - 1) * 0.5 - 1) + 0.8 blurred = cv2.bilateralFilter(image.numpy(), noise_radius, preserve_edges, sigma) blurred = torch.from_numpy(blurred) else: blurred = image # Apply sharpening if sharpen > 0: img_chw = image.permute(2, 0, 1).unsqueeze(0) # Add batch dimension sharpened = kornia.enhance.sharpness(img_chw, sharpen).squeeze(0).permute(1, 2, 0) else: sharpened = image # Blend results result = ratio * sharpened + (1 - ratio) * blurred result = torch.clamp(result, 0, 1) # Convert back to uint8 output = (result.numpy() * 255).astype(np.uint8) return output def create_sharpen_tab(): with gr.Tab("Sharpening"): with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", height=256) with gr.Tab("Smart Sharpen"): noise_radius = gr.Slider( minimum=1, maximum=25, value=7, step=1, label="Noise Radius" ) preserve_edges = gr.Slider( minimum=0.0, maximum=1.0, value=0.75, step=0.05, label="Preserve Edges" ) sharpen = gr.Slider( minimum=0.0, maximum=25.0, value=5.0, step=0.5, label="Sharpen Amount" ) ratio = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.1, label="Blend Ratio" ) smart_btn = gr.Button("Apply Smart Sharpen") with gr.Tabs(): with gr.Tab("CAS"): cas_amount = gr.Slider( minimum=0.0, maximum=1.0, value=0.8, step=0.05, label="Amount" ) cas_btn = gr.Button("Apply CAS") with gr.Column(): output_image = gr.Image(label="Sharpened Image") smart_btn.click( fn=apply_smart_sharpen, inputs=[input_image, noise_radius, preserve_edges, sharpen, ratio], outputs=output_image ) cas_btn.click( fn=apply_cas, inputs=[input_image, cas_amount], outputs=output_image )