rahulvenkk
app.py updated
6dfcb0f
raw
history blame
6.2 kB
import os
import json
import glob
import torch
import datetime
import argparse
import torch.nn.functional as F
import numpy as np
import pycocotools.mask as mask_util
def create_image_info(image_id, file_name, image_size,
date_captured=datetime.datetime.utcnow().isoformat(' '),
license_id=1, coco_url="", flickr_url=""):
"""Return image_info in COCO style
Args:
image_id: the image ID
file_name: the file name of each image
image_size: image size in the format of (width, height)
date_captured: the date this image info is created
license: license of this image
coco_url: url to COCO images if there is any
flickr_url: url to flickr if there is any
"""
image_info = {
"id": image_id,
"file_name": file_name,
"width": image_size[0],
"height": image_size[1],
"date_captured": date_captured,
"license": license_id,
"coco_url": coco_url,
"flickr_url": flickr_url
}
return image_info
def create_annotation_info(annotation_id, image_id, category_info, binary_mask,
image_size=None, bounding_box=None):
"""Return annotation info in COCO style
Args:
annotation_id: the annotation ID
image_id: the image ID
category_info: the information on categories
binary_mask: a 2D binary numpy array where '1's represent the object
file_name: the file name of each image
image_size: image size in the format of (width, height)
bounding_box: the bounding box for detection task. If bounding_box is not provided,
we will generate one according to the binary mask.
"""
upper = np.max(binary_mask)
lower = np.min(binary_mask)
thresh = upper / 2.0
binary_mask[binary_mask > thresh] = upper
binary_mask[binary_mask <= thresh] = lower
if image_size is not None:
binary_mask = resize_binary_mask(binary_mask.astype(np.uint8), image_size)
binary_mask_encoded = mask_util.encode(np.asfortranarray(binary_mask.astype(np.uint8)))
area = mask_util.area(binary_mask_encoded)
if area < 1:
return None
if bounding_box is None:
bounding_box = mask_util.toBbox(binary_mask_encoded)
rle = mask_util.encode(np.array(binary_mask[...,None], order="F", dtype="uint8"))[0]
rle['counts'] = rle['counts'].decode('ascii')
segmentation = rle
annotation_info = {
"id": annotation_id,
"image_id": image_id,
"category_id": category_info["id"],
"iscrowd": 0,
"area": area.tolist(),
"bbox": bounding_box.tolist(),
"segmentation": segmentation,
"width": binary_mask.shape[1],
"height": binary_mask.shape[0],
}
return annotation_info
# necessay info used for coco style annotations
INFO = {
"description": "ImageNet-1K: pseudo-masks with MaskCut",
"url": "https://github.com/facebookresearch/CutLER",
"version": "1.0",
"year": 2023,
"contributor": "Xudong Wang",
"date_created": datetime.datetime.utcnow().isoformat(' ')
}
LICENSES = [
{
"id": 1,
"name": "Apache License",
"url": "https://github.com/facebookresearch/CutLER/blob/main/LICENSE"
}
]
# only one class, i.e. foreground
CATEGORIES = [
{
'id': 1,
'name': 'fg',
'supercategory': 'fg',
},
]
convert = lambda text: int(text) if text.isdigit() else text.lower()
natrual_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
output = {
"info": INFO,
"licenses": LICENSES,
"categories": CATEGORIES,
"images": [],
"annotations": []}
category_info = {
"is_crowd": 0,
"id": 1
}
if __name__ == "__main__":
parser = argparse.ArgumentParser('Merge pytorch results file into json')
parser.add_argument('--base-dir', type=str,
default='annotations/',
help='Dir to the generated annotation .pt files with CWM')
parser.add_argument('--save-path', type=str, default="coco_train_fixsize480_N3.json",
help='Path to save the merged annotation file')
args = parser.parse_args()
file_list = glob.glob(os.path.join(args.base_dir, '*', '*'))
ann_file = '/ccn2/u/honglinc/datasets/coco/annotations/instances_train2017.json'
with open(ann_file, 'r') as file:
gt_json = json.load(file)
image_id, segmentation_id = 1, 1
image_names = []
for file_name in file_list:
print('processing file name', file_name)
data = torch.load(file_name)
for img_name, mask_list in data.items():
for img in gt_json['images']:
if img['file_name'] == img_name:
height = img['height']
width = img['width']
break
flag = img_name not in image_names
if flag:
image_info = create_image_info(
image_id, img_name, (height, width, 3))
output["images"].append(image_info)
image_names.append(img_name)
for mask in mask_list:
# create coco-style annotation info
if mask.sum() == 0:
continue
pseudo_mask = F.interpolate(mask.float(), size=(height, width), mode='bicubic') > 0.5
pseudo_mask = pseudo_mask[0,0].numpy()
annotation_info = create_annotation_info(
segmentation_id, image_id, category_info, pseudo_mask.astype(np.uint8), None)
if annotation_info is not None:
output["annotations"].append(annotation_info)
segmentation_id += 1
if flag:
image_id += 1
print(image_id, segmentation_id)
# save annotations
with open(args.save_path, 'w') as output_json_file:
json.dump(output, output_json_file)
print(f'dumping {args.save_path}')
print("Done: {} images; {} anns.".format(len(output['images']), len(output['annotations'])))