JOY-Huang commited on
Commit
028ae97
·
1 Parent(s): fe2ea78

Update space

Browse files
Files changed (1) hide show
  1. app.py +274 -111
app.py CHANGED
@@ -1,43 +1,209 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
- #import spaces #[uncomment to use ZeroGPU]
5
- from diffusers import DiffusionPipeline
6
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_repo_id = "stabilityai/sdxl-turbo" #Replace to the model you would like to use
 
 
10
 
11
  if torch.cuda.is_available():
12
  torch_dtype = torch.float16
13
  else:
14
  torch_dtype = torch.float32
15
 
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
17
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  MAX_IMAGE_SIZE = 1024
21
 
22
- #@spaces.GPU #[uncomment to use ZeroGPU]
23
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
24
 
25
- if randomize_seed:
26
- seed = random.randint(0, MAX_SEED)
27
-
28
- generator = torch.Generator().manual_seed(seed)
29
-
30
- image = pipe(
31
- prompt = prompt,
32
- negative_prompt = negative_prompt,
33
- guidance_scale = guidance_scale,
34
- num_inference_steps = num_inference_steps,
35
- width = width,
36
- height = height,
37
- generator = generator
38
- ).images[0]
39
-
40
- return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  examples = [
43
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
@@ -53,90 +219,87 @@ css="""
53
  """
54
 
55
  with gr.Blocks(css=css) as demo:
56
-
57
- with gr.Column(elem_id="col-container"):
58
- gr.Markdown(f"""
59
- # Text-to-Image Gradio Template
60
- """)
61
-
62
- with gr.Row():
63
-
64
- prompt = gr.Text(
65
- label="Prompt",
66
- show_label=False,
67
- max_lines=1,
68
- placeholder="Enter your prompt",
69
- container=False,
70
- )
71
-
72
- run_button = gr.Button("Run", scale=0)
73
-
74
- result = gr.Image(label="Result", show_label=False)
75
-
76
- with gr.Accordion("Advanced Settings", open=False):
77
-
78
- negative_prompt = gr.Text(
79
- label="Negative prompt",
80
- max_lines=1,
81
- placeholder="Enter a negative prompt",
82
- visible=False,
83
- )
84
-
85
- seed = gr.Slider(
86
- label="Seed",
87
- minimum=0,
88
- maximum=MAX_SEED,
89
- step=1,
90
- value=0,
91
- )
92
-
93
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
94
-
95
- with gr.Row():
96
-
97
- width = gr.Slider(
98
- label="Width",
99
- minimum=256,
100
- maximum=MAX_IMAGE_SIZE,
101
- step=32,
102
- value=1024, #Replace with defaults that work for your model
103
- )
104
-
105
- height = gr.Slider(
106
- label="Height",
107
- minimum=256,
108
- maximum=MAX_IMAGE_SIZE,
109
- step=32,
110
- value=1024, #Replace with defaults that work for your model
111
- )
112
-
113
  with gr.Row():
114
-
115
- guidance_scale = gr.Slider(
116
- label="Guidance scale",
117
- minimum=0.0,
118
- maximum=10.0,
119
- step=0.1,
120
- value=0.0, #Replace with defaults that work for your model
121
- )
122
-
123
- num_inference_steps = gr.Slider(
124
- label="Number of inference steps",
125
- minimum=1,
126
- maximum=50,
127
- step=1,
128
- value=2, #Replace with defaults that work for your model
129
- )
130
-
131
- gr.Examples(
132
- examples = examples,
133
- inputs = [prompt]
134
- )
135
- gr.on(
136
- triggers=[run_button.click, prompt.submit],
137
- fn = infer,
138
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
139
- outputs = [result, seed]
140
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- demo.queue().launch()
 
 
 
 
 
 
1
  import torch
2
+ import random
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+ from diffusers import (
9
+ DDPMScheduler,
10
+ StableDiffusionXLPipeline
11
+ )
12
+ from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
13
+ from diffusers.utils import convert_unet_state_dict_to_peft
14
+ from peft import LoraConfig, set_peft_model_state_dict
15
+ from transformers import (
16
+ AutoImageProcessor, AutoModel
17
+ )
18
+
19
+ from module.ip_adapter.utils import init_ip_adapter_in_unet
20
+ from module.ip_adapter.resampler import Resampler
21
+ from module.aggregator import Aggregator
22
+ from pipelines.sdxl_instantir import InstantIRPipeline, LCM_LORA_MODULES, PREVIEWER_LORA_MODULES
23
+
24
+
25
+ transform = transforms.Compose([
26
+ transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
27
+ transforms.CenterCrop(1024),
28
+ ])
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ sdxl_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
32
+ instantir_repo_id = "instantx/instantir"
33
+ dinov2_repo_id = "facebook/dinov2-large"
34
 
35
  if torch.cuda.is_available():
36
  torch_dtype = torch.float16
37
  else:
38
  torch_dtype = torch.float32
39
 
40
+ print("Loading vision encoder...")
41
+ image_encoder = AutoModel.from_pretrained(dinov2_repo_id, torch_dtype=torch_dtype)
42
+ image_processor = AutoImageProcessor.from_pretrained(dinov2_repo_id)
43
+
44
+ print("Loading SDXL...")
45
+ pipe = StableDiffusionXLPipeline.from_pretrained(
46
+ sdxl_repo_id,
47
+ torch_dtype=torch.float16,
48
+ )
49
+ unet = pipe.unet
50
+
51
+ print("Initializing Aggregator...")
52
+ aggregator = Aggregator.from_unet(unet, load_weights_from_unet=False)
53
+
54
+ print("Loading LQ-Adapter...")
55
+ image_proj_model = Resampler(
56
+ dim=1280,
57
+ depth=4,
58
+ dim_head=64,
59
+ heads=20,
60
+ num_queries=64,
61
+ embedding_dim=image_encoder.config.hidden_size,
62
+ output_dim=unet.config.cross_attention_dim,
63
+ ff_mult=4
64
+ )
65
+ init_ip_adapter_in_unet(
66
+ unet,
67
+ image_proj_model,
68
+ "InstantX/InstantIR/adapter.pt",
69
+ adapter_tokens=64,
70
+ )
71
+ print("Initializing InstantIR...")
72
+ pipe = InstantIRPipeline(
73
+ pipe.vae, pipe.text_encoder, pipe.text_encoder_2, pipe.tokenizer, pipe.tokenizer_2,
74
+ unet, aggregator, pipe.scheduler, feature_extractor=image_processor, image_encoder=image_encoder,
75
+ )
76
+
77
+ # Add Previewer LoRA.
78
+ lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(
79
+ "InstantX/InstantIR/previewer_lora_weights.bin",
80
+ # weight_name="previewer_lora_weights.bin",
81
+
82
+ )
83
+ unet_state_dict = {
84
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
85
+ }
86
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
87
+ lora_state_dict = dict()
88
+ for k, v in unet_state_dict.items():
89
+ if "ip" in k:
90
+ k = k.replace("attn2", "attn2.processor")
91
+ lora_state_dict[k] = v
92
+ else:
93
+ lora_state_dict[k] = v
94
+ if alpha_dict:
95
+ lora_alpha = next(iter(alpha_dict.values()))
96
+ else:
97
+ lora_alpha = 1
98
+ print(f"use lora alpha {lora_alpha}")
99
+ lora_config = LoraConfig(
100
+ r=64,
101
+ target_modules=PREVIEWER_LORA_MODULES,
102
+ lora_alpha=lora_alpha,
103
+ lora_dropout=0.0,
104
+ )
105
+
106
+ # Add LCM LoRA.
107
+ lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(
108
+ "latent-consistency/lcm-lora-sdxl"
109
+ )
110
+ unet_state_dict = {
111
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
112
+ }
113
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
114
+ if alpha_dict:
115
+ lora_alpha = next(iter(alpha_dict.values()))
116
+ else:
117
+ lora_alpha = 1
118
+ print(f"use lora alpha {lora_alpha}")
119
+ lora_config = LoraConfig(
120
+ r=64,
121
+ target_modules=LCM_LORA_MODULES,
122
+ lora_alpha=lora_alpha,
123
+ lora_dropout=0.0,
124
+ )
125
+
126
+ unet.add_adapter(lora_config, "lcm")
127
+ incompatible_keys = set_peft_model_state_dict(unet, unet_state_dict, adapter_name="lcm")
128
+ if incompatible_keys is not None:
129
+ # check only for unexpected keys
130
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
131
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
132
+ if unexpected_keys:
133
+ raise ValueError(
134
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
135
+ f" {unexpected_keys}. "
136
+ )
137
+
138
+ unet.disable_adapters()
139
+ pipe.scheduler = DDPMScheduler.from_pretrained(
140
+ sdxl_repo_id,
141
+ subfolder="scheduler"
142
+ )
143
+ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
144
+ # Load weights.
145
+ print("Loading checkpoint...")
146
+ aggregator_state_dict = torch.load(
147
+ "InstantX/InstantIR/aggregator.pt",
148
+ map_location="cpu"
149
+ )
150
+ aggregator.load_state_dict(aggregator_state_dict, strict=True)
151
+ aggregator.to(dtype=torch.float16)
152
+ unet.to(dtype=torch.float16)
153
+ pipe=pipe.to(device)
154
 
155
  MAX_SEED = np.iinfo(np.int32).max
156
  MAX_IMAGE_SIZE = 1024
157
 
158
+ def unpack_pipe_out(preview_row, index):
159
+ return preview_row[index][0]
160
 
161
+ def dynamic_preview_slider(sampling_steps):
162
+ print(sampling_steps)
163
+ return gr.Slider(label="Restoration Previews", value=sampling_steps-1, minimum=0, maximum=sampling_steps-1, step=1)
164
+
165
+ def dynamic_guidance_slider(sampling_steps):
166
+ return gr.Slider(label="Start Free Rendering", value=sampling_steps, minimum=0, maximum=sampling_steps, step=1)
167
+
168
+ def show_final_preview(preview_row):
169
+ return preview_row[-1][0]
170
+
171
+ # @spaces.GPU #[uncomment to use ZeroGPU]
172
+ def instantir_restore(lq, prompt="", steps=30, cfg_scale=7.0, guidance_end=1.0, creative_restoration=False, seed=3407):
173
+ if creative_restoration:
174
+ if "lcm" not in pipe.unet.active_adapters():
175
+ pipe.unet.set_adapter('lcm')
176
+ else:
177
+ if "previewer" not in pipe.unet.active_adapters():
178
+ pipe.unet.set_adapter('previewer')
179
+
180
+ if isinstance(guidance_end, int):
181
+ guidance_end = guidance_end / steps
182
+ with torch.no_grad(): lq = [transform(lq)]
183
+ generator = torch.Generator(device=device).manual_seed(seed)
184
+
185
+ out = pipe(
186
+ prompt=[prompt]*len(lq),
187
+ image=lq,
188
+ ip_adapter_image=[lq],
189
+ num_inference_steps=steps,
190
+ generator=generator,
191
+ controlnet_conditioning_scale=1.0,
192
+ # negative_original_size=(256,256),
193
+ # negative_target_size=(1024,1024),
194
+ negative_prompt=[""]*len(lq),
195
+ guidance_scale=cfg_scale,
196
+ control_guidance_end=guidance_end,
197
+ # control_guidance_start=0.5,
198
+ previewer_scheduler=lcm_scheduler,
199
+ return_dict=False,
200
+ save_preview_row=True,
201
+ # reference_latent = reference_latents,
202
+ # output_type='pt'
203
+ )
204
+ for i, preview_img in enumerate(out[1]):
205
+ preview_img.append(f"preview_{i}")
206
+ return out[0][0], out[1]
207
 
208
  examples = [
209
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
 
219
  """
220
 
221
  with gr.Blocks(css=css) as demo:
222
+ gr.Markdown(
223
+ """
224
+ # InstantIR: Blind Image Restoration with Instant Generative Reference.
225
+
226
+ ### **Official 🤗 Gradio demo of [InstantIR](https://arxiv.org/abs/2410.06551).**
227
+ ### **InstantIR can not only help you restore your broken image, but also capable of imaginative re-creation following your text prompts. See advance usage for more details!**
228
+ ## Basic usage: revitalize your image
229
+ 1. Upload an image you want to restore;
230
+ 2. Optionally, tune the `Steps` `CFG Scale` parameters. Typically higher steps lead to better results, but less than 50 is recommended for efficiency;
231
+ 3. Click `InstantIR magic!`.
232
+ """)
233
+ with gr.Row():
234
+ lq_img = gr.Image(label="Low-quality image", type="pil")
235
+ with gr.Column(elem_id="col-container"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  with gr.Row():
237
+ steps = gr.Number(label="Steps", value=20, step=1)
238
+ cfg_scale = gr.Number(label="CFG Scale", value=7.0, step=0.1)
239
+ seed = gr.Number(label="Seed", value=42, step=1)
240
+ # guidance_start = gr.Slider(label="Guidance Start", value=1.0, minimum=0.0, maximum=1.0, step=0.05)
241
+ guidance_end = gr.Slider(label="Start Free Rendering", value=20, minimum=0, maximum=20, step=1)
242
+ prompt = gr.Textbox(
243
+ label="Restoration prompts (Optional)", show_label=False,
244
+ placeholder="Restoration prompts (Optional)", value='',
245
+ # container=False,
246
+ )
247
+ mode = gr.Checkbox(label="Creative Restoration", value=False)
248
+ # with gr.Accordion("Advanced Settings", open=False):
249
+ with gr.Row():
250
+ with gr.Row():
251
+ restore_btn = gr.Button("InstantIR magic!")
252
+ clear_btn = gr.ClearButton()
253
+ index = gr.Slider(label="Restoration Previews", value=19, minimum=0, maximum=19, step=1)
254
+ with gr.Row():
255
+ output = gr.Image(label="InstantIR restored", type="pil")
256
+ preview = gr.Image(label="Preview", type="pil")
257
+ # gr.Examples(
258
+ # examples = examples,
259
+ # inputs = [prompt]
260
+ # )
261
+ # gr.on(
262
+ # triggers=[restore_btn.click, prompt.submit],
263
+ # fn = infer,
264
+ # inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
265
+ # outputs = [result, seed]
266
+ # )
267
+ pipe_out = gr.Gallery(visible=False)
268
+ clear_btn.add([lq_img, output, preview])
269
+ restore_btn.click(instantir_restore, inputs=[lq_img, prompt, steps, cfg_scale, guidance_end, mode, seed], outputs=[output, pipe_out], api_name="InstantIR")
270
+ steps.change(dynamic_guidance_slider, inputs=steps, outputs=guidance_end)
271
+ output.change(dynamic_preview_slider, inputs=steps, outputs=index)
272
+ index.release(unpack_pipe_out, inputs=[pipe_out, index], outputs=preview)
273
+ output.change(show_final_preview, inputs=pipe_out, outputs=preview)
274
+ gr.Markdown(
275
+ """
276
+ ## Advance usage:
277
+ ### Browse restoration variants:
278
+ 1. After InstantIR processing, drag the `Restoration Previews` slider to explore other in-progress versions;
279
+ 2. If you like one of them, set the `Start Free Rendering` slider to the same value to get a more refined result.
280
+ ### Creative restoration:
281
+ 1. Check the `Creative Restoration` checkbox;
282
+ 2. Input your text prompts in the `Restoration prompts` textbox;
283
+ 3. Set `Start Free Rendering` slider to a medium value (around half of the `steps`) to provide adequate room for InstantIR creation.
284
+
285
+ ## Examples
286
+ Here are some examplar usage of InstantIR:
287
+ """)
288
+ # examples = gr.Gallery(label="Examples")
289
+
290
+ gr.Markdown(
291
+ """
292
+ ## Citation
293
+ If InstantIR is helpful to your work, please cite our paper via:
294
+
295
+ ```
296
+ @article{huang2024instantir,
297
+ title={InstantIR: Blind Image Restoration with Instant Generative Reference},
298
+ author={Huang, Jen-Yuan and Wang, Haofan and Wang, Qixun and Bai, Xu and Ai, Hao and Xing, Peng and Huang, Jen-Tse},
299
+ journal={arXiv preprint arXiv:2410.06551},
300
+ year={2024}
301
+ }
302
+ ```
303
+ """)
304
 
305
+ demo.queue().launch(debug=True)