FurnitureDemo / app.py
blanchon's picture
really log all
87608e2
import csv
import json
import math
import os
import secrets
from pathlib import Path
from typing import cast
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxFillPipeline
from gradio.components.gallery import GalleryMediaType
from gradio.components.image_editor import EditorValue
from huggingface_hub import HfApi
from PIL import Image, ImageFilter, ImageOps
DEVICE = "cuda"
USER = os.getenv("USER")
PASSWORD = os.getenv("PASSWORD")
if not USER or not PASSWORD:
msg = "USER and PASSWORD must be set"
raise ValueError(msg)
MAX_SEED = np.iinfo(np.int32).max
SYSTEM_PROMPT = r"""This two-panel split-frame image showcases a furniture in as a product shot versus styled in a room.
[LEFT] standalone product shot image the furniture on a white background.
[RIGHT] integrated example within a room scene."""
MASK_CONTEXT_PADDING = 16 * 8
api = HfApi()
model_name = "2025-01-11_22-00-18-save-10359-55-129_patched.safetensors"
# Download the blanchon/FurnitureFlags init Path(__file__).parent / examples_dataset
FLAG_PATH = Path(__file__).parent / "examples_dataset"
if not torch.cuda.is_available():
FLAG_PATH.mkdir(parents=True, exist_ok=True)
else:
api.snapshot_download(
repo_id="blanchon/FurnitureFlags",
local_dir=FLAG_PATH,
repo_type="dataset",
)
EXAMPLES: dict[str, list[str, str, str, list[str]]] = {}
flag_files = FLAG_PATH.glob("dataset*.csv")
for flag_file in flag_files:
with flag_file.open("r") as file:
reader = csv.reader(file)
next(reader)
for row in reader:
furniture_image, room_image, results_values, flag, time = row
room_image = json.loads(room_image)
room_image_background = room_image["background"]
room_image_layers = room_image["layers"]
room_image_composite = room_image["composite"]
results_values = json.loads(results_values)
results_values = [result["image"] for result in results_values]
EXAMPLES[time] = [
furniture_image,
{
"background": room_image_background,
"layers": room_image_layers,
"composite": room_image_composite,
},
# results_values,
]
if not torch.cuda.is_available():
def _dummy_pipe(image: Image.Image, *args, **kwargs): # noqa: ARG001
# return {"images": [image]} # noqa: ERA001
blue_image = Image.new("RGB", image.size, (0, 0, 255))
return {"images": [blue_image, blue_image, blue_image]}
pipe = _dummy_pipe
else:
state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
pretrained_model_name_or_path_or_dict="blanchon/FluxFillFurniture",
weight_name=model_name,
return_alphas=True,
)
if not all(("lora" in key or "dora_scale" in key) for key in state_dict):
msg = "Invalid LoRA checkpoint."
raise ValueError(msg)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to(DEVICE)
FluxFillPipeline.load_lora_into_transformer(
state_dict=state_dict,
network_alphas=network_alphas,
transformer=pipe.transformer,
)
pipe.to(DEVICE)
callback = gr.CSVLogger()
def make_example(image_path: Path, mask_path: Path) -> EditorValue:
background_image = Image.open(image_path)
background_image = background_image.convert("RGB")
background = np.array(background_image)
mask_image = Image.open(mask_path)
mask_image = mask_image.convert("RGB")
mask = np.array(mask_image)
mask = mask[:, :, 0]
mask = np.where(mask == 255, 0, 255) # noqa: PLR2004
if background.shape[0] != mask.shape[0] or background.shape[1] != mask.shape[1]:
msg = "Background and mask must have the same shape"
raise ValueError(msg)
layer = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8)
layer[:, :, 3] = mask
composite = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8)
composite[:, :, :3] = background
composite[:, :, 3] = np.where(mask == 255, 0, 255) # noqa: PLR2004
return {
"background": background,
"layers": [layer],
"composite": composite,
}
def pad(
image: Image.Image,
size: tuple[int, int],
method: int = Image.Resampling.BICUBIC,
color: str | int | tuple[int, ...] | None = None,
centering: tuple[float, float] = (1, 1),
) -> tuple[Image.Image, tuple[int, int]]:
resized = ImageOps.contain(image, size, method)
resized_size = resized.size
if resized_size == size:
out = resized
else:
out = Image.new(image.mode, size, color)
if resized.palette:
palette = resized.getpalette()
if palette is not None:
out.putpalette(palette)
if resized.width != size[0]:
x = round((size[0] - resized.width) * max(0, min(centering[0], 1)))
out.paste(resized, (x, 0))
else:
y = round((size[1] - resized.height) * max(0, min(centering[1], 1)))
out.paste(resized, (0, y))
return out, resized_size
def unpad(
padded_image: Image.Image,
padded_size: tuple[int, int],
original_size: tuple[int, int],
centering: tuple[float, float] = (1, 1),
method: int = Image.Resampling.BICUBIC,
) -> Image.Image:
width, height = padded_image.size
padded_width, padded_height = padded_size
# Calculate the cropping box based on centering
left = round((width - padded_width) * centering[0])
top = round((height - padded_height) * centering[1])
right = left + padded_width
bottom = top + padded_height
# Crop the image to remove the padding
cropped_image = padded_image.crop((left, top, right, bottom))
# Resize the cropped image to match the original size
resized_image = cropped_image.resize(original_size, method)
return resized_image
def adjust_bbox_to_divisible_16(
x_min: int,
y_min: int,
x_max: int,
y_max: int,
width: int,
height: int,
padding: int = MASK_CONTEXT_PADDING,
) -> tuple[int, int, int, int]:
# Add context padding
x_min = max(x_min - padding, 0)
y_min = max(y_min - padding, 0)
x_max = min(x_max + padding, width)
y_max = min(y_max + padding, height)
# Ensure bbox dimensions are divisible by 16
def make_divisible_16(val_min, val_max, max_limit):
size = val_max - val_min
if size % 16 != 0:
adjustment = 16 - (size % 16)
val_min = max(val_min - adjustment // 2, 0)
val_max = min(val_max + adjustment // 2, max_limit)
return val_min, val_max
x_min, x_max = make_divisible_16(x_min, x_max, width)
y_min, y_max = make_divisible_16(y_min, y_max, height)
# Re-check divisibility after bounds adjustment
x_min = max(x_min, 0)
y_min = max(y_min, 0)
x_max = min(x_max, width)
y_max = min(y_max, height)
# Final divisibility check (in case constraints pushed it off again)
x_min, x_max = make_divisible_16(x_min, x_max, width)
y_min, y_max = make_divisible_16(y_min, y_max, height)
return x_min, y_min, x_max, y_max
def flag(
furniture_image_input: Image.Image,
room_image_input: EditorValue,
results: GalleryMediaType,
):
if len(results) == 0:
return
callback.flag(
flag_data=[furniture_image_input, room_image_input, results],
flag_option=model_name,
)
if torch.cuda.is_available():
# Upload the flagged data points to the hub
api.upload_folder(
repo_id="blanchon/FurnitureFlags",
repo_type="dataset",
folder_path=FLAG_PATH,
ignore_patterns=[".cache"],
)
@spaces.GPU(duration=150)
def infer(
furniture_image_input: Image.Image,
room_image_input: EditorValue,
furniture_prompt: str = "",
seed: int = 42,
randomize_seed: bool = False,
guidance_scale: float = 3.5,
num_inference_steps: int = 20,
max_dimension: int = 720,
num_images_per_prompt: int = 2,
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
) -> tuple[GalleryMediaType, int]:
# Ensure max_dimension is a multiple of 16 (for VAE)
max_dimension = (max_dimension // 16) * 16
room_image = room_image_input["background"]
if room_image is None:
msg = "Room image is required"
raise ValueError(msg)
room_image = cast("Image.Image", room_image)
room_mask = room_image_input["layers"][0]
if room_mask is None:
msg = "Room mask is required"
raise ValueError(msg)
room_mask = cast("Image.Image", room_mask)
mask_bbox_x_min, mask_bbox_y_min, mask_bbox_x_max, mask_bbox_y_max = (
adjust_bbox_to_divisible_16(
*room_mask.getbbox(alpha_only=False),
width=room_mask.width,
height=room_mask.height,
padding=MASK_CONTEXT_PADDING,
)
)
# Create a debug image showing the bounding box
bbox_debug = room_image.copy()
from PIL import ImageDraw
draw = ImageDraw.Draw(bbox_debug)
draw.rectangle(
(mask_bbox_x_min, mask_bbox_y_min, mask_bbox_x_max, mask_bbox_y_max),
outline="red",
width=3,
)
room_image_cropped = room_image.crop(
(
mask_bbox_x_min,
mask_bbox_y_min,
mask_bbox_x_max,
mask_bbox_y_max,
)
)
room_image_padded, room_image_padded_size = pad(
room_image_cropped,
(max_dimension, max_dimension),
)
# Grow mask: For each kernel size apply the max filter
grow_pixels = 10
sigma_grow = grow_pixels / 4
kernel_size_grow = math.ceil(sigma_grow * 1.5 + 1)
room_mask_grow = room_mask.filter(
ImageFilter.MaxFilter(size=2 * kernel_size_grow + 1)
)
# Blur mask: For each kernel size apply the gaussian blur filter
blur_pixels = 33
sigma_blur = blur_pixels / 4
kernel_size_blur = math.ceil(sigma_blur * 1.5 + 1)
room_mask_blurred = room_mask_grow.filter(
ImageFilter.GaussianBlur(radius=kernel_size_blur)
)
room_mask_cropped = room_mask_blurred.crop(
(
mask_bbox_x_min,
mask_bbox_y_min,
mask_bbox_x_max,
mask_bbox_y_max,
)
)
room_mask_padded, _ = pad(
room_mask_cropped,
(max_dimension, max_dimension),
)
furniture_image, _ = pad(
furniture_image_input,
(max_dimension, max_dimension),
)
furniture_mask = Image.new("RGB", (max_dimension, max_dimension), (255, 255, 255))
image = Image.new(
"RGB",
(max_dimension * 2, max_dimension),
(255, 255, 255),
)
# Paste on the center of the image
image.paste(furniture_image, (0, 0))
image.paste(room_image_padded, (max_dimension, 0))
mask = Image.new(
"RGB",
(max_dimension * 2, max_dimension),
(255, 255, 255),
)
mask.paste(furniture_mask, (0, 0))
mask.paste(room_mask_padded, (max_dimension, 0), room_mask_padded)
# Invert the mask
mask = ImageOps.invert(mask)
# Blur the mask
mask = mask.filter(ImageFilter.GaussianBlur(radius=10))
# Convert to 3 channel
mask = mask.convert("L")
if randomize_seed:
seed = secrets.randbelow(MAX_SEED)
prompt = (
furniture_prompt + ".\n" + SYSTEM_PROMPT if furniture_prompt else SYSTEM_PROMPT
)
results_images = pipe(
prompt=prompt,
image=image,
mask_image=mask,
height=max_dimension,
width=max_dimension * 2,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=torch.Generator("cpu").manual_seed(seed),
)["images"]
final_images = []
final_images.append(bbox_debug)
final_images.append(room_image_padded)
final_images.append(room_image_cropped)
final_images.append(room_image)
final_images.append(room_mask)
final_images.append(furniture_image)
final_images.append(image)
final_images.append(mask)
for image in results_images:
final_image = room_image.copy()
image_generated = unpad(
image,
room_image_padded_size,
(
mask_bbox_x_max - mask_bbox_x_min,
mask_bbox_y_max - mask_bbox_y_min,
),
)
# Paste the image on the room image as the crop was done
# on the room image
final_image.paste(
image_generated,
(mask_bbox_x_min, mask_bbox_y_min),
room_mask_cropped,
)
final_images.append(final_image)
return final_images, seed
intro_markdown = r"""
# FurnitureDemo
"""
css = r"""
#col-left {
margin: 0 auto;
max-width: 430px;
}
#col-mid {
margin: 0 auto;
max-width: 430px;
}
#col-right {
margin: 0 auto;
max-width: 430px;
}
#col-showcase {
margin: 0 auto;
max-width: 1100px;
}
"""
def check_password(password: str) -> bool:
if password == PASSWORD:
return [
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True),
]
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
with gr.Blocks(css=css) as demo:
gr.Markdown(intro_markdown)
with gr.Row(visible=False) as content:
with gr.Column(elem_id="col-left"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 1. Upload a furniture image ⬇️
</div>
</div>
""",
max_height=50,
)
furniture_image_input = gr.Image(
label="furniture",
type="pil",
sources=["upload"],
image_mode="RGB",
height=500,
)
furniture_examples = gr.Examples(
examples=list({example[0] for example in EXAMPLES.values()}),
label="Furniture examples",
examples_per_page=6,
inputs=[furniture_image_input],
)
with gr.Column(elem_id="col-mid"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 2. Upload a room image ⬇️
</div>
</div>
""",
max_height=50,
)
room_image_input = gr.ImageEditor(
label="room_image",
type="pil",
sources=["upload"],
image_mode="RGBA",
layers=False,
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
height=500,
)
room_examples = gr.Examples(
examples=[example[1] for example in EXAMPLES.values()],
label="Room examples",
examples_per_page=6,
# examples=[
# make_example(
# EXAMPLES_DIR / "1" / "room_image.png",
# EXAMPLES_DIR / "1" / "room_mask.png",
# ),
# make_example(
# EXAMPLES_DIR / "2" / "room_image.png",
# EXAMPLES_DIR / "2" / "room_mask.png",
# ),
# ],
inputs=[room_image_input],
)
with gr.Column(elem_id="col-right"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 3. Press Run to launch
</div>
</div>
""",
max_height=50,
)
results = gr.Gallery(
label="results",
show_label=False,
columns=[2],
rows=[2],
object_fit="contain",
height=500,
format="png",
interactive=False,
)
run_button = gr.Button("Run")
flag_button = gr.Button("Flag")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
furniture_prompt = gr.Text(
label="Prompt",
max_lines=1,
placeholder="Enter a custom furniture description (optional)",
container=False,
)
with gr.Column():
max_dimension = gr.Slider(
label="Max Dimension",
minimum=512,
maximum=1024,
step=128,
value=720,
)
num_images_per_prompt = gr.Slider(
label="Number of images per prompt",
minimum=1,
maximum=4,
step=1,
value=2,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=30,
step=0.5,
# value=50, # noqa: ERA001
value=30,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=20,
)
with gr.Column(elem_id="col-showcase"):
gr.HTML("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div> </div>
<br>
<div>
Examples in pairs of furniture and room images
</div>
</div>
""")
show_case = gr.Examples(
examples=list(EXAMPLES.values()),
inputs=[furniture_image_input, room_image_input],
outputs=[results, seed],
fn=infer,
cache_examples=True,
cache_mode="eager",
label="Examples",
examples_per_page=12,
)
with gr.Row():
password = gr.Textbox(label="Password", type="password")
submit = gr.Button("Submit")
submit.click(
fn=check_password,
inputs=[password],
outputs=[password, submit, content],
)
# This needs to be called at some point prior to the first call to callback.flag()
callback.setup(
[
furniture_image_input,
room_image_input,
results,
],
"examples_dataset",
)
run_button.click(
fn=infer,
inputs=[
furniture_image_input,
room_image_input,
furniture_prompt,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
max_dimension,
num_images_per_prompt,
],
outputs=[results, seed],
)
flag_button.click(
fn=flag,
inputs=[furniture_image_input, room_image_input, results],
preprocess=False,
)
# demo.launch(auth=[(USER, PASSWORD)])
demo.launch()