import argparse import os import sys import time import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, DistributedSampler from groundingdino.models import build_model import groundingdino.datasets.transforms as T from groundingdino.util import box_ops, get_tokenlizer from groundingdino.util.misc import clean_state_dict, collate_fn from groundingdino.util.slconfig import SLConfig # from torchvision.datasets import CocoDetection import torchvision from groundingdino.util.vl_utils import build_captions_and_token_span, create_positive_map_from_span from groundingdino.datasets.cocogrounding_eval import CocoGroundingEvaluator def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"): args = SLConfig.fromfile(model_config_path) args.device = device model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location="cpu") model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) model.eval() return model class CocoDetection(torchvision.datasets.CocoDetection): def __init__(self, img_folder, ann_file, transforms): super().__init__(img_folder, ann_file) self._transforms = transforms def __getitem__(self, idx): img, target = super().__getitem__(idx) # target: list # import ipdb; ipdb.set_trace() w, h = img.size boxes = [obj["bbox"] for obj in target] boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) boxes[:, 2:] += boxes[:, :2] # xywh -> xyxy boxes[:, 0::2].clamp_(min=0, max=w) boxes[:, 1::2].clamp_(min=0, max=h) # filt invalid boxes/masks/keypoints keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) boxes = boxes[keep] target_new = {} image_id = self.ids[idx] target_new["image_id"] = image_id target_new["boxes"] = boxes target_new["orig_size"] = torch.as_tensor([int(h), int(w)]) if self._transforms is not None: img, target = self._transforms(img, target_new) return img, target class PostProcessCocoGrounding(nn.Module): """ This module converts the model's output into the format expected by the coco api""" def __init__(self, num_select=300, coco_api=None, tokenlizer=None) -> None: super().__init__() self.num_select = num_select assert coco_api is not None category_dict = coco_api.dataset['categories'] cat_list = [item['name'] for item in category_dict] captions, cat2tokenspan = build_captions_and_token_span(cat_list, True) tokenspanlist = [cat2tokenspan[cat] for cat in cat_list] positive_map = create_positive_map_from_span( tokenlizer(captions), tokenspanlist) # 80, 256. normed id_map = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 13, 12: 14, 13: 15, 14: 16, 15: 17, 16: 18, 17: 19, 18: 20, 19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 27, 25: 28, 26: 31, 27: 32, 28: 33, 29: 34, 30: 35, 31: 36, 32: 37, 33: 38, 34: 39, 35: 40, 36: 41, 37: 42, 38: 43, 39: 44, 40: 46, 41: 47, 42: 48, 43: 49, 44: 50, 45: 51, 46: 52, 47: 53, 48: 54, 49: 55, 50: 56, 51: 57, 52: 58, 53: 59, 54: 60, 55: 61, 56: 62, 57: 63, 58: 64, 59: 65, 60: 67, 61: 70, 62: 72, 63: 73, 64: 74, 65: 75, 66: 76, 67: 77, 68: 78, 69: 79, 70: 80, 71: 81, 72: 82, 73: 84, 74: 85, 75: 86, 76: 87, 77: 88, 78: 89, 79: 90} # build a mapping from label_id to pos_map new_pos_map = torch.zeros((91, 256)) for k, v in id_map.items(): new_pos_map[v] = positive_map[k] self.positive_map = new_pos_map @torch.no_grad() def forward(self, outputs, target_sizes, not_to_xyxy=False): """ Perform the computation Parameters: outputs: raw outputs of the model target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch For evaluation, this must be the original image size (before any data augmentation) For visualization, this should be the image size after data augment, but before padding """ num_select = self.num_select out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] # pos map to logit prob_to_token = out_logits.sigmoid() # bs, 100, 256 pos_maps = self.positive_map.to(prob_to_token.device) # (bs, 100, 256) @ (91, 256).T -> (bs, 100, 91) prob_to_label = prob_to_token @ pos_maps.T # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': # import ipdb; ipdb.set_trace() assert len(out_logits) == len(target_sizes) assert target_sizes.shape[1] == 2 prob = prob_to_label topk_values, topk_indexes = torch.topk( prob.view(out_logits.shape[0], -1), num_select, dim=1) scores = topk_values topk_boxes = topk_indexes // prob.shape[2] labels = topk_indexes % prob.shape[2] if not_to_xyxy: boxes = out_bbox else: boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) boxes = torch.gather( boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) # and from relative [0, 1] to absolute [0, height] coordinates img_h, img_w = target_sizes.unbind(1) scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) boxes = boxes * scale_fct[:, None, :] results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] return results def main(args): # config cfg = SLConfig.fromfile(args.config_file) # build model model = load_model(args.config_file, args.checkpoint_path) model = model.to(args.device) model = model.eval() # build dataloader transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) dataset = CocoDetection( args.image_dir, args.anno_path, transforms=transform) data_loader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) # build post processor tokenlizer = get_tokenlizer.get_tokenlizer(cfg.text_encoder_type) postprocessor = PostProcessCocoGrounding( coco_api=dataset.coco, tokenlizer=tokenlizer) # build evaluator evaluator = CocoGroundingEvaluator( dataset.coco, iou_types=("bbox",), useCats=True) # build captions category_dict = dataset.coco.dataset['categories'] cat_list = [item['name'] for item in category_dict] caption = " . ".join(cat_list) + ' .' print("Input text prompt:", caption) # run inference start = time.time() for i, (images, targets) in enumerate(data_loader): # get images and captions images = images.tensors.to(args.device) bs = images.shape[0] input_captions = [caption] * bs # feed to the model outputs = model(images, captions=input_captions) orig_target_sizes = torch.stack( [t["orig_size"] for t in targets], dim=0).to(images.device) results = postprocessor(outputs, orig_target_sizes) cocogrounding_res = { target["image_id"]: output for target, output in zip(targets, results)} evaluator.update(cocogrounding_res) if (i+1) % 30 == 0: used_time = time.time() - start eta = len(data_loader) / (i+1e-5) * used_time - used_time print( f"processed {i}/{len(data_loader)} images. time: {used_time:.2f}s, ETA: {eta:.2f}s") evaluator.synchronize_between_processes() evaluator.accumulate() evaluator.summarize() print("Final results:", evaluator.coco_eval["bbox"].stats.tolist()) if __name__ == "__main__": parser = argparse.ArgumentParser( "Grounding DINO eval on COCO", add_help=True) # load model parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file") parser.add_argument( "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file" ) parser.add_argument("--device", type=str, default="cuda", help="running device (default: cuda)") # post processing parser.add_argument("--num_select", type=int, default=300, help="number of topk to select") # coco info parser.add_argument("--anno_path", type=str, required=True, help="coco root") parser.add_argument("--image_dir", type=str, required=True, help="coco image dir") parser.add_argument("--num_workers", type=int, default=4, help="number of workers for dataloader") args = parser.parse_args() main(args)