Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,332 Bytes
109512f 0324143 25ce5ea 0324143 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import spaces
import gradio as gr
from gradio_imageslider import ImageSlider
import numpy as np
from huggingface_hub import hf_hub_download
import torch
from PIL import Image
from diffusers import DDPMScheduler
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
from module.ip_adapter.utils import load_adapter_to_pipe
from pipelines.sdxl_instantir import InstantIRPipeline
import os
os.makedirs('./models', exist_ok=True)
# Download model files if not present
for filename in ["adapter.pt", "aggregator.pt", "previewer_lora_weights.bin"]:
hf_hub_download(repo_id="InstantX/InstantIR", filename=f"models/{filename}", local_dir=".", force_download=True)
# Initialize the pipeline and models
def initialize_pipeline():
pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
# load adapter
load_adapter_to_pipe(
pipe,
'./models/adapter.pt',
image_encoder_path = 'facebook/dinov2-large',
)
# load previewer lora and schedulers
pipe.prepare_previewers('./models')
pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
# load aggregator weights
pretrained_state_dict = torch.load('./models/aggregator.pt')
pipe.aggregator.load_state_dict(pretrained_state_dict)
# send to GPU and fp16
pipe.to(dtype=torch.float16)
pipe.to('cuda')
return pipe, lcm_scheduler
pipe, lcm_scheduler = initialize_pipeline()
@spaces.GPU
def process_image(input_image):
if input_image is None:
raise gr.Error("Please provide an image to restore.")
# Convert to PIL Image
pil_image = Image.fromarray(input_image)
# Process image
restored_image = pipe(
prompt='',
image=pil_image,
ip_adapter_image=[pil_image],
negative_prompt='',
guidance_scale=7.0,
previewer_scheduler=lcm_scheduler,
return_dict=False,
)[0]
# Convert result to numpy array
result_array = np.array(restored_image)
return (input_image, result_array)
title = """<h1 align="center">InstantIR Image Restoration</h1>
<p><center>Restore and enhance your images</center></p>
<p><center>
<a href="https://huggingface.co./InstantX/InstantIR" target="_blank">[Model Page]</a>
</center></p>
"""
with gr.Blocks() as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image", type="numpy")
process_btn = gr.Button(value="Restore Image", variant="primary")
with gr.Column(scale=1):
output_slider = ImageSlider(label="Before / After", type="numpy")
process_btn.click(
fn=process_image,
inputs=[input_image],
outputs=output_slider
)
# Add examples
gr.Examples(
examples=[
"examples/image1.jpg",
"examples/image2.jpg"
],
inputs=input_image,
outputs=output_slider,
fn=process_image,
cache_examples=True,
)
demo.launch(debug=True) |