Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from gradio_imageslider import ImageSlider | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
import torch | |
from PIL import Image | |
from diffusers import DDPMScheduler | |
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler | |
from module.ip_adapter.utils import load_adapter_to_pipe | |
from pipelines.sdxl_instantir import InstantIRPipeline | |
import os | |
os.makedirs('./models', exist_ok=True) | |
# Download model files if not present | |
for filename in ["adapter.pt", "aggregator.pt", "previewer_lora_weights.bin"]: | |
hf_hub_download(repo_id="InstantX/InstantIR", filename=f"models/{filename}", local_dir=".", force_download=True) | |
# Initialize the pipeline and models | |
def initialize_pipeline(): | |
pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16) | |
# load adapter | |
load_adapter_to_pipe( | |
pipe, | |
'./models/adapter.pt', | |
image_encoder_path = 'facebook/dinov2-large', | |
) | |
# load previewer lora and schedulers | |
pipe.prepare_previewers('./models') | |
pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler") | |
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) | |
# load aggregator weights | |
pretrained_state_dict = torch.load('./models/aggregator.pt') | |
pipe.aggregator.load_state_dict(pretrained_state_dict) | |
# send to GPU and fp16 | |
pipe.to(dtype=torch.float16) | |
pipe.to('cuda') | |
return pipe, lcm_scheduler | |
pipe, lcm_scheduler = initialize_pipeline() | |
def process_image(input_image): | |
if input_image is None: | |
raise gr.Error("Please provide an image to restore.") | |
# Convert to PIL Image | |
pil_image = Image.fromarray(input_image) | |
# Process image | |
restored_image = pipe( | |
prompt='', | |
image=pil_image, | |
ip_adapter_image=[pil_image], | |
negative_prompt='', | |
guidance_scale=7.0, | |
previewer_scheduler=lcm_scheduler, | |
return_dict=False, | |
)[0] | |
# Convert result to numpy array | |
result_array = np.array(restored_image) | |
return (input_image, result_array) | |
title = """<h1 align="center">InstantIR Image Restoration</h1> | |
<p><center>Restore and enhance your images</center></p> | |
<p><center> | |
<a href="https://huggingface.co./InstantX/InstantIR" target="_blank">[Model Page]</a> | |
</center></p> | |
""" | |
with gr.Blocks() as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(label="Input Image", type="numpy") | |
process_btn = gr.Button(value="Restore Image", variant="primary") | |
with gr.Column(scale=1): | |
output_slider = ImageSlider(label="Before / After", type="numpy") | |
process_btn.click( | |
fn=process_image, | |
inputs=[input_image], | |
outputs=output_slider | |
) | |
# Add examples | |
gr.Examples( | |
examples=[ | |
"examples/image1.jpg", | |
"examples/image2.jpg" | |
], | |
inputs=input_image, | |
outputs=output_slider, | |
fn=process_image, | |
cache_examples=True, | |
) | |
demo.launch(debug=True) |