File size: 3,332 Bytes
109512f
0324143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25ce5ea
0324143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

import spaces
import gradio as gr
from gradio_imageslider import ImageSlider
import numpy as np
from huggingface_hub import hf_hub_download
import torch
from PIL import Image
from diffusers import DDPMScheduler
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
from module.ip_adapter.utils import load_adapter_to_pipe
from pipelines.sdxl_instantir import InstantIRPipeline
import os

os.makedirs('./models', exist_ok=True)
# Download model files if not present
for filename in ["adapter.pt", "aggregator.pt", "previewer_lora_weights.bin"]:
    hf_hub_download(repo_id="InstantX/InstantIR", filename=f"models/{filename}", local_dir=".", force_download=True)

# Initialize the pipeline and models
def initialize_pipeline():
    pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
    
    # load adapter
    load_adapter_to_pipe(
        pipe,
        './models/adapter.pt',
        image_encoder_path = 'facebook/dinov2-large',
    )
    
    # load previewer lora and schedulers
    pipe.prepare_previewers('./models')
    pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
    lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
    
    # load aggregator weights
    pretrained_state_dict = torch.load('./models/aggregator.pt')
    pipe.aggregator.load_state_dict(pretrained_state_dict)
    
    # send to GPU and fp16
    pipe.to(dtype=torch.float16)
    pipe.to('cuda')
    
    return pipe, lcm_scheduler

pipe, lcm_scheduler = initialize_pipeline()

@spaces.GPU
def process_image(input_image):
    if input_image is None:
        raise gr.Error("Please provide an image to restore.")
    
    # Convert to PIL Image
    pil_image = Image.fromarray(input_image)
    
    # Process image
    restored_image = pipe(
        prompt='',
        image=pil_image,
        ip_adapter_image=[pil_image],
        negative_prompt='',
        guidance_scale=7.0,
        previewer_scheduler=lcm_scheduler,
        return_dict=False,
    )[0]
    
    # Convert result to numpy array
    result_array = np.array(restored_image)
    
    return (input_image, result_array)

title = """<h1 align="center">InstantIR Image Restoration</h1>

<p><center>Restore and enhance your images</center></p>

<p><center>

<a href="https://huggingface.co./InstantX/InstantIR" target="_blank">[Model Page]</a>

</center></p>

"""

with gr.Blocks() as demo:
    gr.HTML(title)
    
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(label="Input Image", type="numpy")
            process_btn = gr.Button(value="Restore Image", variant="primary")
        with gr.Column(scale=1):
            output_slider = ImageSlider(label="Before / After", type="numpy")

    process_btn.click(
        fn=process_image,
        inputs=[input_image],
        outputs=output_slider
    )

    # Add examples
    gr.Examples(
        examples=[
            "examples/image1.jpg",
            "examples/image2.jpg"
        ],
        inputs=input_image,
        outputs=output_slider,
        fn=process_image,
        cache_examples=True,
    )

demo.launch(debug=True)