dragonSwing commited on
Commit
eeef127
·
1 Parent(s): 156bb47

Update script arguments

Browse files
Files changed (1) hide show
  1. app.py +279 -180
app.py CHANGED
@@ -1,96 +1,47 @@
 
1
  import json
2
  import os
3
- import subprocess
4
  import sys
5
  import tempfile
6
 
7
- import gradio as gr
8
  import numpy as np
9
  import supervision as sv
10
- import torch
 
11
  from PIL import Image
12
- from segment_anything import build_sam
13
  from segment_anything import SamAutomaticMaskGenerator
14
  from segment_anything import SamPredictor
15
- from supervision.detection.utils import mask_to_polygons
16
  from supervision.detection.utils import xywh_to_xyxy
17
-
18
- if os.environ.get("IS_MY_DEBUG") is None:
19
- result = subprocess.run(["pip", "install", "-e", "GroundingDINO"], check=True)
20
- print(f"pip install GroundingDINO = {result}")
21
 
22
  sys.path.append("tag2text")
23
- sys.path.append("GroundingDINO")
24
 
25
  from tag2text.models import tag2text
26
- from groundingdino.util.inference import Model as DinoModel
27
  from config import *
28
- from utils import download_file_hf, detect, segment, show_anns_sam, generate_tags
29
-
30
- if not os.path.exists(abs_weight_dir):
31
- os.makedirs(abs_weight_dir, exist_ok=True)
32
-
33
- sam_checkpoint = os.path.join(abs_weight_dir, sam_dict[default_sam]["checkpoint_file"])
34
- if not os.path.exists(sam_checkpoint):
35
- os.system(f"wget {sam_dict[default_sam]['checkpoint_url']} -O {sam_checkpoint}")
36
-
37
- tag2text_checkpoint = os.path.join(
38
- abs_weight_dir, tag2text_dict[default_tag2text]["checkpoint_file"]
39
- )
40
- if not os.path.exists(tag2text_checkpoint):
41
- os.system(
42
- f"wget {tag2text_dict[default_tag2text]['checkpoint_url']} -O {tag2text_checkpoint}"
43
- )
44
-
45
- dino_checkpoint = os.path.join(
46
- abs_weight_dir, dino_dict[default_dino]["checkpoint_file"]
47
- )
48
- dino_config_file = os.path.join(abs_weight_dir, dino_dict[default_dino]["config_file"])
49
- if not os.path.exists(dino_checkpoint):
50
- dino_repo_id = dino_dict[default_dino]["repo_id"]
51
- download_file_hf(
52
- repo_id=dino_repo_id,
53
- filename=dino_dict[default_dino]["config_file"],
54
- cache_dir=weight_dir,
55
- )
56
- download_file_hf(
57
- repo_id=dino_repo_id,
58
- filename=dino_dict[default_dino]["checkpoint_file"],
59
- cache_dir=weight_dir,
60
- )
61
-
62
- # load model
63
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
- tag2text_model = tag2text.tag2text_caption(
65
- pretrained=tag2text_checkpoint,
66
- image_size=384,
67
- vit="swin_b",
68
- delete_tag_index=delete_tag_index,
69
- )
70
- # threshold for tagging
71
- # we reduce the threshold to obtain more tags
72
- tag2text_model.threshold = 0.64
73
- tag2text_model.to(device)
74
- tag2text_model.eval()
75
-
76
-
77
- sam = build_sam(checkpoint=sam_checkpoint)
78
- sam.to(device=device)
79
- sam_predictor = SamPredictor(sam)
80
- sam_automask_generator = SamAutomaticMaskGenerator(sam)
81
-
82
- grounding_dino_model = DinoModel(
83
- model_config_path=dino_config_file,
84
- model_checkpoint_path=dino_checkpoint,
85
- device=device,
86
- )
87
-
88
-
89
- def process(image_path, task, prompt, box_threshold, text_threshold, iou_threshold):
90
- global tag2text_model, sam_predictor, sam_automask_generator, grounding_dino_model, device
91
- output_gallery = []
92
  detections = None
93
- metadata = {"image": {}, "annotations": []}
 
 
 
94
 
95
  try:
96
  # Load image
@@ -100,17 +51,18 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
100
 
101
  # Extract image metadata
102
  filename = os.path.basename(image_path)
 
103
  h, w = image.shape[:2]
104
  metadata["image"]["file_name"] = filename
105
  metadata["image"]["width"] = w
106
  metadata["image"]["height"] = h
107
 
108
  # Generate tags
109
- if task in ["auto", "detect"] and prompt == "":
110
  tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
111
  prompt = " . ".join(tags)
112
- print(f"Caption: {caption}")
113
- print(f"Tags: {tags}")
114
 
115
  # ToDo: Extract metadata
116
  metadata["image"]["caption"] = caption
@@ -118,7 +70,6 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
118
 
119
  if prompt:
120
  metadata["prompt"] = prompt
121
- print(f"Prompt: {prompt}")
122
 
123
  # Detect boxes
124
  if prompt != "":
@@ -131,18 +82,21 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
131
  iou_threshold=iou_threshold,
132
  post_process=True,
133
  )
134
- print(phrases)
135
-
136
- # Draw boxes
137
- box_annotator = sv.BoxAnnotator()
138
- labels = [
139
- f"{phrases[i]} {detections.confidence[i]:0.2f}"
140
- for i in range(len(phrases))
141
- ]
142
- image = box_annotator.annotate(
143
- scene=image, detections=detections, labels=labels
144
- )
145
- output_gallery.append(image)
 
 
 
146
 
147
  # Segmentation
148
  if task in ["auto", "segment"]:
@@ -167,18 +121,27 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
167
  detections = sv.Detections(
168
  xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
169
  )
170
- # opacity = 0.4
171
- # mask_image, _ = show_anns_sam(masks)
172
- # annotated_image = np.uint8(mask_image * opacity + image * (1 - opacity))
173
-
174
- mask_annotator = sv.MaskAnnotator()
175
- mask_image = np.zeros_like(image, dtype=np.uint8)
176
- mask_image = mask_annotator.annotate(
177
- mask_image, detections=detections, opacity=1
178
- )
179
- annotated_image = mask_annotator.annotate(image, detections=detections)
180
- output_gallery.append(mask_image)
181
- output_gallery.append(annotated_image)
 
 
 
 
 
 
 
 
 
182
 
183
  # ToDo: Extract metadata
184
  if detections:
@@ -201,86 +164,222 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
201
  metadata["annotations"].append(annotation)
202
  i += 1
203
 
204
- meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
205
- meta_file_path = meta_file.name
206
- with open(meta_file_path, "w") as fp:
207
- json.dump(metadata, fp)
208
-
209
- return output_gallery, meta_file_path
 
 
 
 
 
 
 
 
 
 
210
  except Exception as error:
211
- raise gr.Error(f"global exception: {error}")
212
-
213
-
214
- title = "Annotate Anything"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- with gr.Blocks(css="style.css", title=title) as demo:
217
- with gr.Row(elem_classes=["container"]):
218
- with gr.Column(scale=1):
219
- input_image = gr.Image(type="filepath", label="Input")
220
- task = gr.Dropdown(
221
- ["detect", "segment", "auto"], value="auto", label="task_type"
 
 
 
 
222
  )
223
- text_prompt = gr.Textbox(
224
- label="Detection Prompt",
225
- info="To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  )
227
- with gr.Accordion("Advanced parameters", open=False):
228
- box_threshold = gr.Slider(
229
- minimum=0,
230
- maximum=1,
231
- value=0.3,
232
- step=0.05,
233
- label="Box threshold",
234
- info="Hash size to use for image hashing",
235
- )
236
- text_threshold = gr.Slider(
237
- minimum=0,
238
- maximum=1,
239
- value=0.25,
240
- step=0.05,
241
- label="Text threshold",
242
- info="Number of history images used to find out duplicate image",
243
- )
244
- iou_threshold = gr.Slider(
245
- minimum=0,
246
- maximum=1,
247
- value=0.5,
248
- step=0.05,
249
- label="IOU threshold",
250
- info="Minimum similarity threshold (in percent) to consider 2 images to be similar",
251
- )
252
- run_button = gr.Button(label="Run")
253
-
254
- with gr.Column(scale=2):
255
- gallery = gr.Gallery(
256
- label="Generated images", show_label=False, elem_id="gallery"
257
- ).style(preview=True, grid=2, object_fit="scale-down")
258
- meta_file = gr.File(label="Metadata file")
259
-
260
- with gr.Row(elem_classes=["container"]):
261
- gr.Examples(
262
- [
263
- ["examples/dog.png", "auto", ""],
264
- ["examples/eiffel.png", "auto", ""],
265
- ["examples/eiffel.png", "segment", ""],
266
- ["examples/girl.png", "auto", "girl . face"],
267
- ["examples/horse.png", "detect", "horse"],
268
- ["examples/horses.jpg", "auto", "horse"],
269
- ["examples/traffic.jpg", "auto", ""],
270
- ],
271
- [input_image, task, text_prompt],
272
  )
273
- run_button.click(
274
- fn=process,
275
- inputs=[
276
- input_image,
277
- task,
278
- text_prompt,
279
- box_threshold,
280
- text_threshold,
281
- iou_threshold,
282
- ],
283
- outputs=[gallery, meta_file],
284
  )
285
 
286
- demo.queue(concurrency_count=2).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
  import json
3
  import os
 
4
  import sys
5
  import tempfile
6
 
 
7
  import numpy as np
8
  import supervision as sv
9
+ from groundingdino.util.inference import Model as DinoModel
10
+ from imutils import paths
11
  from PIL import Image
12
+ from segment_anything import sam_model_registry
13
  from segment_anything import SamAutomaticMaskGenerator
14
  from segment_anything import SamPredictor
 
15
  from supervision.detection.utils import xywh_to_xyxy
16
+ from tqdm import tqdm
 
 
 
17
 
18
  sys.path.append("tag2text")
 
19
 
20
  from tag2text.models import tag2text
 
21
  from config import *
22
+ from utils import detect, download_file_hf, segment, generate_tags, show_anns_sv
23
+
24
+
25
+ def process(
26
+ tag2text_model,
27
+ grounding_dino_model,
28
+ sam_predictor,
29
+ sam_automask_generator,
30
+ image_path,
31
+ task,
32
+ prompt,
33
+ box_threshold,
34
+ text_threshold,
35
+ iou_threshold,
36
+ device,
37
+ output_dir=None,
38
+ save_mask=False,
39
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  detections = None
41
+ metadata = {"image": {}, "annotations": [], "assets": {}}
42
+
43
+ if save_mask:
44
+ metadata["assets"]["intermediate_mask"] = []
45
 
46
  try:
47
  # Load image
 
51
 
52
  # Extract image metadata
53
  filename = os.path.basename(image_path)
54
+ basename = os.path.splitext(filename)[0]
55
  h, w = image.shape[:2]
56
  metadata["image"]["file_name"] = filename
57
  metadata["image"]["width"] = w
58
  metadata["image"]["height"] = h
59
 
60
  # Generate tags
61
+ if task in ["auto", "detection"] and prompt == "":
62
  tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
63
  prompt = " . ".join(tags)
64
+ # print(f"Caption: {caption}")
65
+ # print(f"Tags: {tags}")
66
 
67
  # ToDo: Extract metadata
68
  metadata["image"]["caption"] = caption
 
70
 
71
  if prompt:
72
  metadata["prompt"] = prompt
 
73
 
74
  # Detect boxes
75
  if prompt != "":
 
82
  iou_threshold=iou_threshold,
83
  post_process=True,
84
  )
85
+
86
+ # Save detection image
87
+ if output_dir:
88
+ # Draw boxes
89
+ box_annotator = sv.BoxAnnotator()
90
+ labels = [
91
+ f"{phrases[i]} {detections.confidence[i]:0.2f}"
92
+ for i in range(len(phrases))
93
+ ]
94
+ box_image = box_annotator.annotate(
95
+ scene=image, detections=detections, labels=labels
96
+ )
97
+ box_image_path = os.path.join(output_dir, basename + "_detect.png")
98
+ metadata["assets"]["detection"] = box_image_path
99
+ Image.fromarray(box_image).save(box_image_path)
100
 
101
  # Segmentation
102
  if task in ["auto", "segment"]:
 
121
  detections = sv.Detections(
122
  xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
123
  )
124
+
125
+ # Save annotated image
126
+ if output_dir:
127
+ mask_annotator = sv.MaskAnnotator()
128
+ mask_image, res = show_anns_sv(detections)
129
+ annotated_image = mask_annotator.annotate(image, detections=detections)
130
+
131
+ mask_image_path = os.path.join(output_dir, basename + "_mask.png")
132
+ metadata["assets"]["mask"] = mask_image_path
133
+ Image.fromarray(mask_image).save(mask_image_path)
134
+
135
+ # Save annotation encoding from https://github.com/LUSSeg/ImageNet-S
136
+ mask_enc_path = os.path.join(output_dir, basename + "_mask_enc.npy")
137
+ np.save(mask_enc_path, res)
138
+ metadata["assets"]["mask_enc"] = mask_enc_path
139
+
140
+ annotated_image_path = os.path.join(
141
+ output_dir, basename + "_annotate.png"
142
+ )
143
+ metadata["assets"]["annotate"] = annotated_image_path
144
+ Image.fromarray(annotated_image).save(annotated_image_path)
145
 
146
  # ToDo: Extract metadata
147
  if detections:
 
164
  metadata["annotations"].append(annotation)
165
  i += 1
166
 
167
+ if output_dir and save_mask:
168
+ mask_image_path = os.path.join(
169
+ output_dir, f"{basename}_mask_{id}.png"
170
+ )
171
+ metadata["assets"]["intermediate_mask"].append(mask_image_path)
172
+ Image.fromarray(mask * 255).save(mask_image_path)
173
+
174
+ if output_dir:
175
+ meta_file_path = os.path.join(output_dir, basename + "_meta.json")
176
+ with open(meta_file_path, "w") as fp:
177
+ json.dump(metadata, fp)
178
+ else:
179
+ meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
180
+ meta_file_path = meta_file.name
181
+
182
+ return meta_file_path
183
  except Exception as error:
184
+ raise ValueError(f"global exception: {error}")
185
+
186
+
187
+ def main(args: argparse.Namespace) -> None:
188
+ device = args.device
189
+ prompt = args.prompt
190
+ task = args.task
191
+
192
+ tag2text_model = None
193
+ grounding_dino_model = None
194
+ sam_predictor = None
195
+ sam_automask_generator = None
196
+
197
+ box_threshold = args.box_threshold
198
+ text_threshold = args.text_threshold
199
+ iou_threshold = args.iou_threshold
200
+ save_mask = args.save_mask
201
+
202
+ # load model
203
+ if task in ["auto", "detection"] and prompt == "":
204
+ print("Loading Tag2Text model...")
205
+ tag2text_type = args.tag2text_type
206
+ tag2text_checkpoint = os.path.join(
207
+ abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
208
+ )
209
+ if not os.path.exists(tag2text_checkpoint):
210
+ print(f"Downloading weights for Tag2Text {tag2text_type} model")
211
+ os.system(
212
+ f"wget {tag2text_dict[tag2text_type]['checkpoint_url']} -O {tag2text_checkpoint}"
213
+ )
214
+ tag2text_model = tag2text.tag2text_caption(
215
+ pretrained=tag2text_checkpoint,
216
+ image_size=384,
217
+ vit="swin_b",
218
+ delete_tag_index=delete_tag_index,
219
+ )
220
+ # threshold for tagging
221
+ # we reduce the threshold to obtain more tags
222
+ tag2text_model.threshold = 0.64
223
+ tag2text_model.to(device)
224
+ tag2text_model.eval()
225
+
226
+ if task in ["auto", "detection"] or prompt != "":
227
+ print("Loading Grounding Dino model...")
228
+ dino_type = args.dino_type
229
+ dino_checkpoint = os.path.join(
230
+ abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
231
+ )
232
+ dino_config_file = os.path.join(
233
+ abs_weight_dir, dino_dict[dino_type]["config_file"]
234
+ )
235
+ if not os.path.exists(dino_checkpoint):
236
+ print(f"Downloading weights for Grounding Dino {dino_type} model")
237
+ dino_repo_id = dino_dict[dino_type]["repo_id"]
238
+ download_file_hf(
239
+ repo_id=dino_repo_id,
240
+ filename=dino_dict[dino_type]["checkpoint_file"],
241
+ cache_dir=weight_dir,
242
+ )
243
+ download_file_hf(
244
+ repo_id=dino_repo_id,
245
+ filename=dino_dict[dino_type]["checkpoint_file"],
246
+ cache_dir=weight_dir,
247
+ )
248
+ grounding_dino_model = DinoModel(
249
+ model_config_path=dino_config_file,
250
+ model_checkpoint_path=dino_checkpoint,
251
+ device=device,
252
+ )
253
 
254
+ if task in ["auto", "segment"]:
255
+ print("Loading SAM...")
256
+ sam_type = args.sam_type
257
+ sam_checkpoint = os.path.join(
258
+ abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
259
+ )
260
+ if not os.path.exists(sam_checkpoint):
261
+ print(f"Downloading weights for SAM {sam_type}")
262
+ os.system(
263
+ f"wget {sam_dict[sam_type]['checkpoint_url']} -O {sam_checkpoint}"
264
  )
265
+ sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint)
266
+ sam.to(device=device)
267
+ sam_predictor = SamPredictor(sam)
268
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
269
+
270
+ if not os.path.exists(args.input):
271
+ raise ValueError("The input directory doesn't exist!")
272
+ elif not os.path.isdir(args.input):
273
+ image_paths = [args.input]
274
+ else:
275
+ image_paths = paths.list_images(args.input)
276
+
277
+ os.makedirs(args.output, exist_ok=True)
278
+
279
+ with tqdm(image_paths) as pbar:
280
+ for image_path in pbar:
281
+ pbar.set_postfix_str(f"Processing {image_path}")
282
+ process(
283
+ tag2text_model=tag2text_model,
284
+ grounding_dino_model=grounding_dino_model,
285
+ sam_predictor=sam_predictor,
286
+ sam_automask_generator=sam_automask_generator,
287
+ image_path=image_path,
288
+ task=task,
289
+ prompt=prompt,
290
+ box_threshold=box_threshold,
291
+ text_threshold=text_threshold,
292
+ iou_threshold=iou_threshold,
293
+ device=device,
294
+ output_dir=args.output,
295
+ save_mask=save_mask,
296
  )
297
+
298
+
299
+ if __name__ == "__main__":
300
+ if not os.path.exists(abs_weight_dir):
301
+ os.makedirs(abs_weight_dir, exist_ok=True)
302
+
303
+ parser = argparse.ArgumentParser(
304
+ description=(
305
+ "Runs automatic detection and mask generation on an input image or directory of images"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  )
 
 
 
 
 
 
 
 
 
 
 
307
  )
308
 
309
+ parser.add_argument(
310
+ "--input",
311
+ "-i",
312
+ type=str,
313
+ required=True,
314
+ help="Path to either a single input image or folder of images.",
315
+ )
316
+
317
+ parser.add_argument(
318
+ "--output",
319
+ "-o",
320
+ type=str,
321
+ required=True,
322
+ help=(
323
+ "Path to the directory where masks will be output."
324
+ ),
325
+ )
326
+
327
+ parser.add_argument(
328
+ "--sam-type",
329
+ type=str,
330
+ default=default_sam,
331
+ choices=sam_dict.keys(),
332
+ help="The type of SA model use for segmentation.",
333
+ )
334
+
335
+ parser.add_argument(
336
+ "--tag2text-type",
337
+ type=str,
338
+ default=default_tag2text,
339
+ choices=tag2text_dict.keys(),
340
+ help="The type of Tag2Text model use for tags and caption generation.",
341
+ )
342
+
343
+ parser.add_argument(
344
+ "--dino-type",
345
+ type=str,
346
+ default=default_dino,
347
+ choices=dino_dict.keys(),
348
+ help="The type of Grounding Dino model use for promptable object detection.",
349
+ )
350
+
351
+ parser.add_argument(
352
+ "--task",
353
+ help="Task to run",
354
+ default="auto",
355
+ choices=["auto", "detect", "segment"],
356
+ type=str,
357
+ )
358
+ parser.add_argument(
359
+ "--prompt",
360
+ help="Detection prompt",
361
+ default="",
362
+ type=str,
363
+ )
364
+
365
+ parser.add_argument(
366
+ "--box-threshold", type=float, default=0.25, help="box threshold"
367
+ )
368
+ parser.add_argument(
369
+ "--text-threshold", type=float, default=0.2, help="text threshold"
370
+ )
371
+ parser.add_argument(
372
+ "--iou-threshold", type=float, default=0.5, help="iou threshold"
373
+ )
374
+
375
+ parser.add_argument(
376
+ "--save-mask",
377
+ action="store_true",
378
+ default=False,
379
+ help="If True, save all intermidiate masks.",
380
+ )
381
+ parser.add_argument(
382
+ "--device", type=str, default="cuda", help="The device to run generation on."
383
+ )
384
+ args = parser.parse_args()
385
+ main(args)