silentchen commited on
Commit
3ab28ab
·
1 Parent(s): eea38fb

first commit

Browse files
DejaVuSansMono.ttf ADDED
Binary file (341 kB). View file
 
app.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from omegaconf import OmegaConf
4
+ # from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt
5
+
6
+ import json
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ from functools import partial
10
+ from collections import Counter
11
+ import math
12
+ import gc
13
+
14
+ from gradio import processing_utils
15
+ from typing import Optional
16
+
17
+ import warnings
18
+
19
+ from datetime import datetime
20
+
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
24
+
25
+ import sys
26
+
27
+ sys.tracebacklimit = 0
28
+
29
+
30
+ def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None):
31
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
32
+ return torch.load(cache_file, map_location='cpu')
33
+
34
+
35
+ def load_ckpt_config_from_hf(modality):
36
+ ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='model')
37
+ config = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='config')
38
+ return ckpt, config
39
+
40
+
41
+ def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None):
42
+ pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality)
43
+ config = OmegaConf.create(config["_content"]) # config used in training
44
+ config.alpha_scale = 1.0
45
+ config.model['params']['is_inpaint'] = is_inpaint
46
+ config.model['params']['is_style'] = is_style
47
+
48
+ if common_instances is None:
49
+ common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model')
50
+ common_instances = load_common_ckpt(config, common_ckpt)
51
+
52
+ loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen, common_instances)
53
+
54
+ return loaded_model_list, common_instances
55
+
56
+
57
+ class Instance:
58
+ def __init__(self, capacity=2):
59
+ self.model_type = 'base'
60
+ self.loaded_model_list = {}
61
+ self.counter = Counter()
62
+ self.global_counter = Counter()
63
+ self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
64
+ 'gligen-generation-text-box',
65
+ is_inpaint=False, is_style=False, common_instances=None
66
+ )
67
+ self.capacity = capacity
68
+
69
+ def _log(self, model_type, batch_size, instruction, phrase_list):
70
+ self.counter[model_type] += 1
71
+ self.global_counter[model_type] += 1
72
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
73
+ print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format(
74
+ current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list
75
+ ))
76
+
77
+ def get_model(self, model_type, batch_size, instruction, phrase_list):
78
+ if model_type in self.loaded_model_list:
79
+ self._log(model_type, batch_size, instruction, phrase_list)
80
+ return self.loaded_model_list[model_type]
81
+
82
+ if self.capacity == len(self.loaded_model_list):
83
+ least_used_type = self.counter.most_common()[-1][0]
84
+ del self.loaded_model_list[least_used_type]
85
+ del self.counter[least_used_type]
86
+ gc.collect()
87
+ torch.cuda.empty_cache()
88
+
89
+ self.loaded_model_list[model_type] = self._get_model(model_type)
90
+ self._log(model_type, batch_size, instruction, phrase_list)
91
+ return self.loaded_model_list[model_type]
92
+
93
+ def _get_model(self, model_type):
94
+ if model_type == 'base':
95
+ return ckpt_load_helper(
96
+ 'gligen-generation-text-box',
97
+ is_inpaint=False, is_style=False, common_instances=self.common_instances
98
+ )[0]
99
+ elif model_type == 'inpaint':
100
+ return ckpt_load_helper(
101
+ 'gligen-inpainting-text-box',
102
+ is_inpaint=True, is_style=False, common_instances=self.common_instances
103
+ )[0]
104
+ elif model_type == 'style':
105
+ return ckpt_load_helper(
106
+ 'gligen-generation-text-image-box',
107
+ is_inpaint=False, is_style=True, common_instances=self.common_instances
108
+ )[0]
109
+
110
+ assert False
111
+
112
+
113
+ # instance = Instance()
114
+
115
+
116
+ def load_clip_model():
117
+ from transformers import CLIPProcessor, CLIPModel
118
+ version = "openai/clip-vit-large-patch14"
119
+ model = CLIPModel.from_pretrained(version).cuda()
120
+ processor = CLIPProcessor.from_pretrained(version)
121
+
122
+ return {
123
+ 'version': version,
124
+ 'model': model,
125
+ 'processor': processor,
126
+ }
127
+
128
+
129
+ # clip_model = load_clip_model()
130
+
131
+
132
+ class ImageMask(gr.components.Image):
133
+ """
134
+ Sets: source="canvas", tool="sketch"
135
+ """
136
+
137
+ is_template = True
138
+
139
+ def __init__(self, **kwargs):
140
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
141
+
142
+ def preprocess(self, x):
143
+ if x is None:
144
+ return x
145
+ if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
146
+ decode_image = processing_utils.decode_base64_to_image(x)
147
+ width, height = decode_image.size
148
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
149
+ mask[..., -1] = 255
150
+ mask = self.postprocess(mask)
151
+ x = {'image': x, 'mask': mask}
152
+ return super().preprocess(x)
153
+
154
+
155
+ class Blocks(gr.Blocks):
156
+
157
+ def __init__(
158
+ self,
159
+ theme: str = "default",
160
+ analytics_enabled: Optional[bool] = None,
161
+ mode: str = "blocks",
162
+ title: str = "Gradio",
163
+ css: Optional[str] = None,
164
+ **kwargs,
165
+ ):
166
+ self.extra_configs = {
167
+ 'thumbnail': kwargs.pop('thumbnail', ''),
168
+ 'url': kwargs.pop('url', 'https://gradio.app/'),
169
+ 'creator': kwargs.pop('creator', '@teamGradio'),
170
+ }
171
+
172
+ super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
173
+ warnings.filterwarnings("ignore")
174
+
175
+ def get_config_file(self):
176
+ config = super(Blocks, self).get_config_file()
177
+
178
+ for k, v in self.extra_configs.items():
179
+ config[k] = v
180
+
181
+ return config
182
+
183
+
184
+ '''
185
+ inference model
186
+ '''
187
+
188
+
189
+ @torch.no_grad()
190
+ def inference(task, language_instruction, grounding_instruction, inpainting_boxes_nodrop, image,
191
+ alpha_sample, guidance_scale, batch_size,
192
+ fix_seed, rand_seed, actual_mask, style_image,
193
+ *args, **kwargs):
194
+ grounding_instruction = json.loads(grounding_instruction)
195
+ phrase_list, location_list = [], []
196
+ for k, v in grounding_instruction.items():
197
+ phrase_list.append(k)
198
+ location_list.append(v)
199
+
200
+ placeholder_image = Image.open('images/teddy.jpg').convert("RGB")
201
+ image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled
202
+
203
+ batch_size = int(batch_size)
204
+ if not 1 <= batch_size <= 4:
205
+ batch_size = 2
206
+
207
+ if style_image == None:
208
+ has_text_mask = 1
209
+ has_image_mask = 0 # then we hack above 'image_list'
210
+ else:
211
+ valid_phrase_len = len(phrase_list)
212
+
213
+ phrase_list += ['placeholder']
214
+ has_text_mask = [1] * valid_phrase_len + [0]
215
+
216
+ image_list = [placeholder_image] * valid_phrase_len + [style_image]
217
+ has_image_mask = [0] * valid_phrase_len + [1]
218
+
219
+ location_list += [[0.0, 0.0, 1, 0.01]] # style image grounding location
220
+
221
+ if task == 'Grounded Inpainting':
222
+ alpha_sample = 1.0
223
+
224
+ instruction = dict(
225
+ prompt=language_instruction,
226
+ phrases=phrase_list,
227
+ images=image_list,
228
+ locations=location_list,
229
+ alpha_type=[alpha_sample, 0, 1.0 - alpha_sample],
230
+ has_text_mask=has_text_mask,
231
+ has_image_mask=has_image_mask,
232
+ save_folder_name=language_instruction,
233
+ guidance_scale=guidance_scale,
234
+ batch_size=batch_size,
235
+ fix_seed=bool(fix_seed),
236
+ rand_seed=int(rand_seed),
237
+ actual_mask=actual_mask,
238
+ inpainting_boxes_nodrop=inpainting_boxes_nodrop,
239
+ )
240
+
241
+ get_model = partial(instance.get_model,
242
+ batch_size=batch_size,
243
+ instruction=language_instruction,
244
+ phrase_list=phrase_list)
245
+
246
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
247
+ if task == 'Grounded Generation':
248
+ if style_image == None:
249
+ return grounded_generation_box(get_model('base'), instruction, *args, **kwargs)
250
+ else:
251
+ return grounded_generation_box(get_model('style'), instruction, *args, **kwargs)
252
+ elif task == 'Grounded Inpainting':
253
+ assert image is not None
254
+ instruction['input_image'] = image.convert("RGB")
255
+ return grounded_generation_box(get_model('inpaint'), instruction, *args, **kwargs)
256
+
257
+
258
+ def draw_box(boxes=[], texts=[], img=None):
259
+ if len(boxes) == 0 and img is None:
260
+ return None
261
+
262
+ if img is None:
263
+ img = Image.new('RGB', (512, 512), (255, 255, 255))
264
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
265
+ draw = ImageDraw.Draw(img)
266
+ font = ImageFont.truetype("DejaVuSansMono.ttf", size=18)
267
+ print(boxes)
268
+ for bid, box in enumerate(boxes):
269
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
270
+ anno_text = texts[bid]
271
+ draw.rectangle(
272
+ [box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]],
273
+ outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4)
274
+ draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size * 1.2)], anno_text, font=font,
275
+ fill=(255, 255, 255))
276
+ return img
277
+
278
+
279
+ def get_concat(ims):
280
+ if len(ims) == 1:
281
+ n_col = 1
282
+ else:
283
+ n_col = 2
284
+ n_row = math.ceil(len(ims) / 2)
285
+ dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white")
286
+ for i, im in enumerate(ims):
287
+ row_id = i // n_col
288
+ col_id = i % n_col
289
+ dst.paste(im, (im.width * col_id, im.height * row_id))
290
+ return dst
291
+
292
+
293
+ def auto_append_grounding(language_instruction, grounding_texts):
294
+ for grounding_text in grounding_texts:
295
+ if grounding_text not in language_instruction and grounding_text != 'auto':
296
+ language_instruction += "; " + grounding_text
297
+ return language_instruction
298
+
299
+
300
+ def generate(task, language_instruction, grounding_texts, sketch_pad,
301
+ alpha_sample, guidance_scale, batch_size,
302
+ fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image,
303
+ state):
304
+ if 'boxes' not in state:
305
+ state['boxes'] = []
306
+
307
+ boxes = state['boxes']
308
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
309
+ # assert len(boxes) == len(grounding_texts)
310
+ if len(boxes) != len(grounding_texts):
311
+ if len(boxes) < len(grounding_texts):
312
+ raise ValueError("""The number of boxes should be equal to the number of grounding objects.
313
+ Number of boxes drawn: {}, number of grounding tokens: {}.
314
+ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
315
+ grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
316
+
317
+ boxes = (np.asarray(boxes) / 512).tolist()
318
+ grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
319
+
320
+ image = None
321
+ actual_mask = None
322
+ if task == 'Grounded Inpainting':
323
+ image = state.get('original_image', sketch_pad['image']).copy()
324
+ image = center_crop(image)
325
+ image = Image.fromarray(image)
326
+
327
+ if use_actual_mask:
328
+ actual_mask = sketch_pad['mask'].copy()
329
+ if actual_mask.ndim == 3:
330
+ actual_mask = actual_mask[..., 0]
331
+ actual_mask = center_crop(actual_mask, tgt_size=(64, 64))
332
+ actual_mask = torch.from_numpy(actual_mask == 0).float()
333
+
334
+ if state.get('inpaint_hw', None):
335
+ boxes = np.asarray(boxes) * 0.9 + 0.05
336
+ boxes = boxes.tolist()
337
+ grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes) if obj != 'auto'})
338
+
339
+ if append_grounding:
340
+ language_instruction = auto_append_grounding(language_instruction, grounding_texts)
341
+
342
+ gen_images, gen_overlays = inference(
343
+ task, language_instruction, grounding_instruction, boxes, image,
344
+ alpha_sample, guidance_scale, batch_size,
345
+ fix_seed, rand_seed, actual_mask, style_cond_image, clip_model=clip_model,
346
+ )
347
+
348
+ for idx, gen_image in enumerate(gen_images):
349
+
350
+ if task == 'Grounded Inpainting' and state.get('inpaint_hw', None):
351
+ hw = min(*state['original_image'].shape[:2])
352
+ gen_image = sized_center_fill(state['original_image'].copy(), np.array(gen_image.resize((hw, hw))), hw, hw)
353
+ gen_image = Image.fromarray(gen_image)
354
+
355
+ gen_images[idx] = gen_image
356
+
357
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
358
+ gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
359
+ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
360
+ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
361
+
362
+ return gen_images + [state]
363
+
364
+
365
+ def binarize(x):
366
+ return (x != 0).astype('uint8') * 255
367
+
368
+
369
+ def sized_center_crop(img, cropx, cropy):
370
+ y, x = img.shape[:2]
371
+ startx = x // 2 - (cropx // 2)
372
+ starty = y // 2 - (cropy // 2)
373
+ return img[starty:starty + cropy, startx:startx + cropx]
374
+
375
+
376
+ def sized_center_fill(img, fill, cropx, cropy):
377
+ y, x = img.shape[:2]
378
+ startx = x // 2 - (cropx // 2)
379
+ starty = y // 2 - (cropy // 2)
380
+ img[starty:starty + cropy, startx:startx + cropx] = fill
381
+ return img
382
+
383
+
384
+ def sized_center_mask(img, cropx, cropy):
385
+ y, x = img.shape[:2]
386
+ startx = x // 2 - (cropx // 2)
387
+ starty = y // 2 - (cropy // 2)
388
+ center_region = img[starty:starty + cropy, startx:startx + cropx].copy()
389
+ img = (img * 0.2).astype('uint8')
390
+ img[starty:starty + cropy, startx:startx + cropx] = center_region
391
+ return img
392
+
393
+
394
+ def center_crop(img, HW=None, tgt_size=(512, 512)):
395
+ if HW is None:
396
+ H, W = img.shape[:2]
397
+ HW = min(H, W)
398
+ img = sized_center_crop(img, HW, HW)
399
+ img = Image.fromarray(img)
400
+ img = img.resize(tgt_size)
401
+ return np.array(img)
402
+
403
+
404
+ def draw(task, input, grounding_texts, new_image_trigger, state):
405
+ if type(input) == dict:
406
+ image = input['image']
407
+ mask = input['mask']
408
+ else:
409
+ mask = input
410
+
411
+ if mask.ndim == 3:
412
+ mask = mask[..., 0]
413
+
414
+ image_scale = 1.0
415
+
416
+ # resize trigger
417
+ if task == "Grounded Inpainting":
418
+ mask_cond = mask.sum() == 0
419
+ # size_cond = mask.shape != (512, 512)
420
+ if mask_cond and 'original_image' not in state:
421
+ image = Image.fromarray(image)
422
+ width, height = image.size
423
+ scale = 600 / min(width, height)
424
+ image = image.resize((int(width * scale), int(height * scale)))
425
+ state['original_image'] = np.array(image).copy()
426
+ image_scale = float(height / width)
427
+ return [None, new_image_trigger + 1, image_scale, state]
428
+ else:
429
+ original_image = state['original_image']
430
+ H, W = original_image.shape[:2]
431
+ image_scale = float(H / W)
432
+
433
+ mask = binarize(mask)
434
+ if mask.shape != (512, 512):
435
+ # assert False, "should not receive any non- 512x512 masks."
436
+ if 'original_image' in state and state['original_image'].shape[:2] == mask.shape:
437
+ mask = center_crop(mask, state['inpaint_hw'])
438
+ image = center_crop(state['original_image'], state['inpaint_hw'])
439
+ else:
440
+ mask = np.zeros((512, 512), dtype=np.uint8)
441
+ # mask = center_crop(mask)
442
+ mask = binarize(mask)
443
+
444
+ if type(mask) != np.ndarray:
445
+ mask = np.array(mask)
446
+
447
+ if mask.sum() == 0 and task != "Grounded Inpainting":
448
+ state = {}
449
+
450
+ if task != 'Grounded Inpainting':
451
+ image = None
452
+ else:
453
+ image = Image.fromarray(image)
454
+
455
+ if 'boxes' not in state:
456
+ state['boxes'] = []
457
+
458
+ if 'masks' not in state or len(state['masks']) == 0:
459
+ state['masks'] = []
460
+ last_mask = np.zeros_like(mask)
461
+ else:
462
+ last_mask = state['masks'][-1]
463
+
464
+ if type(mask) == np.ndarray and mask.size > 1:
465
+ diff_mask = mask - last_mask
466
+ else:
467
+ diff_mask = np.zeros([])
468
+
469
+ if diff_mask.sum() > 0:
470
+ x1x2 = np.where(diff_mask.max(0) != 0)[0]
471
+ y1y2 = np.where(diff_mask.max(1) != 0)[0]
472
+ y1, y2 = y1y2.min(), y1y2.max()
473
+ x1, x2 = x1x2.min(), x1x2.max()
474
+
475
+ if (x2 - x1 > 5) and (y2 - y1 > 5):
476
+ state['masks'].append(mask.copy())
477
+ state['boxes'].append((x1, y1, x2, y2))
478
+
479
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
480
+ grounding_texts = [x for x in grounding_texts if len(x) > 0]
481
+ if len(grounding_texts) < len(state['boxes']):
482
+ grounding_texts += [f'Obj. {bid + 1}' for bid in range(len(grounding_texts), len(state['boxes']))]
483
+ print("state", state)
484
+ box_image = draw_box(state['boxes'], grounding_texts, image)
485
+
486
+ if box_image is not None and state.get('inpaint_hw', None):
487
+ inpaint_hw = state['inpaint_hw']
488
+ box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw)))
489
+ original_image = state['original_image'].copy()
490
+ box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw)
491
+ print(box_image, new_image_trigger, image_scale, state)
492
+ return [box_image, new_image_trigger, image_scale, state]
493
+
494
+
495
+ def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
496
+ if task != 'Grounded Inpainting':
497
+ sketch_pad_trigger = sketch_pad_trigger + 1
498
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
499
+ out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \
500
+ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
501
+ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
502
+ state = {}
503
+ return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
504
+
505
+
506
+ css = """
507
+ #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
508
+ {
509
+ height: var(--height) !important;
510
+ max-height: var(--height) !important;
511
+ min-height: var(--height) !important;
512
+ }
513
+ #paper-info a {
514
+ color:#008AD7;
515
+ text-decoration: none;
516
+ }
517
+ #paper-info a:hover {
518
+ cursor: pointer;
519
+ text-decoration: none;
520
+ }
521
+ """
522
+
523
+ rescale_js = """
524
+ function(x) {
525
+ const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
526
+ let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
527
+ const image_width = root.querySelector('#img2img_image').clientWidth;
528
+ const target_height = parseInt(image_width * image_scale);
529
+ document.body.style.setProperty('--height', `${target_height}px`);
530
+ root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
531
+ root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
532
+ return x;
533
+ }
534
+ """
535
+
536
+ with Blocks(
537
+ css=css,
538
+ analytics_enabled=False,
539
+ title="GLIGen demo",
540
+ ) as main:
541
+ description = """<p style="text-align: center; font-weight: bold;">
542
+ <span style="font-size: 28px">Layout Guidance</span>
543
+ <br>
544
+ <span style="font-size: 18px" id="paper-info">
545
+ [<a href="https://gligen.github.io" target="_blank">Project Page</a>]
546
+ [<a href="https://arxiv.org/abs/2301.07093" target="_blank">Paper</a>]
547
+ [<a href="https://github.com/gligen/GLIGEN" target="_blank">GitHub</a>]
548
+ [<a href="https://huggingface.co/spaces/gligen/demo_legacy" target="_blank">Mirror</a>]
549
+ </span>
550
+ </p>
551
+ """
552
+ gr.HTML(description)
553
+
554
+ with gr.Row():
555
+ with gr.Column(scale=4):
556
+ sketch_pad_trigger = gr.Number(value=0, visible=False)
557
+ sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
558
+ init_white_trigger = gr.Number(value=0, visible=False)
559
+ image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
560
+ new_image_trigger = gr.Number(value=0, visible=False)
561
+
562
+ # task = gr.Radio(
563
+ # choices=["Grounded Generation", 'Grounded Inpainting'],
564
+ # type="value",
565
+ # value="Grounded Generation",
566
+ # label="Task",
567
+ # )
568
+ language_instruction = gr.Textbox(
569
+ label="Text Caption",
570
+ )
571
+ grounding_instruction = gr.Textbox(
572
+ label="Grounding instruction (Separated by semicolon)",
573
+ )
574
+ with gr.Row():
575
+ sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
576
+ out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
577
+ with gr.Row():
578
+ clear_btn = gr.Button(value='Clear')
579
+ gen_btn = gr.Button(value='Generate')
580
+ with gr.Accordion("Advanced Options", open=False):
581
+ with gr.Column():
582
+ Loss_scale = gr.Slider(minimum=0, maximum=500, step=5, value=30,
583
+ label="Loss Scale Factor")
584
+ guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
585
+ batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=2, label="Number of Samples")
586
+ max_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Max Iteration per Step")
587
+ loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss Threshold")
588
+ max_step = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Max Step of Backward Guidance")
589
+
590
+ # append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption")
591
+ # use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
592
+ with gr.Row():
593
+ fix_seed = gr.Checkbox(value=True, label="Fixed seed")
594
+ rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed")
595
+
596
+ with gr.Column(scale=4):
597
+ gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>')
598
+ with gr.Row():
599
+ out_gen_1 = gr.Image(type="pil", visible=True, show_label=False, label="Generated Image")
600
+ out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
601
+ with gr.Row():
602
+ out_gen_3 = gr.Image(type="pil", visible=False, show_label=False)
603
+ out_gen_4 = gr.Image(type="pil", visible=False, show_label=False)
604
+
605
+ state = gr.State({})
606
+
607
+
608
+ class Controller:
609
+ def __init__(self):
610
+ self.calls = 0
611
+ self.tracks = 0
612
+ self.resizes = 0
613
+ self.scales = 0
614
+
615
+ def init_white(self, init_white_trigger):
616
+ self.calls += 1
617
+ return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger + 1
618
+
619
+ def change_n_samples(self, n_samples):
620
+ blank_samples = n_samples % 2 if n_samples > 1 else 0
621
+ return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
622
+ + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
623
+
624
+ def resize_centercrop(self, state):
625
+ self.resizes += 1
626
+ image = state['original_image'].copy()
627
+ inpaint_hw = int(0.9 * min(*image.shape[:2]))
628
+ state['inpaint_hw'] = inpaint_hw
629
+ image_cc = center_crop(image, inpaint_hw)
630
+ # print(f'resize triggered {self.resizes}', image.shape, '->', image_cc.shape)
631
+ return image_cc, state
632
+
633
+ def resize_masked(self, state):
634
+ self.resizes += 1
635
+ image = state['original_image'].copy()
636
+ inpaint_hw = int(0.9 * min(*image.shape[:2]))
637
+ state['inpaint_hw'] = inpaint_hw
638
+ image_mask = sized_center_mask(image, inpaint_hw, inpaint_hw)
639
+ state['masked_image'] = image_mask.copy()
640
+ # print(f'mask triggered {self.resizes}')
641
+ return image_mask, state
642
+
643
+ def switch_task_hide_cond(self, task):
644
+ cond = False
645
+ if task == "Grounded Generation":
646
+ cond = True
647
+
648
+ return gr.Checkbox.update(visible=cond, value=False), gr.Image.update(value=None,
649
+ visible=False), gr.Slider.update(
650
+ visible=cond), gr.Checkbox.update(visible=(not cond), value=False)
651
+
652
+
653
+ controller = Controller()
654
+ main.load(
655
+ lambda x: x + 1,
656
+ inputs=sketch_pad_trigger,
657
+ outputs=sketch_pad_trigger,
658
+ queue=False)
659
+ sketch_pad.edit(
660
+ draw,
661
+ inputs=[sketch_pad, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
662
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
663
+ queue=False,
664
+ )
665
+ grounding_instruction.change(
666
+ draw,
667
+ inputs=[sketch_pad, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
668
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
669
+ queue=False,
670
+ )
671
+ clear_btn.click(
672
+ clear,
673
+ inputs=[sketch_pad_trigger, sketch_pad_trigger, batch_size, state],
674
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3,
675
+ out_gen_4, state],
676
+ queue=False)
677
+ # task.change(
678
+ # partial(clear, switch_task=True),
679
+ # inputs=[task, sketch_pad_trigger, batch_size, state],
680
+ # outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3,
681
+ # out_gen_4, state],
682
+ # queue=False)
683
+ sketch_pad_trigger.change(
684
+ controller.init_white,
685
+ inputs=[init_white_trigger],
686
+ outputs=[sketch_pad, image_scale, init_white_trigger],
687
+ queue=False)
688
+ sketch_pad_resize_trigger.change(
689
+ controller.resize_masked,
690
+ inputs=[state],
691
+ outputs=[sketch_pad, state],
692
+ queue=False)
693
+ batch_size.change(
694
+ controller.change_n_samples,
695
+ inputs=[batch_size],
696
+ outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4],
697
+ queue=False)
698
+
699
+ batch_size.change(
700
+ controller.change_n_samples,
701
+ inputs=[batch_size],
702
+ outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4],
703
+ queue=False)
704
+
705
+ gen_btn.click(
706
+ generate,
707
+ inputs=[
708
+ language_instruction, language_instruction, grounding_instruction, sketch_pad,
709
+ loss_threshold, guidance_scale, batch_size,
710
+ fix_seed, rand_seed,
711
+ max_step,
712
+ Loss_scale, max_iter,
713
+ state,
714
+ ],
715
+ outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
716
+ queue=True
717
+ )
718
+ sketch_pad_resize_trigger.change(
719
+ None,
720
+ None,
721
+ sketch_pad_resize_trigger,
722
+ _js=rescale_js,
723
+ queue=False)
724
+ init_white_trigger.change(
725
+ None,
726
+ None,
727
+ init_white_trigger,
728
+ _js=rescale_js,
729
+ queue=False)
730
+
731
+ with gr.Column():
732
+ gr.Examples(
733
+ examples=[
734
+ [
735
+ "images/input.png",
736
+ "A hello kitty toy is playing with a purple ball.",
737
+ "hello kitty;ball",
738
+ "images/hello_kitty_results.png"
739
+ ],
740
+ ],
741
+ inputs=[sketch_pad, language_instruction, grounding_instruction, out_gen_1],
742
+ outputs=None,
743
+ fn=None,
744
+ cache_examples=False,
745
+ )
746
+
747
+ main.queue(concurrency_count=1, api_open=False)
748
+ main.launch(share=False, show_api=False, show_error=True)
images/hello_kitty_results.png ADDED
images/input.png ADDED
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ xformers==0.0.16
4
+ omegaconf==2.1.1
5
+ albumentations==1.3.0
6
+ opencv-python
7
+ imageio==2.9.0
8
+ imageio-ffmpeg==0.4.2
9
+ pytorch-lightning==1.4.2
10
+ test-tube>=0.7.5
11
+ streamlit==1.17.0
12
+ einops==0.3.0
13
+ git+https://github.com/openai/CLIP.git
14
+ protobuf~=3.20.1
15
+ torchmetrics==0.6.0
16
+ transformers==4.19.2
17
+ kornia==0.6.0
18
+ gradio==3.19.1