import math import secrets from pathlib import Path from typing import cast import gradio as gr import numpy as np import spaces import torch from diffusers import FluxFillPipeline from gradio.components.image_editor import EditorValue from PIL import Image, ImageFilter, ImageOps DEVICE = "cuda" EXAMPLES_DIR = Path(__file__).parent / "examples" MAX_SEED = np.iinfo(np.int32).max SYSTEM_PROMPT = r"""This two-panel split-frame image showcases a furniture in as a product shot versus styled in a room. [LEFT] standalone product shot image the furniture on a white background. [RIGHT] integrated example within a room scene.""" MASK_CONTEXT_PADDING = 16 * 8 if not torch.cuda.is_available(): def _dummy_pipe(image: Image.Image, *args, **kwargs): # noqa: ARG001 # return {"images": [image]} # noqa: ERA001 blue_image = Image.new("RGB", image.size, (0, 0, 255)) return {"images": [blue_image]} pipe = _dummy_pipe else: state_dict, network_alphas = FluxFillPipeline.lora_state_dict( pretrained_model_name_or_path_or_dict="blanchon/FluxFillFurniture", weight_name="pytorch_lora_weights3.safetensors", return_alphas=True, ) if not all(("lora" in key or "dora_scale" in key) for key in state_dict): msg = "Invalid LoRA checkpoint." raise ValueError(msg) pipe = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 ).to(DEVICE) FluxFillPipeline.load_lora_into_transformer( state_dict=state_dict, network_alphas=network_alphas, transformer=pipe.transformer, ) pipe.to(DEVICE) def make_example(image_path: Path, mask_path: Path) -> EditorValue: background_image = Image.open(image_path) background_image = background_image.convert("RGB") background = np.array(background_image) mask_image = Image.open(mask_path) mask_image = mask_image.convert("RGB") mask = np.array(mask_image) mask = mask[:, :, 0] mask = np.where(mask == 255, 0, 255) # noqa: PLR2004 if background.shape[0] != mask.shape[0] or background.shape[1] != mask.shape[1]: msg = "Background and mask must have the same shape" raise ValueError(msg) layer = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) layer[:, :, 3] = mask composite = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) composite[:, :, :3] = background composite[:, :, 3] = np.where(mask == 255, 0, 255) # noqa: PLR2004 return { "background": background, "layers": [layer], "composite": composite, } def pad( image: Image.Image, size: tuple[int, int], method: int = Image.Resampling.BICUBIC, color: str | int | tuple[int, ...] | None = None, centering: tuple[float, float] = (1, 1), ) -> tuple[Image.Image, tuple[int, int]]: resized = ImageOps.contain(image, size, method) resized_size = resized.size if resized_size == size: out = resized else: out = Image.new(image.mode, size, color) if resized.palette: palette = resized.getpalette() if palette is not None: out.putpalette(palette) if resized.width != size[0]: x = round((size[0] - resized.width) * max(0, min(centering[0], 1))) out.paste(resized, (x, 0)) else: y = round((size[1] - resized.height) * max(0, min(centering[1], 1))) out.paste(resized, (0, y)) return out, resized_size def unpad( padded_image: Image.Image, padded_size: tuple[int, int], original_size: tuple[int, int], centering: tuple[float, float] = (1, 1), method: int = Image.Resampling.BICUBIC, ) -> Image.Image: width, height = padded_image.size padded_width, padded_height = padded_size # Calculate the cropping box based on centering left = round((width - padded_width) * centering[0]) top = round((height - padded_height) * centering[1]) right = left + padded_width bottom = top + padded_height # Crop the image to remove the padding cropped_image = padded_image.crop((left, top, right, bottom)) # Resize the cropped image to match the original size resized_image = cropped_image.resize(original_size, method) return resized_image def adjust_bbox_to_divisible_16( x_min: int, y_min: int, x_max: int, y_max: int, width: int, height: int, padding: int = MASK_CONTEXT_PADDING, ) -> tuple[int, int, int, int]: # Add context padding x_min = max(x_min - padding, 0) y_min = max(y_min - padding, 0) x_max = min(x_max + padding, width) y_max = min(y_max + padding, height) # Ensure bbox dimensions are divisible by 16 def make_divisible_16(val_min, val_max, max_limit): size = val_max - val_min if size % 16 != 0: adjustment = 16 - (size % 16) val_min = max(val_min - adjustment // 2, 0) val_max = min(val_max + adjustment // 2, max_limit) return val_min, val_max x_min, x_max = make_divisible_16(x_min, x_max, width) y_min, y_max = make_divisible_16(y_min, y_max, height) # Re-check divisibility after bounds adjustment x_min = max(x_min, 0) y_min = max(y_min, 0) x_max = min(x_max, width) y_max = min(y_max, height) # Final divisibility check (in case constraints pushed it off again) x_min, x_max = make_divisible_16(x_min, x_max, width) y_min, y_max = make_divisible_16(y_min, y_max, height) return x_min, y_min, x_max, y_max @spaces.GPU(duration=150) def infer( furniture_image_input: Image.Image, room_image_input: EditorValue, furniture_prompt: str = "", seed: int = 42, randomize_seed: bool = False, guidance_scale: float = 3.5, num_inference_steps: int = 20, max_dimension: int = 720, num_images_per_prompt: int = 2, progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008 ): # Ensure max_dimension is a multiple of 16 (for VAE) max_dimension = (max_dimension // 16) * 16 room_image = room_image_input["background"] if room_image is None: msg = "Room image is required" raise ValueError(msg) room_image = cast("Image.Image", room_image) room_mask = room_image_input["layers"][0] if room_mask is None: msg = "Room mask is required" raise ValueError(msg) room_mask = cast("Image.Image", room_mask) mask_bbox_x_min, mask_bbox_y_min, mask_bbox_x_max, mask_bbox_y_max = ( adjust_bbox_to_divisible_16( *room_mask.getbbox(alpha_only=False), width=room_mask.width, height=room_mask.height, padding=MASK_CONTEXT_PADDING, ) ) room_image_cropped = room_image.crop(( mask_bbox_x_min, mask_bbox_y_min, mask_bbox_x_max, mask_bbox_y_max, )) room_image_padded, room_image_padded_size = pad( room_image_cropped, (max_dimension, max_dimension), ) # grow_and_blur_mask grow_pixels = 10 sigma_grow = grow_pixels / 4 kernel_size_grow = math.ceil(sigma_grow * 1.5 + 1) room_mask_grow = room_mask.filter( ImageFilter.MaxFilter(size=2 * kernel_size_grow + 1) ) blur_pixels = 33 sigma_blur = blur_pixels / 4 kernel_size_blur = math.ceil(sigma_blur * 1.5 + 1) room_mask_blurred = room_mask_grow.filter( ImageFilter.GaussianBlur(radius=kernel_size_blur) ) room_mask_cropped = room_mask_blurred.crop(( mask_bbox_x_min, mask_bbox_y_min, mask_bbox_x_max, mask_bbox_y_max, )) room_mask_padded, _ = pad( room_mask_cropped, (max_dimension, max_dimension), ) room_image_padded.save("room_image_padded.png") room_mask_padded.save("room_mask_padded.png") furniture_image, _ = pad( furniture_image_input, (max_dimension, max_dimension), ) furniture_mask = Image.new("RGB", (max_dimension, max_dimension), (255, 255, 255)) image = Image.new( "RGB", (max_dimension * 2, max_dimension), (255, 255, 255), ) # Paste on the center of the image image.paste(furniture_image, (0, 0)) image.paste(room_image_padded, (max_dimension, 0)) mask = Image.new( "RGB", (max_dimension * 2, max_dimension), (255, 255, 255), ) mask.paste(furniture_mask, (0, 0)) mask.paste(room_mask_padded, (max_dimension, 0), room_mask_padded) # Invert the mask mask = ImageOps.invert(mask) # Blur the mask mask = mask.filter(ImageFilter.GaussianBlur(radius=10)) # Convert to 3 channel mask = mask.convert("L") if randomize_seed: seed = secrets.randbelow(MAX_SEED) prompt = ( furniture_prompt + ".\n" + SYSTEM_PROMPT if furniture_prompt else SYSTEM_PROMPT ) image.save("image.png") mask.save("mask.png") results_images = pipe( prompt=prompt, image=image, mask_image=mask, height=max_dimension, width=max_dimension * 2, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, generator=torch.Generator("cpu").manual_seed(seed), )["images"] final_images = [] for image in results_images: final_image = room_image.copy() image_generated = unpad( image, room_image_padded_size, ( mask_bbox_x_max - mask_bbox_x_min, mask_bbox_y_max - mask_bbox_y_min, ), ) # Paste the image on the room image as the crop was done # on the room image final_image.paste( image_generated, (mask_bbox_x_min, mask_bbox_y_min), room_mask_cropped, ) final_images.append(final_image) return final_images, seed intro_markdown = r"""