File size: 4,781 Bytes
3dac99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
'''
Save SAM mask predictions
'''
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import torch.multiprocessing as mp
import pickle
from tqdm import tqdm
import torch
import cv2
import os
import json
import argparse
import numpy as np
img_anno = {
            'ade20k_val':['ADEChallengeData2016/images/validation', 'ADEChallengeData2016/ade20k_panoptic_val.json'],
            'pc_val': ['pascal_ctx_d2/images/validation','' ],
            'pas_val':['pascal_voc_d2/images/validation',''],
            }
sam_checkpoint_dict = {
            'vit_b': 'pretrained_checkpoint/sam_vit_b_01ec64.pth',
            'vit_h': 'pretrained_checkpoint/sam_vit_h_4b8939.pth',
            'vit_l': 'pretrained_checkpoint/sam_vit_l_0b3195.pth',
            'vit_t': 'pretrained_checkpoint/mobile_sam.pt'
            }

def process_images(args, gpu, data_chunk, save_path, if_parallel):
    def to_parallel(if_parallel):
        sam_checkpoint = sam_checkpoint_dict[args.sam_model]
        sam = sam_model_registry[args.sam_model](checkpoint=sam_checkpoint)
        if not if_parallel:
            torch.cuda.set_device(gpu)
            sam = sam.cuda()
        else:
            sam = sam.cuda()
            sam =  torch.nn.DataParallel(sam)
            sam = sam.module
        return sam
    
    sam = to_parallel(if_parallel)
    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        pred_iou_thresh=0.8,
        stability_score_thresh=0.7,
        crop_n_layers=0,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100,  
        output_mode='coco_rle'
    )
    # Process each image
    for image_info in tqdm(data_chunk):
        if isinstance(image_info, dict):
            if 'coco_url' in image_info:
                coco_url = image_info['coco_url']
                file_name = coco_url.split('/')[-1].split('.')[0] + '.jpg'
            elif 'file_name' in image_info:
                file_name = image_info['file_name'].split('.')[0] + '.jpg'
            file_path = os.path.join(dataset_path,img_anno[args.data_name][0])
        else: 
            assert isinstance(image_info, str)
            file_name = image_info.split('.')[0] + '.jpg'
            file_path = os.path.join(dataset_path,img_anno[args.data_name][0])
        image_path = f'{file_path}/{file_name}'
        try:
            id =file_name.split('.')[0]
            id = id.replace('/','_')
            savepath = f'{save_path}/{id}.pkl' 
            if not os.path.exists(savepath):
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                everything_mask = mask_generator.generate(image, train=False) 
                everything_mask = sorted(everything_mask, key=lambda x: x['area'], reverse=True)
                if len(everything_mask) >50:
                    everything_mask = everything_mask[:50]
                with open(savepath, 'wb') as f:
                    pickle.dump(everything_mask, f)
        except Exception as e:
            print(f"Failed to load or convert image at {image_path}. Error: {e}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_name', type=str, default='pas_val')
    parser.add_argument('--sam_model', type=str, default='vit_h')
    argss = parser.parse_args()
    gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
    dataset_path = os.getenv("DETECTRON2_DATASETS", "/users/cx_xchen/DATASETS/")

    num_gpus = len([x.strip() for x in gpus.split(",") if x.strip().isdigit()])
    print(f"Using {num_gpus} GPUs")
    # File paths
    if img_anno[argss.data_name][1] != '':
        json_file_path = os.path.join(dataset_path, img_anno[argss.data_name][1])
        # Load data
        with open(json_file_path, 'r') as file:
            data = json.load(file)
        # Split data into chunks for each GPU
        data_chunks = np.array_split(data['images'], num_gpus)
    else:
        image_dir = os.path.join(dataset_path, img_anno[argss.data_name][0])
        image_files = os.listdir(image_dir)
        data_chunks = np.array_split(image_files, num_gpus)
    # Create processes
    save_path = f'output/SAM_masks_pred/{argss.sam_model}_{argss.data_name}'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    processes = []
    parallel = False
    # if  parallel:
        # assert num_gpus>1
    for gpu in range(num_gpus):
        p = mp.Process(target=process_images, args=(argss, gpu, data_chunks[gpu],save_path, False))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()
    # elif num_gpus<=1:
    #     process_images(argss, None, np.concatenate(data_chunks), save_path, if_parallel=True)
    # else:
    #     assert NotImplemented