import spaces import gradio as gr from gradio_imageslider import ImageSlider import numpy as np from huggingface_hub import hf_hub_download import torch from PIL import Image from diffusers import DDPMScheduler from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler from module.ip_adapter.utils import load_adapter_to_pipe from pipelines.sdxl_instantir import InstantIRPipeline import os os.makedirs('./models', exist_ok=True) # Download model files if not present for filename in ["adapter.pt", "aggregator.pt", "previewer_lora_weights.bin"]: hf_hub_download(repo_id="InstantX/InstantIR", filename=f"models/{filename}", local_dir=".", force_download=True) # Initialize the pipeline and models def initialize_pipeline(): pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16) # load adapter load_adapter_to_pipe( pipe, './models/adapter.pt', image_encoder_path = 'facebook/dinov2-large', ) # load previewer lora and schedulers pipe.prepare_previewers('./models') pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler") lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) # load aggregator weights pretrained_state_dict = torch.load('./models/aggregator.pt') pipe.aggregator.load_state_dict(pretrained_state_dict) # send to GPU and fp16 pipe.to(dtype=torch.float16) pipe.to('cuda') return pipe, lcm_scheduler pipe, lcm_scheduler = initialize_pipeline() @spaces.GPU def process_image(input_image): if input_image is None: raise gr.Error("Please provide an image to restore.") # Convert to PIL Image pil_image = Image.fromarray(input_image) # Process image restored_image = pipe( prompt='', image=pil_image, ip_adapter_image=[pil_image], negative_prompt='', guidance_scale=7.0, previewer_scheduler=lcm_scheduler, return_dict=False, )[0] # Convert result to numpy array result_array = np.array(restored_image) return (input_image, result_array) title = """

InstantIR Image Restoration

Restore and enhance your images

[Model Page]

""" with gr.Blocks() as demo: gr.HTML(title) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Input Image", type="numpy") process_btn = gr.Button(value="Restore Image", variant="primary") with gr.Column(scale=1): output_slider = ImageSlider(label="Before / After", type="numpy") process_btn.click( fn=process_image, inputs=[input_image], outputs=output_slider ) # Add examples gr.Examples( examples=[ "examples/image1.jpg", "examples/image2.jpg" ], inputs=input_image, outputs=output_slider, fn=process_image, cache_examples=True, ) demo.launch(debug=True)