dragonSwing commited on
Commit
637f41e
·
1 Parent(s): 699405e

update script

Browse files
Files changed (1) hide show
  1. annotate_anything.py +22 -16
annotate_anything.py CHANGED
@@ -33,8 +33,9 @@ def process(
33
  box_threshold,
34
  text_threshold,
35
  iou_threshold,
36
- device,
37
  output_dir=None,
 
38
  save_mask=False,
39
  ):
40
  detections = None
@@ -84,7 +85,7 @@ def process(
84
  )
85
 
86
  # Save detection image
87
- if output_dir:
88
  # Draw boxes
89
  box_annotator = sv.BoxAnnotator()
90
  labels = [
@@ -123,7 +124,7 @@ def process(
123
  )
124
 
125
  # Save annotated image
126
- if output_dir:
127
  mask_annotator = sv.MaskAnnotator()
128
  mask_image, res = show_anns_sv(detections)
129
  annotated_image = mask_annotator.annotate(image, detections=detections)
@@ -197,12 +198,13 @@ def main(args: argparse.Namespace) -> None:
197
  box_threshold = args.box_threshold
198
  text_threshold = args.text_threshold
199
  iou_threshold = args.iou_threshold
 
200
  save_mask = args.save_mask
201
 
202
  # load model
203
  if task in ["auto", "detection"] and prompt == "":
204
  print("Loading Tag2Text model...")
205
- tag2text_type = args.tag2text
206
  tag2text_checkpoint = os.path.join(
207
  abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
208
  )
@@ -225,7 +227,7 @@ def main(args: argparse.Namespace) -> None:
225
 
226
  if task in ["auto", "detection"] or prompt != "":
227
  print("Loading Grounding Dino model...")
228
- dino_type = args.dino
229
  dino_checkpoint = os.path.join(
230
  abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
231
  )
@@ -253,7 +255,7 @@ def main(args: argparse.Namespace) -> None:
253
 
254
  if task in ["auto", "segment"]:
255
  print("Loading SAM...")
256
- sam_type = args.sam
257
  sam_checkpoint = os.path.join(
258
  abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
259
  )
@@ -292,6 +294,7 @@ def main(args: argparse.Namespace) -> None:
292
  iou_threshold=iou_threshold,
293
  device=device,
294
  output_dir=args.output,
 
295
  save_mask=save_mask,
296
  )
297
 
@@ -319,34 +322,31 @@ if __name__ == "__main__":
319
  "-o",
320
  type=str,
321
  required=True,
322
- help=(
323
- "Path to the directory where masks will be output. Output will be either a folder "
324
- "of PNGs per image or a single json with COCO-style masks."
325
- ),
326
  )
327
 
328
  parser.add_argument(
329
- "--sam",
330
  type=str,
331
  default=default_sam,
332
  choices=sam_dict.keys(),
333
- help="The type of SA model to load",
334
  )
335
 
336
  parser.add_argument(
337
- "--tag2text",
338
  type=str,
339
  default=default_tag2text,
340
  choices=tag2text_dict.keys(),
341
- help="The path to the Tag2Text checkpoint to use for tags and caption generation.",
342
  )
343
 
344
  parser.add_argument(
345
- "--dino",
346
  type=str,
347
  default=default_dino,
348
  choices=dino_dict.keys(),
349
- help="The config file of Grounding Dino model to load",
350
  )
351
 
352
  parser.add_argument(
@@ -373,6 +373,12 @@ if __name__ == "__main__":
373
  "--iou-threshold", type=float, default=0.5, help="iou threshold"
374
  )
375
 
 
 
 
 
 
 
376
  parser.add_argument(
377
  "--save-mask",
378
  action="store_true",
 
33
  box_threshold,
34
  text_threshold,
35
  iou_threshold,
36
+ device="cuda",
37
  output_dir=None,
38
+ save_ann=True,
39
  save_mask=False,
40
  ):
41
  detections = None
 
85
  )
86
 
87
  # Save detection image
88
+ if output_dir and save_ann:
89
  # Draw boxes
90
  box_annotator = sv.BoxAnnotator()
91
  labels = [
 
124
  )
125
 
126
  # Save annotated image
127
+ if output_dir and save_ann:
128
  mask_annotator = sv.MaskAnnotator()
129
  mask_image, res = show_anns_sv(detections)
130
  annotated_image = mask_annotator.annotate(image, detections=detections)
 
198
  box_threshold = args.box_threshold
199
  text_threshold = args.text_threshold
200
  iou_threshold = args.iou_threshold
201
+ save_ann = not args.no_save_ann
202
  save_mask = args.save_mask
203
 
204
  # load model
205
  if task in ["auto", "detection"] and prompt == "":
206
  print("Loading Tag2Text model...")
207
+ tag2text_type = args.tag2text_type
208
  tag2text_checkpoint = os.path.join(
209
  abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
210
  )
 
227
 
228
  if task in ["auto", "detection"] or prompt != "":
229
  print("Loading Grounding Dino model...")
230
+ dino_type = args.dino_type
231
  dino_checkpoint = os.path.join(
232
  abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
233
  )
 
255
 
256
  if task in ["auto", "segment"]:
257
  print("Loading SAM...")
258
+ sam_type = args.sam_type
259
  sam_checkpoint = os.path.join(
260
  abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
261
  )
 
294
  iou_threshold=iou_threshold,
295
  device=device,
296
  output_dir=args.output,
297
+ save_ann=save_ann,
298
  save_mask=save_mask,
299
  )
300
 
 
322
  "-o",
323
  type=str,
324
  required=True,
325
+ help="Path to the directory where masks will be output.",
 
 
 
326
  )
327
 
328
  parser.add_argument(
329
+ "--sam-type",
330
  type=str,
331
  default=default_sam,
332
  choices=sam_dict.keys(),
333
+ help="The type of SA model use for segmentation.",
334
  )
335
 
336
  parser.add_argument(
337
+ "--tag2text-type",
338
  type=str,
339
  default=default_tag2text,
340
  choices=tag2text_dict.keys(),
341
+ help="The type of Tag2Text model use for tags and caption generation.",
342
  )
343
 
344
  parser.add_argument(
345
+ "--dino-type",
346
  type=str,
347
  default=default_dino,
348
  choices=dino_dict.keys(),
349
+ help="The type of Grounding Dino model use for promptable object detection.",
350
  )
351
 
352
  parser.add_argument(
 
373
  "--iou-threshold", type=float, default=0.5, help="iou threshold"
374
  )
375
 
376
+ parser.add_argument(
377
+ "--no-save-ann",
378
+ action="store_true",
379
+ default=False,
380
+ help="If False, save original image with blended masks and detection boxes.",
381
+ )
382
  parser.add_argument(
383
  "--save-mask",
384
  action="store_true",