File size: 4,002 Bytes
0324143
 
 
 
 
 
 
 
 
 
 
 
 
4a9d4cd
4f7099b
4a9d4cd
4f7099b
 
 
 
 
 
 
2fde56b
 
 
 
 
 
 
 
 
4a9d4cd
 
 
0324143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import spaces
import gradio as gr
from gradio_imageslider import ImageSlider
import numpy as np
from huggingface_hub import hf_hub_download
import torch
from PIL import Image
from diffusers import DDPMScheduler
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
from module.ip_adapter.utils import load_adapter_to_pipe
from pipelines.sdxl_instantir import InstantIRPipeline
import os

import os
import sys

# Add the current directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

import os
import sys

# Add the packages directory to Python path
packages_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "diffusers")
if packages_dir not in sys.path:
    sys.path.insert(0, packages_dir)

# Add the local diffusers directory to the Python path
sys.path.append(os.path.join(os.path.dirname(__file__), 'diffusers'))

os.makedirs('./models', exist_ok=True)
# Download model files if not present
for filename in ["adapter.pt", "aggregator.pt", "previewer_lora_weights.bin"]:
    hf_hub_download(repo_id="InstantX/InstantIR", filename=f"models/{filename}", local_dir=".", force_download=True)

# Initialize the pipeline and models
def initialize_pipeline():
    pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
    
    # load adapter
    load_adapter_to_pipe(
        pipe,
        './models/adapter.pt',
        image_encoder_or_path = 'facebook/dinov2-large',
    )
    
    # load previewer lora and schedulers
    pipe.prepare_previewers('./models')
    pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
    lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
    
    # load aggregator weights
    pretrained_state_dict = torch.load('./models/aggregator.pt')
    pipe.aggregator.load_state_dict(pretrained_state_dict)
    
    # send to GPU and fp16
    pipe.to(dtype=torch.float16)
    pipe.to('cuda')
    
    return pipe, lcm_scheduler

pipe, lcm_scheduler = initialize_pipeline()

@spaces.GPU
def process_image(input_image):
    if input_image is None:
        raise gr.Error("Please provide an image to restore.")
    
    # Convert to PIL Image
    pil_image = Image.fromarray(input_image)
    
    # Process image
    restored_image = pipe(
        prompt='',
        image=pil_image,
        ip_adapter_image=[pil_image],
        negative_prompt='',
        guidance_scale=7.0,
        previewer_scheduler=lcm_scheduler,
        return_dict=False,
    )[0]
    
    # Convert result to numpy array
    result_array = np.array(restored_image)
    
    return (input_image, result_array)

title = """<h1 align="center">InstantIR Image Restoration</h1>

<p><center>Restore and enhance your images</center></p>

<p><center>

<a href="https://huggingface.co./InstantX/InstantIR" target="_blank">[Model Page]</a>

</center></p>

"""

with gr.Blocks() as demo:
    gr.HTML(title)
    
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(label="Input Image", type="numpy")
            process_btn = gr.Button(value="Restore Image", variant="primary")
        with gr.Column(scale=1):
            output_slider = ImageSlider(label="Before / After", type="numpy")

    process_btn.click(
        fn=process_image,
        inputs=[input_image],
        outputs=output_slider
    )

    # Add examples
    gr.Examples(
        examples=[
            "examples/image1.jpg",
            "examples/image2.jpg"
        ],
        inputs=input_image,
        outputs=output_slider,
        fn=process_image,
        cache_examples=True,
    )

demo.launch(debug=True)