Spaces:
Running
Running
dragonSwing
commited on
Commit
·
637f41e
1
Parent(s):
699405e
update script
Browse files- annotate_anything.py +22 -16
annotate_anything.py
CHANGED
@@ -33,8 +33,9 @@ def process(
|
|
33 |
box_threshold,
|
34 |
text_threshold,
|
35 |
iou_threshold,
|
36 |
-
device,
|
37 |
output_dir=None,
|
|
|
38 |
save_mask=False,
|
39 |
):
|
40 |
detections = None
|
@@ -84,7 +85,7 @@ def process(
|
|
84 |
)
|
85 |
|
86 |
# Save detection image
|
87 |
-
if output_dir:
|
88 |
# Draw boxes
|
89 |
box_annotator = sv.BoxAnnotator()
|
90 |
labels = [
|
@@ -123,7 +124,7 @@ def process(
|
|
123 |
)
|
124 |
|
125 |
# Save annotated image
|
126 |
-
if output_dir:
|
127 |
mask_annotator = sv.MaskAnnotator()
|
128 |
mask_image, res = show_anns_sv(detections)
|
129 |
annotated_image = mask_annotator.annotate(image, detections=detections)
|
@@ -197,12 +198,13 @@ def main(args: argparse.Namespace) -> None:
|
|
197 |
box_threshold = args.box_threshold
|
198 |
text_threshold = args.text_threshold
|
199 |
iou_threshold = args.iou_threshold
|
|
|
200 |
save_mask = args.save_mask
|
201 |
|
202 |
# load model
|
203 |
if task in ["auto", "detection"] and prompt == "":
|
204 |
print("Loading Tag2Text model...")
|
205 |
-
tag2text_type = args.
|
206 |
tag2text_checkpoint = os.path.join(
|
207 |
abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
|
208 |
)
|
@@ -225,7 +227,7 @@ def main(args: argparse.Namespace) -> None:
|
|
225 |
|
226 |
if task in ["auto", "detection"] or prompt != "":
|
227 |
print("Loading Grounding Dino model...")
|
228 |
-
dino_type = args.
|
229 |
dino_checkpoint = os.path.join(
|
230 |
abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
|
231 |
)
|
@@ -253,7 +255,7 @@ def main(args: argparse.Namespace) -> None:
|
|
253 |
|
254 |
if task in ["auto", "segment"]:
|
255 |
print("Loading SAM...")
|
256 |
-
sam_type = args.
|
257 |
sam_checkpoint = os.path.join(
|
258 |
abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
|
259 |
)
|
@@ -292,6 +294,7 @@ def main(args: argparse.Namespace) -> None:
|
|
292 |
iou_threshold=iou_threshold,
|
293 |
device=device,
|
294 |
output_dir=args.output,
|
|
|
295 |
save_mask=save_mask,
|
296 |
)
|
297 |
|
@@ -319,34 +322,31 @@ if __name__ == "__main__":
|
|
319 |
"-o",
|
320 |
type=str,
|
321 |
required=True,
|
322 |
-
help=
|
323 |
-
"Path to the directory where masks will be output. Output will be either a folder "
|
324 |
-
"of PNGs per image or a single json with COCO-style masks."
|
325 |
-
),
|
326 |
)
|
327 |
|
328 |
parser.add_argument(
|
329 |
-
"--sam",
|
330 |
type=str,
|
331 |
default=default_sam,
|
332 |
choices=sam_dict.keys(),
|
333 |
-
help="The type of SA model
|
334 |
)
|
335 |
|
336 |
parser.add_argument(
|
337 |
-
"--tag2text",
|
338 |
type=str,
|
339 |
default=default_tag2text,
|
340 |
choices=tag2text_dict.keys(),
|
341 |
-
help="The
|
342 |
)
|
343 |
|
344 |
parser.add_argument(
|
345 |
-
"--dino",
|
346 |
type=str,
|
347 |
default=default_dino,
|
348 |
choices=dino_dict.keys(),
|
349 |
-
help="The
|
350 |
)
|
351 |
|
352 |
parser.add_argument(
|
@@ -373,6 +373,12 @@ if __name__ == "__main__":
|
|
373 |
"--iou-threshold", type=float, default=0.5, help="iou threshold"
|
374 |
)
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
parser.add_argument(
|
377 |
"--save-mask",
|
378 |
action="store_true",
|
|
|
33 |
box_threshold,
|
34 |
text_threshold,
|
35 |
iou_threshold,
|
36 |
+
device="cuda",
|
37 |
output_dir=None,
|
38 |
+
save_ann=True,
|
39 |
save_mask=False,
|
40 |
):
|
41 |
detections = None
|
|
|
85 |
)
|
86 |
|
87 |
# Save detection image
|
88 |
+
if output_dir and save_ann:
|
89 |
# Draw boxes
|
90 |
box_annotator = sv.BoxAnnotator()
|
91 |
labels = [
|
|
|
124 |
)
|
125 |
|
126 |
# Save annotated image
|
127 |
+
if output_dir and save_ann:
|
128 |
mask_annotator = sv.MaskAnnotator()
|
129 |
mask_image, res = show_anns_sv(detections)
|
130 |
annotated_image = mask_annotator.annotate(image, detections=detections)
|
|
|
198 |
box_threshold = args.box_threshold
|
199 |
text_threshold = args.text_threshold
|
200 |
iou_threshold = args.iou_threshold
|
201 |
+
save_ann = not args.no_save_ann
|
202 |
save_mask = args.save_mask
|
203 |
|
204 |
# load model
|
205 |
if task in ["auto", "detection"] and prompt == "":
|
206 |
print("Loading Tag2Text model...")
|
207 |
+
tag2text_type = args.tag2text_type
|
208 |
tag2text_checkpoint = os.path.join(
|
209 |
abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
|
210 |
)
|
|
|
227 |
|
228 |
if task in ["auto", "detection"] or prompt != "":
|
229 |
print("Loading Grounding Dino model...")
|
230 |
+
dino_type = args.dino_type
|
231 |
dino_checkpoint = os.path.join(
|
232 |
abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
|
233 |
)
|
|
|
255 |
|
256 |
if task in ["auto", "segment"]:
|
257 |
print("Loading SAM...")
|
258 |
+
sam_type = args.sam_type
|
259 |
sam_checkpoint = os.path.join(
|
260 |
abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
|
261 |
)
|
|
|
294 |
iou_threshold=iou_threshold,
|
295 |
device=device,
|
296 |
output_dir=args.output,
|
297 |
+
save_ann=save_ann,
|
298 |
save_mask=save_mask,
|
299 |
)
|
300 |
|
|
|
322 |
"-o",
|
323 |
type=str,
|
324 |
required=True,
|
325 |
+
help="Path to the directory where masks will be output.",
|
|
|
|
|
|
|
326 |
)
|
327 |
|
328 |
parser.add_argument(
|
329 |
+
"--sam-type",
|
330 |
type=str,
|
331 |
default=default_sam,
|
332 |
choices=sam_dict.keys(),
|
333 |
+
help="The type of SA model use for segmentation.",
|
334 |
)
|
335 |
|
336 |
parser.add_argument(
|
337 |
+
"--tag2text-type",
|
338 |
type=str,
|
339 |
default=default_tag2text,
|
340 |
choices=tag2text_dict.keys(),
|
341 |
+
help="The type of Tag2Text model use for tags and caption generation.",
|
342 |
)
|
343 |
|
344 |
parser.add_argument(
|
345 |
+
"--dino-type",
|
346 |
type=str,
|
347 |
default=default_dino,
|
348 |
choices=dino_dict.keys(),
|
349 |
+
help="The type of Grounding Dino model use for promptable object detection.",
|
350 |
)
|
351 |
|
352 |
parser.add_argument(
|
|
|
373 |
"--iou-threshold", type=float, default=0.5, help="iou threshold"
|
374 |
)
|
375 |
|
376 |
+
parser.add_argument(
|
377 |
+
"--no-save-ann",
|
378 |
+
action="store_true",
|
379 |
+
default=False,
|
380 |
+
help="If False, save original image with blended masks and detection boxes.",
|
381 |
+
)
|
382 |
parser.add_argument(
|
383 |
"--save-mask",
|
384 |
action="store_true",
|