import os import json import glob import torch import datetime import argparse import torch.nn.functional as F import numpy as np import pycocotools.mask as mask_util def create_image_info(image_id, file_name, image_size, date_captured=datetime.datetime.utcnow().isoformat(' '), license_id=1, coco_url="", flickr_url=""): """Return image_info in COCO style Args: image_id: the image ID file_name: the file name of each image image_size: image size in the format of (width, height) date_captured: the date this image info is created license: license of this image coco_url: url to COCO images if there is any flickr_url: url to flickr if there is any """ image_info = { "id": image_id, "file_name": file_name, "width": image_size[0], "height": image_size[1], "date_captured": date_captured, "license": license_id, "coco_url": coco_url, "flickr_url": flickr_url } return image_info def create_annotation_info(annotation_id, image_id, category_info, binary_mask, image_size=None, bounding_box=None): """Return annotation info in COCO style Args: annotation_id: the annotation ID image_id: the image ID category_info: the information on categories binary_mask: a 2D binary numpy array where '1's represent the object file_name: the file name of each image image_size: image size in the format of (width, height) bounding_box: the bounding box for detection task. If bounding_box is not provided, we will generate one according to the binary mask. """ upper = np.max(binary_mask) lower = np.min(binary_mask) thresh = upper / 2.0 binary_mask[binary_mask > thresh] = upper binary_mask[binary_mask <= thresh] = lower if image_size is not None: binary_mask = resize_binary_mask(binary_mask.astype(np.uint8), image_size) binary_mask_encoded = mask_util.encode(np.asfortranarray(binary_mask.astype(np.uint8))) area = mask_util.area(binary_mask_encoded) if area < 1: return None if bounding_box is None: bounding_box = mask_util.toBbox(binary_mask_encoded) rle = mask_util.encode(np.array(binary_mask[...,None], order="F", dtype="uint8"))[0] rle['counts'] = rle['counts'].decode('ascii') segmentation = rle annotation_info = { "id": annotation_id, "image_id": image_id, "category_id": category_info["id"], "iscrowd": 0, "area": area.tolist(), "bbox": bounding_box.tolist(), "segmentation": segmentation, "width": binary_mask.shape[1], "height": binary_mask.shape[0], } return annotation_info # necessay info used for coco style annotations INFO = { "description": "ImageNet-1K: pseudo-masks with MaskCut", "url": "https://github.com/facebookresearch/CutLER", "version": "1.0", "year": 2023, "contributor": "Xudong Wang", "date_created": datetime.datetime.utcnow().isoformat(' ') } LICENSES = [ { "id": 1, "name": "Apache License", "url": "https://github.com/facebookresearch/CutLER/blob/main/LICENSE" } ] # only one class, i.e. foreground CATEGORIES = [ { 'id': 1, 'name': 'fg', 'supercategory': 'fg', }, ] convert = lambda text: int(text) if text.isdigit() else text.lower() natrual_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] output = { "info": INFO, "licenses": LICENSES, "categories": CATEGORIES, "images": [], "annotations": []} category_info = { "is_crowd": 0, "id": 1 } if __name__ == "__main__": parser = argparse.ArgumentParser('Merge pytorch results file into json') parser.add_argument('--base-dir', type=str, default='annotations/', help='Dir to the generated annotation .pt files with CWM') parser.add_argument('--save-path', type=str, default="coco_train_fixsize480_N3.json", help='Path to save the merged annotation file') args = parser.parse_args() file_list = glob.glob(os.path.join(args.base_dir, '*', '*')) ann_file = '/ccn2/u/honglinc/datasets/coco/annotations/instances_train2017.json' with open(ann_file, 'r') as file: gt_json = json.load(file) image_id, segmentation_id = 1, 1 image_names = [] for file_name in file_list: print('processing file name', file_name) data = torch.load(file_name) for img_name, mask_list in data.items(): for img in gt_json['images']: if img['file_name'] == img_name: height = img['height'] width = img['width'] break flag = img_name not in image_names if flag: image_info = create_image_info( image_id, img_name, (height, width, 3)) output["images"].append(image_info) image_names.append(img_name) for mask in mask_list: # create coco-style annotation info if mask.sum() == 0: continue pseudo_mask = F.interpolate(mask.float(), size=(height, width), mode='bicubic') > 0.5 pseudo_mask = pseudo_mask[0,0].numpy() annotation_info = create_annotation_info( segmentation_id, image_id, category_info, pseudo_mask.astype(np.uint8), None) if annotation_info is not None: output["annotations"].append(annotation_info) segmentation_id += 1 if flag: image_id += 1 print(image_id, segmentation_id) # save annotations with open(args.save_path, 'w') as output_json_file: json.dump(output, output_json_file) print(f'dumping {args.save_path}') print("Done: {} images; {} anns.".format(len(output['images']), len(output['annotations'])))