|
import typer |
|
from groundingdino.util.inference import load_model, load_image, predict |
|
from tqdm import tqdm |
|
import torchvision |
|
import torch |
|
import fiftyone as fo |
|
|
|
|
|
def main( |
|
image_directory: str = 'test_grounding_dino', |
|
text_prompt: str = 'bus, car', |
|
box_threshold: float = 0.15, |
|
text_threshold: float = 0.10, |
|
export_dataset: bool = False, |
|
view_dataset: bool = False, |
|
export_annotated_images: bool = True, |
|
weights_path : str = "groundingdino_swint_ogc.pth", |
|
config_path: str = "../../GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", |
|
subsample: int = None, |
|
): |
|
|
|
model = load_model(config_path, weights_path) |
|
|
|
dataset = fo.Dataset.from_images_dir(image_directory) |
|
|
|
samples = [] |
|
|
|
if subsample is not None: |
|
|
|
if subsample < len(dataset): |
|
dataset = dataset.take(subsample).clone() |
|
|
|
for sample in tqdm(dataset): |
|
|
|
image_source, image = load_image(sample.filepath) |
|
|
|
boxes, logits, phrases = predict( |
|
model=model, |
|
image=image, |
|
caption=text_prompt, |
|
box_threshold=box_threshold, |
|
text_threshold=text_threshold, |
|
) |
|
|
|
detections = [] |
|
|
|
for box, logit, phrase in zip(boxes, logits, phrases): |
|
|
|
rel_box = torchvision.ops.box_convert(box, 'cxcywh', 'xywh') |
|
|
|
detections.append( |
|
fo.Detection( |
|
label=phrase, |
|
bounding_box=rel_box, |
|
confidence=logit, |
|
)) |
|
|
|
|
|
sample["detections"] = fo.Detections(detections=detections) |
|
sample.save() |
|
|
|
|
|
if view_dataset: |
|
session = fo.launch_app(dataset) |
|
session.wait() |
|
|
|
|
|
if export_dataset: |
|
dataset.export( |
|
'coco_dataset', |
|
dataset_type=fo.types.COCODetectionDataset, |
|
) |
|
|
|
|
|
if export_annotated_images: |
|
dataset.draw_labels( |
|
'images_with_bounding_boxes', |
|
label_fields=['detections'] |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
typer.run(main) |
|
|