dikdimon's picture
Upload extensions using SD-Hub extension
f4a41d8 verified
raw
history blame
24.8 kB
import os
import numpy as np
from rembg import remove, new_session
from PIL import Image, ImageOps, ImageFilter, ImageEnhance
import cv2
from tqdm import tqdm
import gradio as gr
from modules import script_callbacks, shared
import torch
import tempfile
class GeekyRemB:
def __init__(self):
self.session = None
def apply_chroma_key(self, image, color, threshold, color_tolerance=20):
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
if color == "green":
lower = np.array([40 - color_tolerance, 40, 40])
upper = np.array([80 + color_tolerance, 255, 255])
elif color == "blue":
lower = np.array([90 - color_tolerance, 40, 40])
upper = np.array([130 + color_tolerance, 255, 255])
elif color == "red":
lower = np.array([0, 40, 40])
upper = np.array([20 + color_tolerance, 255, 255])
else:
return np.zeros(image.shape[:2], dtype=np.uint8)
mask = cv2.inRange(hsv, lower, upper)
mask = 255 - cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)[1]
return mask
def process_mask(self, mask, invert_mask, feather_amount, mask_blur, mask_expansion):
if invert_mask:
mask = 255 - mask
if mask_expansion != 0:
kernel = np.ones((abs(mask_expansion), abs(mask_expansion)), np.uint8)
if mask_expansion > 0:
mask = cv2.dilate(mask, kernel, iterations=1)
else:
mask = cv2.erode(mask, kernel, iterations=1)
if feather_amount > 0:
mask = cv2.GaussianBlur(mask, (0, 0), sigmaX=feather_amount)
if mask_blur > 0:
mask = cv2.GaussianBlur(mask, (0, 0), sigmaX=mask_blur)
return mask
def remove_background(self, image, background_image, model, alpha_matting, alpha_matting_foreground_threshold,
alpha_matting_background_threshold, post_process_mask, chroma_key, chroma_threshold,
color_tolerance, background_mode, background_color, output_format="RGBA",
invert_mask=False, feather_amount=0, edge_detection=False,
edge_thickness=1, edge_color="#FFFFFF", shadow=False, shadow_blur=5,
shadow_opacity=0.5, color_adjustment=False, brightness=1.0, contrast=1.0,
saturation=1.0, x_position=0, y_position=0, rotation=0, opacity=1.0,
flip_horizontal=False, flip_vertical=False, mask_blur=0, mask_expansion=0,
foreground_scale=1.0, foreground_aspect_ratio=None, remove_bg=True,
use_custom_dimensions=False, custom_width=None, custom_height=None,
output_dimension_source="Foreground"):
if self.session is None or self.session.model_name != model:
self.session = new_session(model)
bg_color = tuple(int(background_color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) + (255,)
edge_color = tuple(int(edge_color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
pil_image = image if isinstance(image, Image.Image) else Image.fromarray(np.clip(255. * image[0].cpu().numpy(), 0, 255).astype(np.uint8))
original_image = np.array(pil_image)
if chroma_key != "none":
chroma_mask = self.apply_chroma_key(original_image, chroma_key, chroma_threshold, color_tolerance)
input_mask = chroma_mask
else:
input_mask = None
if remove_bg:
removed_bg = remove(
pil_image,
session=self.session,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
post_process_mask=post_process_mask,
)
rembg_mask = np.array(removed_bg)[:,:,3]
else:
removed_bg = pil_image.convert("RGBA")
rembg_mask = np.full(pil_image.size[::-1], 255, dtype=np.uint8)
if input_mask is not None:
final_mask = cv2.bitwise_and(rembg_mask, input_mask)
else:
final_mask = rembg_mask
final_mask = self.process_mask(final_mask, invert_mask, feather_amount, mask_blur, mask_expansion)
orig_width, orig_height = pil_image.size
bg_width, bg_height = background_image.size if background_image else (orig_width, orig_height)
if use_custom_dimensions and custom_width and custom_height:
output_width, output_height = int(custom_width), int(custom_height)
elif output_dimension_source == "Background" and background_image:
output_width, output_height = bg_width, bg_height
else:
output_width, output_height = orig_width, orig_height
new_width = int(orig_width * foreground_scale)
if foreground_aspect_ratio is not None:
new_height = int(new_width / foreground_aspect_ratio)
else:
new_height = int(orig_height * foreground_scale)
fg_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
fg_mask = Image.fromarray(final_mask).resize((new_width, new_height), Image.LANCZOS)
if background_mode == "transparent":
result = Image.new("RGBA", (output_width, output_height), (0, 0, 0, 0))
elif background_mode == "color":
result = Image.new("RGBA", (output_width, output_height), bg_color)
else: # background_mode == "image"
if background_image is not None:
result = background_image.resize((output_width, output_height), Image.LANCZOS).convert("RGBA")
else:
result = Image.new("RGBA", (output_width, output_height), (0, 0, 0, 0))
if flip_horizontal:
fg_image = fg_image.transpose(Image.FLIP_LEFT_RIGHT)
fg_mask = fg_mask.transpose(Image.FLIP_LEFT_RIGHT)
if flip_vertical:
fg_image = fg_image.transpose(Image.FLIP_TOP_BOTTOM)
fg_mask = fg_mask.transpose(Image.FLIP_TOP_BOTTOM)
fg_image = fg_image.rotate(rotation, resample=Image.BICUBIC, expand=True)
fg_mask = fg_mask.rotate(rotation, resample=Image.BICUBIC, expand=True)
paste_x = x_position + (output_width - fg_image.width) // 2
paste_y = y_position + (output_height - fg_image.height) // 2
fg_rgba = fg_image.convert("RGBA")
fg_with_opacity = Image.new("RGBA", fg_rgba.size, (0, 0, 0, 0))
for x in range(fg_rgba.width):
for y in range(fg_rgba.height):
r, g, b, a = fg_rgba.getpixel((x, y))
fg_with_opacity.putpixel((x, y), (r, g, b, int(a * opacity)))
fg_mask_with_opacity = fg_mask.point(lambda p: int(p * opacity))
result.paste(fg_with_opacity, (paste_x, paste_y), fg_mask_with_opacity)
if edge_detection:
edge_mask = cv2.Canny(np.array(fg_mask), 100, 200)
edge_mask = cv2.dilate(edge_mask, np.ones((edge_thickness, edge_thickness), np.uint8), iterations=1)
edge_overlay = Image.new("RGBA", (output_width, output_height), (0, 0, 0, 0))
edge_overlay.paste(Image.new("RGB", fg_image.size, edge_color), (paste_x, paste_y), Image.fromarray(edge_mask))
result = Image.alpha_composite(result, edge_overlay)
if shadow:
shadow_mask = fg_mask.filter(ImageFilter.GaussianBlur(shadow_blur))
shadow_image = Image.new("RGBA", (output_width, output_height), (0, 0, 0, 0))
shadow_image.paste((0, 0, 0, int(255 * shadow_opacity)), (paste_x, paste_y), shadow_mask)
result = Image.alpha_composite(result, shadow_image.filter(ImageFilter.GaussianBlur(shadow_blur)))
if color_adjustment:
enhancer = ImageEnhance.Brightness(result)
result = enhancer.enhance(brightness)
enhancer = ImageEnhance.Contrast(result)
result = enhancer.enhance(contrast)
enhancer = ImageEnhance.Color(result)
result = enhancer.enhance(saturation)
if output_format == "RGB":
result = result.convert("RGB")
return result, fg_mask
def process_frame(self, frame, *args):
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
processed_frame, _ = self.remove_background(pil_frame, *args)
return cv2.cvtColor(np.array(processed_frame), cv2.COLOR_RGB2BGR)
def process_video(self, input_path, output_path, background_video_path, *args):
cap = cv2.VideoCapture(input_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if background_video_path:
bg_cap = cv2.VideoCapture(background_video_path)
bg_total_frames = int(bg_cap.get(cv2.CAP_PROP_FRAME_COUNT))
for frame_num in tqdm(range(total_frames), desc="Processing video"):
ret, frame = cap.read()
if not ret:
break
if background_video_path:
bg_frame_num = frame_num % bg_total_frames
bg_cap.set(cv2.CAP_PROP_POS_FRAMES, bg_frame_num)
bg_ret, bg_frame = bg_cap.read()
if bg_ret:
bg_frame_resized = cv2.resize(bg_frame, (width, height))
args = list(args)
args[1] = Image.fromarray(cv2.cvtColor(bg_frame_resized, cv2.COLOR_BGR2RGB))
args = tuple(args)
processed_frame = self.process_frame(frame, *args)
out.write(processed_frame)
cap.release()
if background_video_path:
bg_cap.release()
out.release()
# Convert output video to MP4 container
temp_output = output_path + "_temp.mp4"
os.rename(output_path, temp_output)
os.system(f"ffmpeg -i {temp_output} -c copy {output_path}")
os.remove(temp_output)
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as geeky_remb_tab:
gr.Markdown("# GeekyRemB: Background Removal and Image/Video Manipulation")
with gr.Row():
with gr.Column(scale=1):
input_type = gr.Radio(["Image", "Video"], label="Input Type", value="Image")
foreground_input = gr.Image(label="Foreground Image", type="pil", visible=True)
foreground_video = gr.Video(label="Foreground Video", visible=False)
with gr.Group():
gr.Markdown("### Foreground Adjustments")
foreground_scale = gr.Slider(label="Scale", minimum=0.1, maximum=5.0, value=1.0, step=0.1)
foreground_aspect_ratio = gr.Slider(label="Aspect Ratio", minimum=0.1, maximum=10.0, value=1.0, step=0.1)
x_position = gr.Slider(label="X Position", minimum=-1000, maximum=1000, value=0, step=1)
y_position = gr.Slider(label="Y Position", minimum=-1000, maximum=1000, value=0, step=1)
rotation = gr.Slider(label="Rotation", minimum=-360, maximum=360, value=0, step=0.1)
opacity = gr.Slider(label="Opacity", minimum=0.0, maximum=1.0, value=1.0, step=0.01)
flip_horizontal = gr.Checkbox(label="Flip Horizontal", value=False)
flip_vertical = gr.Checkbox(label="Flip Vertical", value=False)
with gr.Column(scale=1):
result_type = gr.Radio(["Image", "Video"], label="Output Type", value="Image")
result_image = gr.Image(label="Result Image", type="pil", visible=True)
result_video = gr.Video(label="Result Video", visible=False)
with gr.Group():
gr.Markdown("### Background Options")
remove_background = gr.Checkbox(label="Remove Background", value=True)
background_mode = gr.Radio(label="Background Mode", choices=["transparent", "color", "image", "video"], value="transparent")
background_color = gr.ColorPicker(label="Background Color", value="#000000", visible=False)
background_image = gr.Image(label="Background Image", type="pil", visible=False)
background_video = gr.Video(label="Background Video", visible=False)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
with gr.Column():
gr.Markdown("### Removal Settings")
model = gr.Dropdown(label="Model", choices=["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use", "isnet-anime"], value="u2net")
output_format = gr.Radio(label="Output Format", choices=["RGBA", "RGB"], value="RGBA")
alpha_matting = gr.Checkbox(label="Alpha Matting", value=False)
alpha_matting_foreground_threshold = gr.Slider(label="Alpha Matting Foreground Threshold", minimum=0, maximum=255, value=240, step=1)
alpha_matting_background_threshold = gr.Slider(label="Alpha Matting Background Threshold", minimum=0, maximum=255, value=10, step=1)
post_process_mask = gr.Checkbox(label="Post Process Mask", value=False)
with gr.Column():
gr.Markdown("### Chroma Key Settings")
chroma_key = gr.Dropdown(label="Chroma Key", choices=["none", "green", "blue", "red"], value="none")
chroma_threshold = gr.Slider(label="Chroma Threshold", minimum=0, maximum=255, value=30, step=1)
color_tolerance = gr.Slider(label="Color Tolerance", minimum=0, maximum=255, value=20, step=1)
with gr.Column():
gr.Markdown("### Effects")
invert_mask = gr.Checkbox(label="Invert Mask", value=False)
feather_amount = gr.Slider(label="Feather Amount", minimum=0, maximum=100, value=0, step=1)
edge_detection = gr.Checkbox(label="Edge Detection", value=False)
edge_thickness = gr.Slider(label="Edge Thickness", minimum=1, maximum=10, value=1, step=1)
edge_color = gr.ColorPicker(label="Edge Color", value="#FFFFFF")
shadow = gr.Checkbox(label="Shadow", value=False)
shadow_blur = gr.Slider(label="Shadow Blur", minimum=0, maximum=20, value=5, step=1)
shadow_opacity = gr.Slider(label="Shadow Opacity", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
color_adjustment = gr.Checkbox(label="Color Adjustment", value=False)
brightness = gr.Slider(label="Brightness", minimum=0.0, maximum=2.0, value=1.0, step=0.1)
contrast = gr.Slider(label="Contrast", minimum=0.0, maximum=2.0, value=1.0, step=0.1)
saturation = gr.Slider(label="Saturation", minimum=0.0, maximum=2.0, value=1.0, step=0.1)
mask_blur = gr.Slider(label="Mask Blur", minimum=0, maximum=100, value=0, step=1)
mask_expansion = gr.Slider(label="Mask Expansion", minimum=-100, maximum=100, value=0, step=1)
with gr.Row():
gr.Markdown("### Output Settings")
image_format = gr.Dropdown(label="Image Format", choices=["PNG", "JPEG", "WEBP"], value="PNG")
video_format = gr.Dropdown(label="Video Format", choices=["MP4", "AVI", "MOV"], value="MP4")
video_quality = gr.Slider(label="Video Quality", minimum=0, maximum=100, value=95, step=1)
use_custom_dimensions = gr.Checkbox(label="Use Custom Dimensions", value=False)
custom_width = gr.Number(label="Custom Width", value=512, visible=False)
custom_height = gr.Number(label="Custom Height", value=512, visible=False)
output_dimension_source = gr.Radio(
label="Output Dimension Source",
choices=["Foreground", "Background"],
value="Foreground",
visible=True
)
run_button = gr.Button(label="Run GeekyRemB")
def update_input_type(choice):
return {
foreground_input: gr.update(visible=choice == "Image"),
foreground_video: gr.update(visible=choice == "Video"),
}
def update_output_type(choice):
return {
result_image: gr.update(visible=choice == "Image"),
result_video: gr.update(visible=choice == "Video"),
}
def update_background_mode(mode):
return {
background_color: gr.update(visible=mode == "color"),
background_image: gr.update(visible=mode == "image"),
background_video: gr.update(visible=mode == "video"),
}
def update_custom_dimensions(use_custom):
return {
custom_width: gr.update(visible=use_custom),
custom_height: gr.update(visible=use_custom),
output_dimension_source: gr.update(visible=not use_custom)
}
def process_image(image, background_image, *args):
geeky_remb = GeekyRemB()
result, _ = geeky_remb.remove_background(image, background_image, *args)
return result
def process_video(video_path, background_video_path, *args):
geeky_remb = GeekyRemB()
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
output_path = temp_file.name
geeky_remb.process_video(video_path, output_path, background_video_path, *args)
return output_path
def run_geeky_remb(input_type, foreground_input, foreground_video, result_type, model, output_format,
alpha_matting, alpha_matting_foreground_threshold, alpha_matting_background_threshold,
post_process_mask, chroma_key, chroma_threshold, color_tolerance, background_mode,
background_color, background_image, background_video, invert_mask, feather_amount,
edge_detection, edge_thickness, edge_color, shadow, shadow_blur, shadow_opacity,
color_adjustment, brightness, contrast, saturation, x_position, y_position, rotation,
opacity, flip_horizontal, flip_vertical, mask_blur, mask_expansion, foreground_scale,
foreground_aspect_ratio, remove_background, image_format, video_format, video_quality,
use_custom_dimensions, custom_width, custom_height, output_dimension_source):
args = (model, alpha_matting, alpha_matting_foreground_threshold,
alpha_matting_background_threshold, post_process_mask, chroma_key, chroma_threshold,
color_tolerance, background_mode, background_color, output_format,
invert_mask, feather_amount, edge_detection, edge_thickness, edge_color, shadow, shadow_blur,
shadow_opacity, color_adjustment, brightness, contrast, saturation, x_position,
y_position, rotation, opacity, flip_horizontal, flip_vertical, mask_blur,
mask_expansion, foreground_scale, foreground_aspect_ratio, remove_background,
use_custom_dimensions, custom_width, custom_height, output_dimension_source)
if input_type == "Image" and result_type == "Image":
result = process_image(foreground_input, background_image, *args)
if image_format != "PNG":
result = result.convert("RGB")
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{image_format.lower()}") as temp_file:
result.save(temp_file.name, format=image_format, quality=95 if image_format == "JPEG" else None)
return temp_file.name, None
elif input_type == "Video" and result_type == "Video":
output_video = process_video(foreground_video, background_video if background_mode == "video" else None, *args)
if video_format != "MP4":
temp_output = output_video + f"_temp.{video_format.lower()}"
os.system(f"ffmpeg -i {output_video} -c:v libx264 -crf {int(20 - (video_quality / 5))} {temp_output}")
os.remove(output_video)
output_video = temp_output
return None, output_video
elif input_type == "Image" and result_type == "Video":
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
output_path = temp_file.name
frame = cv2.cvtColor(np.array(foreground_input), cv2.COLOR_RGB2BGR)
height, width = frame.shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, 24, (width, height))
for _ in range(24 * 5): # 5 seconds at 24 fps
out.write(frame)
out.release()
return None, process_video(output_path, background_video if background_mode == "video" else None, *args)
elif input_type == "Video" and result_type == "Image":
cap = cv2.VideoCapture(foreground_video)
ret, frame = cap.read()
cap.release()
if ret:
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
result = process_image(pil_frame, background_image, *args)
if image_format != "PNG":
result = result.convert("RGB")
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{image_format.lower()}") as temp_file:
result.save(temp_file.name, format=image_format, quality=95 if image_format == "JPEG" else None)
return temp_file.name, None
else:
return None, None
input_type.change(update_input_type, inputs=[input_type], outputs=[foreground_input, foreground_video])
result_type.change(update_output_type, inputs=[result_type], outputs=[result_image, result_video])
background_mode.change(update_background_mode, inputs=[background_mode], outputs=[background_color, background_image, background_video])
use_custom_dimensions.change(update_custom_dimensions, inputs=[use_custom_dimensions], outputs=[custom_width, custom_height, output_dimension_source])
run_button.click(
fn=run_geeky_remb,
inputs=[
input_type, foreground_input, foreground_video, result_type,
model, output_format, alpha_matting, alpha_matting_foreground_threshold,
alpha_matting_background_threshold, post_process_mask, chroma_key, chroma_threshold,
color_tolerance, background_mode, background_color, background_image, background_video,
invert_mask, feather_amount, edge_detection, edge_thickness, edge_color,
shadow, shadow_blur, shadow_opacity, color_adjustment, brightness, contrast,
saturation, x_position, y_position, rotation, opacity, flip_horizontal,
flip_vertical, mask_blur, mask_expansion, foreground_scale, foreground_aspect_ratio,
remove_background, image_format, video_format, video_quality,
use_custom_dimensions, custom_width, custom_height, output_dimension_source
],
outputs=[result_image, result_video]
)
return [(geeky_remb_tab, "GeekyRemB", "geeky_remb_tab")]
script_callbacks.on_ui_tabs(on_ui_tabs)