Spaces:
Running
Running
dragonSwing
commited on
Commit
·
eeef127
1
Parent(s):
156bb47
Update script arguments
Browse files
app.py
CHANGED
@@ -1,96 +1,47 @@
|
|
|
|
1 |
import json
|
2 |
import os
|
3 |
-
import subprocess
|
4 |
import sys
|
5 |
import tempfile
|
6 |
|
7 |
-
import gradio as gr
|
8 |
import numpy as np
|
9 |
import supervision as sv
|
10 |
-
import
|
|
|
11 |
from PIL import Image
|
12 |
-
from segment_anything import
|
13 |
from segment_anything import SamAutomaticMaskGenerator
|
14 |
from segment_anything import SamPredictor
|
15 |
-
from supervision.detection.utils import mask_to_polygons
|
16 |
from supervision.detection.utils import xywh_to_xyxy
|
17 |
-
|
18 |
-
if os.environ.get("IS_MY_DEBUG") is None:
|
19 |
-
result = subprocess.run(["pip", "install", "-e", "GroundingDINO"], check=True)
|
20 |
-
print(f"pip install GroundingDINO = {result}")
|
21 |
|
22 |
sys.path.append("tag2text")
|
23 |
-
sys.path.append("GroundingDINO")
|
24 |
|
25 |
from tag2text.models import tag2text
|
26 |
-
from groundingdino.util.inference import Model as DinoModel
|
27 |
from config import *
|
28 |
-
from utils import
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
abs_weight_dir, dino_dict[default_dino]["checkpoint_file"]
|
47 |
-
)
|
48 |
-
dino_config_file = os.path.join(abs_weight_dir, dino_dict[default_dino]["config_file"])
|
49 |
-
if not os.path.exists(dino_checkpoint):
|
50 |
-
dino_repo_id = dino_dict[default_dino]["repo_id"]
|
51 |
-
download_file_hf(
|
52 |
-
repo_id=dino_repo_id,
|
53 |
-
filename=dino_dict[default_dino]["config_file"],
|
54 |
-
cache_dir=weight_dir,
|
55 |
-
)
|
56 |
-
download_file_hf(
|
57 |
-
repo_id=dino_repo_id,
|
58 |
-
filename=dino_dict[default_dino]["checkpoint_file"],
|
59 |
-
cache_dir=weight_dir,
|
60 |
-
)
|
61 |
-
|
62 |
-
# load model
|
63 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
64 |
-
tag2text_model = tag2text.tag2text_caption(
|
65 |
-
pretrained=tag2text_checkpoint,
|
66 |
-
image_size=384,
|
67 |
-
vit="swin_b",
|
68 |
-
delete_tag_index=delete_tag_index,
|
69 |
-
)
|
70 |
-
# threshold for tagging
|
71 |
-
# we reduce the threshold to obtain more tags
|
72 |
-
tag2text_model.threshold = 0.64
|
73 |
-
tag2text_model.to(device)
|
74 |
-
tag2text_model.eval()
|
75 |
-
|
76 |
-
|
77 |
-
sam = build_sam(checkpoint=sam_checkpoint)
|
78 |
-
sam.to(device=device)
|
79 |
-
sam_predictor = SamPredictor(sam)
|
80 |
-
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
81 |
-
|
82 |
-
grounding_dino_model = DinoModel(
|
83 |
-
model_config_path=dino_config_file,
|
84 |
-
model_checkpoint_path=dino_checkpoint,
|
85 |
-
device=device,
|
86 |
-
)
|
87 |
-
|
88 |
-
|
89 |
-
def process(image_path, task, prompt, box_threshold, text_threshold, iou_threshold):
|
90 |
-
global tag2text_model, sam_predictor, sam_automask_generator, grounding_dino_model, device
|
91 |
-
output_gallery = []
|
92 |
detections = None
|
93 |
-
metadata = {"image": {}, "annotations": []}
|
|
|
|
|
|
|
94 |
|
95 |
try:
|
96 |
# Load image
|
@@ -100,17 +51,18 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
100 |
|
101 |
# Extract image metadata
|
102 |
filename = os.path.basename(image_path)
|
|
|
103 |
h, w = image.shape[:2]
|
104 |
metadata["image"]["file_name"] = filename
|
105 |
metadata["image"]["width"] = w
|
106 |
metadata["image"]["height"] = h
|
107 |
|
108 |
# Generate tags
|
109 |
-
if task in ["auto", "
|
110 |
tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
|
111 |
prompt = " . ".join(tags)
|
112 |
-
print(f"Caption: {caption}")
|
113 |
-
print(f"Tags: {tags}")
|
114 |
|
115 |
# ToDo: Extract metadata
|
116 |
metadata["image"]["caption"] = caption
|
@@ -118,7 +70,6 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
118 |
|
119 |
if prompt:
|
120 |
metadata["prompt"] = prompt
|
121 |
-
print(f"Prompt: {prompt}")
|
122 |
|
123 |
# Detect boxes
|
124 |
if prompt != "":
|
@@ -131,18 +82,21 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
131 |
iou_threshold=iou_threshold,
|
132 |
post_process=True,
|
133 |
)
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
146 |
|
147 |
# Segmentation
|
148 |
if task in ["auto", "segment"]:
|
@@ -167,18 +121,27 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
167 |
detections = sv.Detections(
|
168 |
xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
|
169 |
)
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
# ToDo: Extract metadata
|
184 |
if detections:
|
@@ -201,86 +164,222 @@ def process(image_path, task, prompt, box_threshold, text_threshold, iou_thresho
|
|
201 |
metadata["annotations"].append(annotation)
|
202 |
i += 1
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
except Exception as error:
|
211 |
-
raise
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
222 |
)
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
)
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
text_threshold = gr.Slider(
|
237 |
-
minimum=0,
|
238 |
-
maximum=1,
|
239 |
-
value=0.25,
|
240 |
-
step=0.05,
|
241 |
-
label="Text threshold",
|
242 |
-
info="Number of history images used to find out duplicate image",
|
243 |
-
)
|
244 |
-
iou_threshold = gr.Slider(
|
245 |
-
minimum=0,
|
246 |
-
maximum=1,
|
247 |
-
value=0.5,
|
248 |
-
step=0.05,
|
249 |
-
label="IOU threshold",
|
250 |
-
info="Minimum similarity threshold (in percent) to consider 2 images to be similar",
|
251 |
-
)
|
252 |
-
run_button = gr.Button(label="Run")
|
253 |
-
|
254 |
-
with gr.Column(scale=2):
|
255 |
-
gallery = gr.Gallery(
|
256 |
-
label="Generated images", show_label=False, elem_id="gallery"
|
257 |
-
).style(preview=True, grid=2, object_fit="scale-down")
|
258 |
-
meta_file = gr.File(label="Metadata file")
|
259 |
-
|
260 |
-
with gr.Row(elem_classes=["container"]):
|
261 |
-
gr.Examples(
|
262 |
-
[
|
263 |
-
["examples/dog.png", "auto", ""],
|
264 |
-
["examples/eiffel.png", "auto", ""],
|
265 |
-
["examples/eiffel.png", "segment", ""],
|
266 |
-
["examples/girl.png", "auto", "girl . face"],
|
267 |
-
["examples/horse.png", "detect", "horse"],
|
268 |
-
["examples/horses.jpg", "auto", "horse"],
|
269 |
-
["examples/traffic.jpg", "auto", ""],
|
270 |
-
],
|
271 |
-
[input_image, task, text_prompt],
|
272 |
)
|
273 |
-
run_button.click(
|
274 |
-
fn=process,
|
275 |
-
inputs=[
|
276 |
-
input_image,
|
277 |
-
task,
|
278 |
-
text_prompt,
|
279 |
-
box_threshold,
|
280 |
-
text_threshold,
|
281 |
-
iou_threshold,
|
282 |
-
],
|
283 |
-
outputs=[gallery, meta_file],
|
284 |
)
|
285 |
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
import json
|
3 |
import os
|
|
|
4 |
import sys
|
5 |
import tempfile
|
6 |
|
|
|
7 |
import numpy as np
|
8 |
import supervision as sv
|
9 |
+
from groundingdino.util.inference import Model as DinoModel
|
10 |
+
from imutils import paths
|
11 |
from PIL import Image
|
12 |
+
from segment_anything import sam_model_registry
|
13 |
from segment_anything import SamAutomaticMaskGenerator
|
14 |
from segment_anything import SamPredictor
|
|
|
15 |
from supervision.detection.utils import xywh_to_xyxy
|
16 |
+
from tqdm import tqdm
|
|
|
|
|
|
|
17 |
|
18 |
sys.path.append("tag2text")
|
|
|
19 |
|
20 |
from tag2text.models import tag2text
|
|
|
21 |
from config import *
|
22 |
+
from utils import detect, download_file_hf, segment, generate_tags, show_anns_sv
|
23 |
+
|
24 |
+
|
25 |
+
def process(
|
26 |
+
tag2text_model,
|
27 |
+
grounding_dino_model,
|
28 |
+
sam_predictor,
|
29 |
+
sam_automask_generator,
|
30 |
+
image_path,
|
31 |
+
task,
|
32 |
+
prompt,
|
33 |
+
box_threshold,
|
34 |
+
text_threshold,
|
35 |
+
iou_threshold,
|
36 |
+
device,
|
37 |
+
output_dir=None,
|
38 |
+
save_mask=False,
|
39 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
detections = None
|
41 |
+
metadata = {"image": {}, "annotations": [], "assets": {}}
|
42 |
+
|
43 |
+
if save_mask:
|
44 |
+
metadata["assets"]["intermediate_mask"] = []
|
45 |
|
46 |
try:
|
47 |
# Load image
|
|
|
51 |
|
52 |
# Extract image metadata
|
53 |
filename = os.path.basename(image_path)
|
54 |
+
basename = os.path.splitext(filename)[0]
|
55 |
h, w = image.shape[:2]
|
56 |
metadata["image"]["file_name"] = filename
|
57 |
metadata["image"]["width"] = w
|
58 |
metadata["image"]["height"] = h
|
59 |
|
60 |
# Generate tags
|
61 |
+
if task in ["auto", "detection"] and prompt == "":
|
62 |
tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
|
63 |
prompt = " . ".join(tags)
|
64 |
+
# print(f"Caption: {caption}")
|
65 |
+
# print(f"Tags: {tags}")
|
66 |
|
67 |
# ToDo: Extract metadata
|
68 |
metadata["image"]["caption"] = caption
|
|
|
70 |
|
71 |
if prompt:
|
72 |
metadata["prompt"] = prompt
|
|
|
73 |
|
74 |
# Detect boxes
|
75 |
if prompt != "":
|
|
|
82 |
iou_threshold=iou_threshold,
|
83 |
post_process=True,
|
84 |
)
|
85 |
+
|
86 |
+
# Save detection image
|
87 |
+
if output_dir:
|
88 |
+
# Draw boxes
|
89 |
+
box_annotator = sv.BoxAnnotator()
|
90 |
+
labels = [
|
91 |
+
f"{phrases[i]} {detections.confidence[i]:0.2f}"
|
92 |
+
for i in range(len(phrases))
|
93 |
+
]
|
94 |
+
box_image = box_annotator.annotate(
|
95 |
+
scene=image, detections=detections, labels=labels
|
96 |
+
)
|
97 |
+
box_image_path = os.path.join(output_dir, basename + "_detect.png")
|
98 |
+
metadata["assets"]["detection"] = box_image_path
|
99 |
+
Image.fromarray(box_image).save(box_image_path)
|
100 |
|
101 |
# Segmentation
|
102 |
if task in ["auto", "segment"]:
|
|
|
121 |
detections = sv.Detections(
|
122 |
xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
|
123 |
)
|
124 |
+
|
125 |
+
# Save annotated image
|
126 |
+
if output_dir:
|
127 |
+
mask_annotator = sv.MaskAnnotator()
|
128 |
+
mask_image, res = show_anns_sv(detections)
|
129 |
+
annotated_image = mask_annotator.annotate(image, detections=detections)
|
130 |
+
|
131 |
+
mask_image_path = os.path.join(output_dir, basename + "_mask.png")
|
132 |
+
metadata["assets"]["mask"] = mask_image_path
|
133 |
+
Image.fromarray(mask_image).save(mask_image_path)
|
134 |
+
|
135 |
+
# Save annotation encoding from https://github.com/LUSSeg/ImageNet-S
|
136 |
+
mask_enc_path = os.path.join(output_dir, basename + "_mask_enc.npy")
|
137 |
+
np.save(mask_enc_path, res)
|
138 |
+
metadata["assets"]["mask_enc"] = mask_enc_path
|
139 |
+
|
140 |
+
annotated_image_path = os.path.join(
|
141 |
+
output_dir, basename + "_annotate.png"
|
142 |
+
)
|
143 |
+
metadata["assets"]["annotate"] = annotated_image_path
|
144 |
+
Image.fromarray(annotated_image).save(annotated_image_path)
|
145 |
|
146 |
# ToDo: Extract metadata
|
147 |
if detections:
|
|
|
164 |
metadata["annotations"].append(annotation)
|
165 |
i += 1
|
166 |
|
167 |
+
if output_dir and save_mask:
|
168 |
+
mask_image_path = os.path.join(
|
169 |
+
output_dir, f"{basename}_mask_{id}.png"
|
170 |
+
)
|
171 |
+
metadata["assets"]["intermediate_mask"].append(mask_image_path)
|
172 |
+
Image.fromarray(mask * 255).save(mask_image_path)
|
173 |
+
|
174 |
+
if output_dir:
|
175 |
+
meta_file_path = os.path.join(output_dir, basename + "_meta.json")
|
176 |
+
with open(meta_file_path, "w") as fp:
|
177 |
+
json.dump(metadata, fp)
|
178 |
+
else:
|
179 |
+
meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
|
180 |
+
meta_file_path = meta_file.name
|
181 |
+
|
182 |
+
return meta_file_path
|
183 |
except Exception as error:
|
184 |
+
raise ValueError(f"global exception: {error}")
|
185 |
+
|
186 |
+
|
187 |
+
def main(args: argparse.Namespace) -> None:
|
188 |
+
device = args.device
|
189 |
+
prompt = args.prompt
|
190 |
+
task = args.task
|
191 |
+
|
192 |
+
tag2text_model = None
|
193 |
+
grounding_dino_model = None
|
194 |
+
sam_predictor = None
|
195 |
+
sam_automask_generator = None
|
196 |
+
|
197 |
+
box_threshold = args.box_threshold
|
198 |
+
text_threshold = args.text_threshold
|
199 |
+
iou_threshold = args.iou_threshold
|
200 |
+
save_mask = args.save_mask
|
201 |
+
|
202 |
+
# load model
|
203 |
+
if task in ["auto", "detection"] and prompt == "":
|
204 |
+
print("Loading Tag2Text model...")
|
205 |
+
tag2text_type = args.tag2text_type
|
206 |
+
tag2text_checkpoint = os.path.join(
|
207 |
+
abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
|
208 |
+
)
|
209 |
+
if not os.path.exists(tag2text_checkpoint):
|
210 |
+
print(f"Downloading weights for Tag2Text {tag2text_type} model")
|
211 |
+
os.system(
|
212 |
+
f"wget {tag2text_dict[tag2text_type]['checkpoint_url']} -O {tag2text_checkpoint}"
|
213 |
+
)
|
214 |
+
tag2text_model = tag2text.tag2text_caption(
|
215 |
+
pretrained=tag2text_checkpoint,
|
216 |
+
image_size=384,
|
217 |
+
vit="swin_b",
|
218 |
+
delete_tag_index=delete_tag_index,
|
219 |
+
)
|
220 |
+
# threshold for tagging
|
221 |
+
# we reduce the threshold to obtain more tags
|
222 |
+
tag2text_model.threshold = 0.64
|
223 |
+
tag2text_model.to(device)
|
224 |
+
tag2text_model.eval()
|
225 |
+
|
226 |
+
if task in ["auto", "detection"] or prompt != "":
|
227 |
+
print("Loading Grounding Dino model...")
|
228 |
+
dino_type = args.dino_type
|
229 |
+
dino_checkpoint = os.path.join(
|
230 |
+
abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
|
231 |
+
)
|
232 |
+
dino_config_file = os.path.join(
|
233 |
+
abs_weight_dir, dino_dict[dino_type]["config_file"]
|
234 |
+
)
|
235 |
+
if not os.path.exists(dino_checkpoint):
|
236 |
+
print(f"Downloading weights for Grounding Dino {dino_type} model")
|
237 |
+
dino_repo_id = dino_dict[dino_type]["repo_id"]
|
238 |
+
download_file_hf(
|
239 |
+
repo_id=dino_repo_id,
|
240 |
+
filename=dino_dict[dino_type]["checkpoint_file"],
|
241 |
+
cache_dir=weight_dir,
|
242 |
+
)
|
243 |
+
download_file_hf(
|
244 |
+
repo_id=dino_repo_id,
|
245 |
+
filename=dino_dict[dino_type]["checkpoint_file"],
|
246 |
+
cache_dir=weight_dir,
|
247 |
+
)
|
248 |
+
grounding_dino_model = DinoModel(
|
249 |
+
model_config_path=dino_config_file,
|
250 |
+
model_checkpoint_path=dino_checkpoint,
|
251 |
+
device=device,
|
252 |
+
)
|
253 |
|
254 |
+
if task in ["auto", "segment"]:
|
255 |
+
print("Loading SAM...")
|
256 |
+
sam_type = args.sam_type
|
257 |
+
sam_checkpoint = os.path.join(
|
258 |
+
abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
|
259 |
+
)
|
260 |
+
if not os.path.exists(sam_checkpoint):
|
261 |
+
print(f"Downloading weights for SAM {sam_type}")
|
262 |
+
os.system(
|
263 |
+
f"wget {sam_dict[sam_type]['checkpoint_url']} -O {sam_checkpoint}"
|
264 |
)
|
265 |
+
sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint)
|
266 |
+
sam.to(device=device)
|
267 |
+
sam_predictor = SamPredictor(sam)
|
268 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
269 |
+
|
270 |
+
if not os.path.exists(args.input):
|
271 |
+
raise ValueError("The input directory doesn't exist!")
|
272 |
+
elif not os.path.isdir(args.input):
|
273 |
+
image_paths = [args.input]
|
274 |
+
else:
|
275 |
+
image_paths = paths.list_images(args.input)
|
276 |
+
|
277 |
+
os.makedirs(args.output, exist_ok=True)
|
278 |
+
|
279 |
+
with tqdm(image_paths) as pbar:
|
280 |
+
for image_path in pbar:
|
281 |
+
pbar.set_postfix_str(f"Processing {image_path}")
|
282 |
+
process(
|
283 |
+
tag2text_model=tag2text_model,
|
284 |
+
grounding_dino_model=grounding_dino_model,
|
285 |
+
sam_predictor=sam_predictor,
|
286 |
+
sam_automask_generator=sam_automask_generator,
|
287 |
+
image_path=image_path,
|
288 |
+
task=task,
|
289 |
+
prompt=prompt,
|
290 |
+
box_threshold=box_threshold,
|
291 |
+
text_threshold=text_threshold,
|
292 |
+
iou_threshold=iou_threshold,
|
293 |
+
device=device,
|
294 |
+
output_dir=args.output,
|
295 |
+
save_mask=save_mask,
|
296 |
)
|
297 |
+
|
298 |
+
|
299 |
+
if __name__ == "__main__":
|
300 |
+
if not os.path.exists(abs_weight_dir):
|
301 |
+
os.makedirs(abs_weight_dir, exist_ok=True)
|
302 |
+
|
303 |
+
parser = argparse.ArgumentParser(
|
304 |
+
description=(
|
305 |
+
"Runs automatic detection and mask generation on an input image or directory of images"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
)
|
308 |
|
309 |
+
parser.add_argument(
|
310 |
+
"--input",
|
311 |
+
"-i",
|
312 |
+
type=str,
|
313 |
+
required=True,
|
314 |
+
help="Path to either a single input image or folder of images.",
|
315 |
+
)
|
316 |
+
|
317 |
+
parser.add_argument(
|
318 |
+
"--output",
|
319 |
+
"-o",
|
320 |
+
type=str,
|
321 |
+
required=True,
|
322 |
+
help=(
|
323 |
+
"Path to the directory where masks will be output."
|
324 |
+
),
|
325 |
+
)
|
326 |
+
|
327 |
+
parser.add_argument(
|
328 |
+
"--sam-type",
|
329 |
+
type=str,
|
330 |
+
default=default_sam,
|
331 |
+
choices=sam_dict.keys(),
|
332 |
+
help="The type of SA model use for segmentation.",
|
333 |
+
)
|
334 |
+
|
335 |
+
parser.add_argument(
|
336 |
+
"--tag2text-type",
|
337 |
+
type=str,
|
338 |
+
default=default_tag2text,
|
339 |
+
choices=tag2text_dict.keys(),
|
340 |
+
help="The type of Tag2Text model use for tags and caption generation.",
|
341 |
+
)
|
342 |
+
|
343 |
+
parser.add_argument(
|
344 |
+
"--dino-type",
|
345 |
+
type=str,
|
346 |
+
default=default_dino,
|
347 |
+
choices=dino_dict.keys(),
|
348 |
+
help="The type of Grounding Dino model use for promptable object detection.",
|
349 |
+
)
|
350 |
+
|
351 |
+
parser.add_argument(
|
352 |
+
"--task",
|
353 |
+
help="Task to run",
|
354 |
+
default="auto",
|
355 |
+
choices=["auto", "detect", "segment"],
|
356 |
+
type=str,
|
357 |
+
)
|
358 |
+
parser.add_argument(
|
359 |
+
"--prompt",
|
360 |
+
help="Detection prompt",
|
361 |
+
default="",
|
362 |
+
type=str,
|
363 |
+
)
|
364 |
+
|
365 |
+
parser.add_argument(
|
366 |
+
"--box-threshold", type=float, default=0.25, help="box threshold"
|
367 |
+
)
|
368 |
+
parser.add_argument(
|
369 |
+
"--text-threshold", type=float, default=0.2, help="text threshold"
|
370 |
+
)
|
371 |
+
parser.add_argument(
|
372 |
+
"--iou-threshold", type=float, default=0.5, help="iou threshold"
|
373 |
+
)
|
374 |
+
|
375 |
+
parser.add_argument(
|
376 |
+
"--save-mask",
|
377 |
+
action="store_true",
|
378 |
+
default=False,
|
379 |
+
help="If True, save all intermidiate masks.",
|
380 |
+
)
|
381 |
+
parser.add_argument(
|
382 |
+
"--device", type=str, default="cuda", help="The device to run generation on."
|
383 |
+
)
|
384 |
+
args = parser.parse_args()
|
385 |
+
main(args)
|