scdrand23 commited on
Commit
814a594
·
1 Parent(s): a097349

not working version

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +37 -7
  2. README.pdf +0 -0
  3. __init__.py +0 -0
  4. app.py +242 -0
  5. configs/biomed_seg_lang_v1.yaml +330 -0
  6. configs/biomedparse_inference.yaml +198 -0
  7. datasets/__init__.py +2 -0
  8. datasets/build.py +630 -0
  9. datasets/dataset_mappers/__init__.py +1 -0
  10. datasets/dataset_mappers/biomed_dataset_mapper.py +378 -0
  11. datasets/evaluation/__init__.py +8 -0
  12. datasets/evaluation/captioning_evaluation.py +129 -0
  13. datasets/evaluation/classification_evaluation.py +76 -0
  14. datasets/evaluation/grounding_evaluation.py +173 -0
  15. datasets/evaluation/instance_evaluation.py +107 -0
  16. datasets/evaluation/interactive_evaluation.py +122 -0
  17. datasets/evaluation/panoptic_evaluation.py +199 -0
  18. datasets/evaluation/retrieval_evaluation.py +260 -0
  19. datasets/evaluation/segmentation_evaluation.py +195 -0
  20. datasets/refer.py +371 -0
  21. datasets/registration/__init__.py +3 -0
  22. datasets/registration/register_biomed_datasets.py +123 -0
  23. datasets/semseg_loader.py +10 -0
  24. datasets/utils/refcoco2json.py +41 -0
  25. datasets/utils/refer.py +372 -0
  26. datasets/visual_sampler/__init__.py +12 -0
  27. datasets/visual_sampler/circle.py +106 -0
  28. datasets/visual_sampler/mask_generators.py +215 -0
  29. datasets/visual_sampler/point.py +74 -0
  30. datasets/visual_sampler/polygon.py +137 -0
  31. datasets/visual_sampler/sampler.py +77 -0
  32. datasets/visual_sampler/scribble.py +96 -0
  33. datasets/visual_sampler/simpleclick_sampler.py +252 -0
  34. docker/Dockerfile +32 -0
  35. docker/README.md +9 -0
  36. docker/data_env.sh +1 -0
  37. docker/docker_build.sh +1 -0
  38. docker/docker_run.sh +1 -0
  39. docker/setup_inside_docker.sh +10 -0
  40. entry.py +92 -0
  41. environment.yml +149 -0
  42. figures/main_figure_1a.py +99 -0
  43. figures/main_figure_1b.py +101 -0
  44. figures/main_figure_2a.py +93 -0
  45. figures/main_figure_3b.py +83 -0
  46. figures/main_figure_3c.py +83 -0
  47. figures/main_figure_3d.py +83 -0
  48. figures/plots/IRI_mean_improvement_medsam.pdf +0 -0
  49. figures/plots/IRI_mean_improvement_medsam.png +0 -0
  50. figures/plots/IRI_mean_improvement_sam.pdf +0 -0
README.md CHANGED
@@ -1,14 +1,44 @@
1
  ---
2
- title: HakimAiV2
3
- emoji: 🐠
4
- colorFrom: green
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
- license: cc-by-nc-4.0
11
- short_description: hakim ai by cbai
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: HakimAi
3
+ emoji: 🏥
4
+ colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: "5.9.0"
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
+ # HakimAi
13
+
14
+ A medical imaging analysis platform powered by BiomedParse, offering comprehensive biomedical image analysis across multiple modalities.
15
+
16
+ ## Features
17
+
18
+ - **Multi-modal Analysis**: Support for various medical imaging types including X-ray, CT, MRI, pathology, and more
19
+ - **Advanced Detection**: Automated identification and segmentation of medical objects and conditions
20
+ - **Interactive Interface**: User-friendly Gradio interface for easy image upload and analysis
21
+ - **Powered by BiomedParse**: Utilizes Microsoft's BiomedParse foundation model for accurate medical image analysis
22
+
23
+ ## Usage
24
+
25
+ 1. Upload your medical image
26
+ 2. Select the analysis type
27
+ 3. View the results including segmentation masks and detection results
28
+
29
+ ## Technical Details
30
+
31
+ This space uses:
32
+ - BiomedParse foundation model for medical image analysis
33
+ - Gradio for the web interface
34
+ - Git LFS for handling large files
35
+ - Python 3.9+ environment
36
+
37
+ ## Model Information
38
+
39
+ Based on BiomedParse, capable of:
40
+ - Segmentation
41
+ - Detection
42
+ - Recognition across nine biomedical modalities
43
+
44
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
README.pdf ADDED
Binary file (44.6 kB). View file
 
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from typing import Tuple, Optional
4
+ import os
5
+ import shutil
6
+ import sys
7
+ from pathlib import Path
8
+ import cv2
9
+ import gradio as gr
10
+ import numpy as np
11
+ import spaces
12
+ # import supervision as sv
13
+ import torch
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ import sys
17
+ from pathlib import Path
18
+ from huggingface_hub import login
19
+ # from dotenv import load_dotenv
20
+
21
+ # For Hugging Face Spaces, secrets are automatically loaded as environment variables
22
+ token = os.getenv("HF_TOKEN")
23
+ if token:
24
+ login(token=token)
25
+ # Clear Hugging Face cache
26
+ # cache_dirs = [
27
+ # "/home/user/.cache/huggingface/",
28
+ # "/home/user/.cache/torch/",
29
+ # "/home/user/.cache/pip/"
30
+ # ]
31
+
32
+ # for cache_dir in cache_dirs:
33
+ # if os.path.exists(cache_dir):
34
+ # print(f"Clearing cache: {cache_dir}")
35
+ # shutil.rmtree(cache_dir, ignore_errors=True)
36
+ # Add the current directory to Python path
37
+ current_dir = Path(__file__).parent
38
+ sys.path.append(str(current_dir))
39
+ # sys.path.append("./BiomedParse/")
40
+ # BIOMEDPARSE_PATH = Path(__file__).parent / "BiomedParse"
41
+ # sys.path.append(str(BIOMEDPARSE_PATH))
42
+ # sys.path.append(str(BIOMEDPARSE_PATH / "BiomedParse")) # Add the inner BiomedParse directory
43
+ from modeling.BaseModel import BaseModel
44
+ from modeling import build_model
45
+ from utilities.arguments import load_opt_from_config_files
46
+ from utilities.constants import BIOMED_CLASSES
47
+ from inference_utils.inference import interactive_infer_image
48
+ from inference_utils.output_processing import check_mask_stats
49
+ from inference_utils.processing_utils import read_rgb
50
+
51
+ import spaces
52
+
53
+ # breakpoint()
54
+ MARKDOWN = """
55
+ # <div style="text-align: center; font-size: 2.5em;">ሀ<span style="color: #32CD32;">A</span>ኪ<span style="color: #FFD700;">i</span>ም <sup>AI</sup></div>
56
+
57
+ <div>
58
+ <a href="https://cyberbrainai.com/">
59
+ <img src="https://cyberbrainai.com/assets/logo.svg" alt="CyberBrain AI" style="display:inline-block; width:50px; height:50px;">
60
+ </a>
61
+ <a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb">
62
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="ድinቅneሽ" style="display:inline-block;">
63
+ </a>
64
+ <a href="https://www.youtube.com/watch?v=Dv003fTyO-Y">
65
+ <img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;">
66
+ </a>
67
+ </div>
68
+
69
+ This demo integrates BiomedParse, a foundation model for joint segmentation, detection, and recognition across 9 biomedical imaging modalities. The model supports:
70
+
71
+ - Segmentation/Detection/Recognition across multiple modalities (CT, MRI, X-Ray, etc.)
72
+ - Text-prompted object detection
73
+ - Recognition of anatomical structures and abnormalities
74
+
75
+
76
+ """
77
+
78
+ IMAGE_PROCESSING_EXAMPLES = [
79
+ ["BiomedParse Segmentation",
80
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/T0011.jpg",
81
+ "Optic disc in retinal Fundus"],
82
+ ["BiomedParse Segmentation",
83
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/Part_3_226_pathology_breast.png",
84
+ "optic disc, optic cup"],
85
+ ["BiomedParse Segmentation",
86
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/covid_1585.png",
87
+ "COVID-19 infection in chest X-Ray"],
88
+ ["BiomedParse Segmentation",
89
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/TCGA_HT_7856_19950831_8_MRI-FLAIR_brain.png",
90
+ "Lower-grade glioma in brain MRI"],
91
+ ["BiomedParse Segmentation",
92
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/LIDC-IDRI-0140_143_280_CT_lung.png",
93
+ "COVID-19 infection in chest CT"],
94
+ ["BiomedParse Segmentation",
95
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/144DME_as_F.jpeg",
96
+ "Cystoid macular edema in retinal OCT"],
97
+ ["BiomedParse Segmentation",
98
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/Part_1_516_pathology_breast.png",
99
+ "Glandular structure in colon Pathology"],
100
+ ["BiomedParse Segmentation",
101
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/ISIC_0015551.jpg",
102
+ "Melanoma in skin Dermoscopy"],
103
+ ["BiomedParse Segmentation",
104
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/C3_EndoCV2021_00462.jpg",
105
+ "Neoplastic polyp in colon Endoscope"]
106
+ ]
107
+
108
+ BIOMEDPARSE_MODES = {
109
+ "CT": ["abdomen", "colon", "liver", "lung", "pelvis"],
110
+ "MRI": ["brain", "heart", "prostate", "abdomen"],
111
+ "MRI-FLAIR": ["brain"],
112
+ "MRI-T1-Gd": ["brain"],
113
+ "MRI-T2": ["prostate"],
114
+ "OCT": ["retinal"],
115
+ "X-Ray": ["chest"],
116
+ "Dermoscopy": ["skin"],
117
+ "Endoscope": ["colon"],
118
+ "Fundus": ["retinal"],
119
+ "Pathology": ["bladder", "breast", "cervix", "colon", "esophagus", "kidney",
120
+ "liver", "ovarian", "prostate", "stomach", "testis", "thyroid", "uterus"],
121
+ "Ultrasound": ["breast", "heart", "transperineal"]
122
+ }
123
+
124
+ IMAGE_INFERENCE_MODES = [
125
+ "BIOMED SEGMENTATION",
126
+ "BIOMED DETECTION",
127
+ "BIOMED RECOGNITION",
128
+ "BIOMED SEGMENTATION + DETECTION",
129
+ "BIOMED SEGMENTATION + RECOGNITION",
130
+ "BIOMED DETECTION + RECOGNITION",
131
+ "BIOMED SEGMENTATION + DETECTION + RECOGNITION"
132
+ ]
133
+
134
+
135
+ def on_mode_dropdown_change(selected_mode):
136
+ if selected_mode in IMAGE_INFERENCE_MODES:
137
+ # Show modality dropdown and hide other inputs initially
138
+ return [
139
+ gr.Dropdown(visible=True, choices=list(BIOMEDPARSE_MODES.keys()), label="Modality"),
140
+ gr.Dropdown(visible=True, label="Anatomical Site"),
141
+ gr.Textbox(visible=False),
142
+ gr.Textbox(visible=False)
143
+ ]
144
+ else:
145
+ # Original behavior for other modes
146
+ return [
147
+ gr.Dropdown(visible=False),
148
+ gr.Dropdown(visible=False),
149
+ gr.Textbox(visible=True),
150
+ gr.Textbox(visible=(selected_mode == None))
151
+ ]
152
+
153
+ def on_modality_change(modality):
154
+ if modality:
155
+ return gr.Dropdown(choices=BIOMEDPARSE_MODES[modality], visible=True)
156
+ return gr.Dropdown(visible=False)
157
+
158
+
159
+ def initialize_model():
160
+ opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"])
161
+ pretrained_pth = 'hf_hub:microsoft/BiomedParse'
162
+ opt['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
163
+ model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval()
164
+ with torch.no_grad():
165
+ model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
166
+ BIOMED_CLASSES + ["background"], is_eval=True
167
+ )
168
+ return model
169
+
170
+
171
+ model = initialize_model()
172
+
173
+
174
+ # Utility functions
175
+ @spaces.GPU
176
+ @torch.inference_mode()
177
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
178
+ def process_image(image_path, text_prompts, modality):
179
+ image = read_rgb(image_path)
180
+ text_prompts = [prompt.strip() for prompt in text_prompts.split(',')]
181
+
182
+ # Run inference
183
+ pred_masks = interactive_infer_image(model, Image.fromarray(image), text_prompts)
184
+
185
+ # Prepare outputs
186
+ results = []
187
+ dice_scores = []
188
+ p_values = []
189
+
190
+ for i, prompt in enumerate(text_prompts):
191
+ # Calculate p-value for the selected modality
192
+ print("PROMPT: ", prompt, flush=True)
193
+ p_value = check_mask_stats(image, pred_masks[i] * 255, modality, prompt)
194
+ p_values.append(f"P-value for '{prompt}' ({modality}): {p_value:.4f}")
195
+
196
+ # Overlay predictions on the image
197
+ overlay_image = image.copy()
198
+ overlay_image[pred_masks[i] > 0.5] = [255, 0, 0] # Highlight predictions in red
199
+ results.append(overlay_image)
200
+
201
+ return results, p_values
202
+
203
+ # Define Gradio interface
204
+ with gr.Blocks() as demo:
205
+ gr.Markdown(MARKDOWN)
206
+ with gr.Row():
207
+ with gr.Column():
208
+ image_input = gr.Image(type="filepath", label="Input Image")
209
+ prompts_input = gr.Textbox(lines=2, placeholder="Enter prompts separated by commas...", label="Prompts")
210
+ modality_dropdown = gr.Dropdown(
211
+ choices=BIOMEDPARSE_MODES.keys(),
212
+ value=BIOMEDPARSE_MODES.keys()[0],
213
+ label="Modality"
214
+ )
215
+ submit_btn = gr.Button("Submit")
216
+ with gr.Column():
217
+ output_gallery = gr.Gallery(label="Predicted Masks")
218
+ pvalue_output = gr.Textbox(label="P-values", interactive=False)
219
+
220
+ submit_btn.click(
221
+ process_image,
222
+ inputs=[image_input, prompts_input, modality_dropdown],
223
+ outputs=[output_gallery, pvalue_output]
224
+ )
225
+ with gr.Row():
226
+ gr.Examples(
227
+ fn=process_image,
228
+ examples=IMAGE_PROCESSING_EXAMPLES,
229
+ inputs=[
230
+ image_processing_mode_dropdown_component,
231
+ image_processing_image_input_component,
232
+ image_processing_text_input_component
233
+ ],
234
+ outputs=[
235
+ image_processing_image_output_component,
236
+ image_processing_text_output_component
237
+ ],
238
+ run_on_click=True
239
+ )
240
+
241
+ # Launch the app
242
+ demo.launch()
configs/biomed_seg_lang_v1.yaml ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ # Define Test/Trainer/Saving
9
+ PIPELINE: XDecoderPipeline
10
+ TRAINER: xdecoder
11
+ SAVE_DIR: './output'
12
+ base_path: "./"
13
+
14
+ # Resume Logistic
15
+ RESUME: false
16
+ WEIGHT: false
17
+ RESUME_FROM: ''
18
+ EVAL_AT_START: false
19
+ SAVE_CHECKPOINT: True
20
+
21
+ # Logging and Debug
22
+ WANDB: False
23
+ LOG_EVERY: 100
24
+ FIND_UNUSED_PARAMETERS: false
25
+
26
+ # Speed up training
27
+ FP16: false
28
+ PORT: '36873'
29
+
30
+ # misc
31
+ LOADER:
32
+ JOINT: True
33
+ KEY_DATASET: ""
34
+ SAMPLE_PROB: "prop" # sampling probability proportional to data size. Use "equal" for each bach from all datasets
35
+ MIXING_LEVEL: 1 # num of different datasets for batch mixing on each GPU
36
+
37
+ RANDOM_SEED: 2024
38
+
39
+ STANDARD_TEXT_FOR_EVAL: False
40
+
41
+ ##################
42
+ # Task settings
43
+ ##################
44
+ VERBOSE: true
45
+ MODEL:
46
+ DEVICE: "cuda" # or "cpu" if no GPU available
47
+ NAME: seem_model_v1
48
+ HEAD: xdecoder_head
49
+ MASK_ON: false
50
+ KEYPOINT_ON: false
51
+ LOAD_PROPOSALS: false
52
+ DIM_PROJ: 512
53
+ TEXT:
54
+ ARCH: vlpencoder
55
+ NAME: transformer
56
+ TOKENIZER: clip
57
+ CONTEXT_LENGTH: 77 #256 # 77
58
+ WIDTH: 512 # 768 # 512
59
+ HEADS: 8
60
+ LAYERS: 12 # 6
61
+ AUTOGRESSIVE: True
62
+ BACKBONE:
63
+ NAME: focal # focal_dw # focal
64
+ PRETRAINED: ''
65
+ LOAD_PRETRAINED: false
66
+ FOCAL:
67
+ PRETRAIN_IMG_SIZE: 224
68
+ PATCH_SIZE: 4
69
+ EMBED_DIM: 192 # 96 # 192
70
+ DEPTHS: [2, 2, 18, 2] # [2, 2, 6, 2] # [2, 2, 18, 2]
71
+ FOCAL_LEVELS: [4, 4, 4, 4] # [3, 3, 3, 3] # [4, 4, 4, 4]
72
+ FOCAL_WINDOWS: [3, 3, 3, 3]
73
+ DROP_PATH_RATE: 0.3
74
+ MLP_RATIO: 4.0
75
+ DROP_RATE: 0.0
76
+ PATCH_NORM: True
77
+ USE_CONV_EMBED: True
78
+ SCALING_MODULATOR: True
79
+ USE_CHECKPOINT: False
80
+ USE_POSTLN: true
81
+ USE_POSTLN_IN_MODULATION: false
82
+ USE_LAYERSCALE: True
83
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
84
+ OUT_INDICES: [0, 1, 2, 3]
85
+ ENCODER:
86
+ NAME: transformer_encoder_fpn
87
+ IGNORE_VALUE: 255
88
+ NUM_CLASSES: 16
89
+ BINARY_CLASSES: False
90
+ LOSS_WEIGHT: 1.0
91
+ CONVS_DIM: 512
92
+ MASK_DIM: 512
93
+ NORM: "GN"
94
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
95
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
96
+ COMMON_STRIDE: 4
97
+ TRANSFORMER_ENC_LAYERS: 6
98
+ DECODER:
99
+ NAME: seem_v1
100
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
101
+ MASK:
102
+ ENABLED: True
103
+ DETECTION: False
104
+ SPATIAL:
105
+ ENABLED: True
106
+ MAX_ITER: 1
107
+ GROUNDING:
108
+ ENABLED: True
109
+ MAX_LEN: 10
110
+ TEXT_WEIGHT: 2.0
111
+ CLASS_WEIGHT: 0.5
112
+ RETRIEVAL:
113
+ ENABLED: False
114
+ LVIS:
115
+ ENABLED: False
116
+ THRES: 0.7
117
+ OPENIMAGE:
118
+ ENABLED: False
119
+ NEGATIVE_SAMPLES: 5
120
+ GROUNDING:
121
+ ENABLED: False
122
+ MAX_LEN: 5
123
+ CAPTION:
124
+ ENABLED: False
125
+ PHRASE_PROB: 0.5
126
+ SIM_THRES: 0.95
127
+ DEEP_SUPERVISION: True
128
+ NO_OBJECT_WEIGHT: 0.1
129
+ GCLASS_WEIGHT: 0.4
130
+ GMASK_WEIGHT: 1.0
131
+ GDICE_WEIGHT: 1.0
132
+ SCLASS_WEIGHT: 0.4
133
+ SMASK_WEIGHT: 1.0
134
+ SDICE_WEIGHT: 1.0
135
+ OCLASS_WEIGHT: 0.4
136
+ OMASK_WEIGHT: 1.0
137
+ ODICE_WEIGHT: 1.0
138
+ CLASS_WEIGHT: 2.0
139
+ MASK_WEIGHT: 5.0
140
+ DICE_WEIGHT: 5.0
141
+ BBOX_WEIGHT: 5.0
142
+ GIOU_WEIGHT: 2.0
143
+ CAPTION_WEIGHT: 2.0
144
+ COST_SPATIAL:
145
+ CLASS_WEIGHT: 5.0
146
+ MASK_WEIGHT: 2.0
147
+ DICE_WEIGHT: 2.0
148
+ HIDDEN_DIM: 512
149
+ NUM_OBJECT_QUERIES: 101
150
+ NHEADS: 8
151
+ DROPOUT: 0.0
152
+ DIM_FEEDFORWARD: 2048
153
+ MAX_SPATIAL_LEN: [512, 512, 512, 512]
154
+ # ENC_LAYERS: 0
155
+ PRE_NORM: False
156
+ ENFORCE_INPUT_PROJ: False
157
+ SIZE_DIVISIBILITY: 32
158
+ TRAIN_NUM_POINTS: 12544
159
+ OVERSAMPLE_RATIO: 3.0
160
+ IMPORTANCE_SAMPLE_RATIO: 0.75
161
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
162
+ TOP_GROUNDING_LAYERS: 10
163
+ TOP_CAPTION_LAYERS: 10
164
+ TOP_SPATIAL_LAYERS: 10
165
+ TOP_OPENIMAGE_LAYERS: 10
166
+ TEST:
167
+ SEMANTIC_ON: False
168
+ INSTANCE_ON: False
169
+ PANOPTIC_ON: False
170
+ OVERLAP_THRESHOLD: 0.8
171
+ OBJECT_MASK_THRESHOLD: 0.8
172
+ SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: true
173
+
174
+ # Spatial sampler
175
+ STROKE_SAMPLER:
176
+ MAX_CANDIDATE: 1
177
+ CANDIDATE_PROBS: [0.25, 0.25, 0.25, 0.25] # for training only
178
+ CANDIDATE_NAMES: ["Point", "Polygon", "Scribble", "Circle"]
179
+ DILATION: 3
180
+ CIRCLE:
181
+ NUM_STROKES: 5
182
+ STROKE_PRESET: ['object_like', 'object_like_middle', 'object_like_small']
183
+ STROKE_PROB: [0.33, 0.33, 0.33]
184
+ SCRIBBLE:
185
+ NUM_STROKES: 5
186
+ STROKE_PRESET: ['rand_curve', 'rand_curve_small']
187
+ STROKE_PROB: [0.5, 0.5]
188
+ POINT:
189
+ NUM_POINTS: 20
190
+ POLYGON:
191
+ MAX_POINTS: 9
192
+ EVAL:
193
+ MODE: 'best' # best/random/best_random
194
+ NEGATIVE: False
195
+ MAX_ITER: 1
196
+ IOU_ITER: 1
197
+ GROUNDING: True
198
+
199
+ # Multi-modal Architecture, order matters
200
+ ATTENTION_ARCH:
201
+ VARIABLE:
202
+ queries: ['object', 'grounding', 'spatial']
203
+ tokens: ['grounding', 'spatial']
204
+ memories: ['spatial']
205
+ SELF_ATTENTION:
206
+ queries:
207
+ object: ['queries_object']
208
+ grounding: ['queries_grounding', 'tokens_grounding']
209
+ spatial: ['queries_spatial', 'tokens_spatial', 'memories_spatial']
210
+ tokens:
211
+ grounding: ['queries_grounding', 'tokens_grounding']
212
+ spatial: ['tokens_spatial']
213
+ memories:
214
+ spatial: ['memories_spatial']
215
+ CROSS_ATTENTION:
216
+ queries:
217
+ object: True
218
+ grounding: True
219
+ spatial: True
220
+ memories:
221
+ spatial: True
222
+ tokens:
223
+ grounding: False
224
+ spatial: False
225
+ MASKING: ['tokens_spatial', 'tokens_grounding']
226
+ DUPLICATION:
227
+ queries:
228
+ grounding: 'queries_object'
229
+ spatial: 'queries_object'
230
+ SPATIAL_MEMORIES: 32
231
+ QUERY_NUMBER: 3
232
+
233
+ DATASETS:
234
+ TRAIN: [
235
+ 'biomed_BiomedParseData-Demo_demo' # Add your registered training datasets here
236
+ ]
237
+
238
+
239
+
240
+ TEST: [
241
+ 'biomed_BiomedParseData-Demo_demo' # Add your registered test datasets here
242
+ ]
243
+
244
+ CLASS_CONCAT: false
245
+ SIZE_DIVISIBILITY: 32
246
+ PROPOSAL_FILES_TRAIN: []
247
+
248
+ INPUT:
249
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
250
+ PIXEL_STD: [58.395, 57.120, 57.375]
251
+
252
+ TRAIN:
253
+ ASPECT_RATIO_GROUPING: true
254
+ BATCH_SIZE_TOTAL: 4
255
+ BATCH_SIZE_PER_GPU: 4
256
+ SHUFFLE: true
257
+
258
+ TEST:
259
+ DETECTIONS_PER_IMAGE: 100
260
+ NAME: coco_eval
261
+ IOU_TYPE: ['bbox', 'segm']
262
+ USE_MULTISCALE: false
263
+ BATCH_SIZE_TOTAL: 4
264
+ MODEL_FILE: ''
265
+ AUG:
266
+ ENABLED: False
267
+
268
+ DATALOADER:
269
+ FILTER_EMPTY_ANNOTATIONS: False
270
+ NUM_WORKERS: 8
271
+ LOAD_PROPOSALS: False
272
+ SAMPLER_TRAIN: "TrainingSampler"
273
+ ASPECT_RATIO_GROUPING: True
274
+
275
+
276
+ BioMed:
277
+ INPUT:
278
+ PIXEL_MEAN: [64.284, 59.293, 59.962]
279
+ PIXEL_STD: [62.484, 60.865, 59.835]
280
+ DATASET_MAPPER_NAME: "biomed_interactive"
281
+ MIN_SIZE_TRAIN: 900
282
+ MAX_SIZE_TRAIN: 1100
283
+ MIN_SIZE_TRAIN_SAMPLING: 'choice'
284
+ MIN_SIZE_TEST: 900
285
+ MAX_SIZE_TEST: 1100
286
+ IMAGE_SIZE: 1024
287
+ MIN_SCALE: 0.9
288
+ MAX_SCALE: 1.1
289
+ IGNORE_VALUE: 255
290
+ COLOR_AUG_SSD: False
291
+ SIZE_DIVISIBILITY: 32
292
+ RANDOM_FLIP: "none"
293
+ RANDOM_ROTATE: False
294
+ MASK_FORMAT: "polygon"
295
+ MIN_AREA: 30
296
+ FORMAT: "RGB"
297
+ SPATIAL: True
298
+ CROP:
299
+ ENABLED: True
300
+ DATASET:
301
+ DATASET: "biomed"
302
+
303
+
304
+ # Detectron2 training config for optimizer and lr scheduler
305
+ SOLVER:
306
+ BASE_LR: 0.0001
307
+ STEPS: [0.88889, 0.96296]
308
+ MAX_ITER: 1
309
+ GAMMA: 0.1
310
+ WARMUP_FACTOR: 1.0
311
+ WARMUP_ITERS: 10
312
+ WARMUP_METHOD: "linear"
313
+ WEIGHT_DECAY: 0.05
314
+ OPTIMIZER: "ADAMW"
315
+ LR_SCHEDULER_NAME: "WarmupMultiStepLR"
316
+ LR_MULTIPLIER:
317
+ backbone: 0.1
318
+ lang_encoder: 0.1
319
+ FIX_PARAM:
320
+ backbone: True
321
+ lang_encoder: True
322
+ pixel_decoder: True
323
+ WEIGHT_DECAY_NORM: 0.0
324
+ WEIGHT_DECAY_EMBED: 0.0
325
+ CLIP_GRADIENTS:
326
+ ENABLED: True
327
+ CLIP_TYPE: "full_model"
328
+ CLIP_VALUE: 5.0 # 0.01
329
+ NORM_TYPE: 2.0
330
+ MAX_NUM_EPOCHS: 50
configs/biomedparse_inference.yaml ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define Test/Trainer/Saving
2
+ PIPELINE: XDecoderPipeline
3
+ TRAINER: xdecoder
4
+ SAVE_DIR: '../../data/output/test'
5
+ base_path: "./"
6
+
7
+ # Resume Logistic
8
+ RESUME: false
9
+ WEIGHT: false
10
+ RESUME_FROM: ''
11
+ EVAL_AT_START: false
12
+
13
+ # Logging and Debug
14
+ WANDB: False
15
+ LOG_EVERY: 100
16
+ FIND_UNUSED_PARAMETERS: false
17
+
18
+ # Speed up training
19
+ FP16: false
20
+ PORT: '36873'
21
+
22
+ # misc
23
+ LOADER:
24
+ JOINT: False
25
+ KEY_DATASET: 'coco'
26
+
27
+ STANDARD_TEXT_FOR_EVAL: False
28
+
29
+ ##################
30
+ # Task settings
31
+ ##################
32
+ VERBOSE: true
33
+ MODEL:
34
+ device: "cuda" # or "cpu" if no GPU available
35
+ DEVICE: "cuda" # or "cpu" if no GPU available
36
+ NAME: seem_model_demo
37
+ HEAD: xdecoder_head
38
+ DIM_PROJ: 512
39
+ TEXT:
40
+ ARCH: vlpencoder
41
+ NAME: transformer
42
+ TOKENIZER: clip
43
+ CONTEXT_LENGTH: 77 # 77
44
+ WIDTH: 512
45
+ HEADS: 8
46
+ LAYERS: 12 # 6
47
+ AUTOGRESSIVE: True
48
+ BACKBONE:
49
+ NAME: focal
50
+ PRETRAINED: ''
51
+ LOAD_PRETRAINED: false
52
+ FOCAL:
53
+ PRETRAIN_IMG_SIZE: 224
54
+ PATCH_SIZE: 4
55
+ EMBED_DIM: 192
56
+ DEPTHS: [2, 2, 18, 2]
57
+ FOCAL_LEVELS: [4, 4, 4, 4]
58
+ FOCAL_WINDOWS: [3, 3, 3, 3]
59
+ DROP_PATH_RATE: 0.3
60
+ MLP_RATIO: 4.0
61
+ DROP_RATE: 0.0
62
+ PATCH_NORM: True
63
+ USE_CONV_EMBED: True
64
+ SCALING_MODULATOR: True
65
+ USE_CHECKPOINT: False
66
+ USE_POSTLN: true
67
+ USE_POSTLN_IN_MODULATION: false
68
+ USE_LAYERSCALE: True
69
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
70
+ OUT_INDICES: [0, 1, 2, 3]
71
+ ENCODER:
72
+ NAME: transformer_encoder_fpn
73
+ IGNORE_VALUE: 255
74
+ NUM_CLASSES: 16
75
+ BINARY_CLASSES: False
76
+ LOSS_WEIGHT: 1.0
77
+ CONVS_DIM: 512
78
+ MASK_DIM: 512
79
+ NORM: "GN"
80
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
81
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
82
+ COMMON_STRIDE: 4
83
+ TRANSFORMER_ENC_LAYERS: 6
84
+ DECODER:
85
+ NAME: seem_demo
86
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
87
+ MASK:
88
+ ENABLED: False
89
+ DETECTION: False
90
+ SPATIAL:
91
+ ENABLED: True
92
+ MAX_ITER: 1
93
+ GROUNDING:
94
+ ENABLED: True
95
+ MAX_LEN: 5
96
+ TEXT_WEIGHT: 2.0
97
+ CLASS_WEIGHT: 0.5
98
+ VISUAL:
99
+ ENABLED: False
100
+ AUDIO:
101
+ ENABLED: False
102
+ RETRIEVAL:
103
+ ENABLED: False
104
+ LVIS:
105
+ ENABLED: True
106
+ THRES: 0.7
107
+ OPENIMAGE:
108
+ ENABLED: False
109
+ NEGATIVE_SAMPLES: 5
110
+ GROUNDING:
111
+ ENABLED: False
112
+ MAX_LEN: 5
113
+ CAPTION:
114
+ ENABLED: False
115
+ PHRASE_PROB: 0.5
116
+ SIM_THRES: 0.95
117
+ DEEP_SUPERVISION: True
118
+ NO_OBJECT_WEIGHT: 0.1
119
+ GCLASS_WEIGHT: 0.4
120
+ GMASK_WEIGHT: 1.0
121
+ GDICE_WEIGHT: 1.0
122
+ SCLASS_WEIGHT: 0.4
123
+ SMASK_WEIGHT: 1.0
124
+ SDICE_WEIGHT: 1.0
125
+ OCLASS_WEIGHT: 0.4
126
+ OMASK_WEIGHT: 1.0
127
+ ODICE_WEIGHT: 1.0
128
+ CLASS_WEIGHT: 2.0
129
+ MASK_WEIGHT: 5.0
130
+ DICE_WEIGHT: 5.0
131
+ BBOX_WEIGHT: 5.0
132
+ GIOU_WEIGHT: 2.0
133
+ CAPTION_WEIGHT: 2.0
134
+ COST_SPATIAL:
135
+ CLASS_WEIGHT: 5.0
136
+ MASK_WEIGHT: 2.0
137
+ DICE_WEIGHT: 2.0
138
+ HIDDEN_DIM: 512
139
+ NUM_OBJECT_QUERIES: 101
140
+ NHEADS: 8
141
+ DROPOUT: 0.0
142
+ DIM_FEEDFORWARD: 2048
143
+ MAX_SPATIAL_LEN: [512, 512, 512, 512]
144
+ # ENC_LAYERS: 0
145
+ PRE_NORM: False
146
+ ENFORCE_INPUT_PROJ: False
147
+ SIZE_DIVISIBILITY: 32
148
+ TRAIN_NUM_POINTS: 12544
149
+ OVERSAMPLE_RATIO: 3.0
150
+ IMPORTANCE_SAMPLE_RATIO: 0.75
151
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
152
+ TOP_GROUNDING_LAYERS: 10
153
+ TOP_CAPTION_LAYERS: 10
154
+ TOP_SPATIAL_LAYERS: 10
155
+ TOP_OPENIMAGE_LAYERS: 10
156
+ TEST:
157
+ SEMANTIC_ON: True
158
+ INSTANCE_ON: True
159
+ PANOPTIC_ON: True
160
+ OVERLAP_THRESHOLD: 0.8
161
+ OBJECT_MASK_THRESHOLD: 0.4
162
+ SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false
163
+ DETECTIONS_PER_IMAGE: 100
164
+
165
+ # Multi-modal Architecture, order matters
166
+ ATTENTION_ARCH:
167
+ VARIABLE:
168
+ queries: ['object']
169
+ tokens: ['grounding', 'spatial', 'visual', 'audio']
170
+ SELF_ATTENTION:
171
+ queries:
172
+ object: ['queries_object', 'tokens_grounding', 'tokens_spatial', 'tokens_visual', 'tokens_audio']
173
+ tokens:
174
+ grounding: ['queries_object', 'tokens_grounding']
175
+ spatial: ['tokens_spatial']
176
+ visual: ['tokens_visual']
177
+ audio: ['queries_object', 'tokens_audio']
178
+ CROSS_ATTENTION:
179
+ queries:
180
+ object: True
181
+ tokens:
182
+ grounding: False
183
+ spatial: False
184
+ visual: False
185
+ audio: False
186
+ MASKING: ['tokens_spatial', 'tokens_grounding', 'tokens_visual', 'tokens_audio']
187
+ DUPLICATION:
188
+ queries:
189
+ grounding: 'queries_object'
190
+ spatial: 'queries_object'
191
+ SPATIAL_MEMORIES: 32
192
+
193
+ INPUT:
194
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
195
+ PIXEL_STD: [58.395, 57.120, 57.375]
196
+ # INPUT:
197
+ # PIXEL_MEAN: [64.284, 59.293, 59.962]
198
+ # PIXEL_STD: [62.484, 60.865, 59.835]
datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import registration
2
+ from .build import build_train_dataloader, build_eval_dataloader, build_evaluator
datasets/build.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Modified by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+ # Copyright (c) Facebook, Inc. and its affiliates.
8
+
9
+ import os
10
+ import numpy as np
11
+ import itertools
12
+ import logging
13
+ from typing import Any, Callable, Dict, List, Optional, Union
14
+
15
+ import torch
16
+ import torch.utils.data
17
+ import torch.utils.data as torchdata
18
+
19
+ import detectron2.utils.comm as comm
20
+ from detectron2.data.build import (
21
+ build_batch_data_loader,
22
+ load_proposals_into_dataset,
23
+ trivial_batch_collator,
24
+ )
25
+ from detectron2.data import MetadataCatalog
26
+ from detectron2.data.catalog import DatasetCatalog
27
+ from detectron2.data.common import DatasetFromList, MapDataset
28
+ from detectron2.data.dataset_mapper import DatasetMapper
29
+ from detectron2.data.samplers import InferenceSampler, TrainingSampler
30
+ from detectron2.evaluation import (
31
+ CityscapesInstanceEvaluator,
32
+ CityscapesSemSegEvaluator,
33
+ COCOEvaluator,
34
+ DatasetEvaluators,
35
+ LVISEvaluator,
36
+ verify_results,
37
+ )
38
+ from fvcore.common.config import CfgNode
39
+
40
+ from .dataset_mappers import *
41
+ from .evaluation import (InstanceSegEvaluator,
42
+ ClassificationEvaluator,
43
+ SemSegEvaluator,
44
+ RetrievalEvaluator,
45
+ #CaptioningEvaluator,
46
+ COCOPanopticEvaluator,
47
+ GroundingEvaluator,
48
+ InteractiveEvaluator,
49
+ )
50
+ from modeling.utils import configurable
51
+ from utilities.distributed import get_world_size
52
+
53
+ class JointLoader(torchdata.IterableDataset):
54
+ """
55
+ Randomly sampple from one of the dataloaders per worker in each iteration.
56
+ The sampling probability is determined by the size of each dataset.
57
+ All examples from one worker (GPU) are from the same dataset in the iteration.
58
+ Mixing is achieved through multiple workers (GPUs).
59
+ """
60
+ def __init__(self, loaders, key_dataset, sample_prob, mixing_level):
61
+ dataset_names = []
62
+ for key, loader in loaders.items():
63
+ name = "{}".format(key.split('_')[0])
64
+ setattr(self, name, loader)
65
+ dataset_names += [name]
66
+ self.dataset_names = dataset_names
67
+ self.key_dataset = key_dataset
68
+ if sample_prob == 'prop':
69
+ self.sample_prob = [len(getattr(self, key)) for key in self.dataset_names]
70
+ elif sample_prob == 'equal':
71
+ self.sample_prob = [1 for key in self.dataset_names]
72
+ elif sample_prob == 'sqrt':
73
+ self.sample_prob = [np.sqrt(len(getattr(self, key))) for key in self.dataset_names]
74
+ self.sample_prob = [p/sum(self.sample_prob) for p in self.sample_prob]
75
+ self.mixing_level = mixing_level
76
+
77
+ # Not sure how expensive `len(getattr(self, name))` is. computing this once and cache.
78
+ # this assumes the len of the underlying data loaders do not change.
79
+ self._len = sum(len(getattr(self, name)) for name in self.dataset_names)
80
+
81
+ def __iter__(self):
82
+ # Reset iterators at the start of each new epoch
83
+ self.iterators = {name: iter(getattr(self, name)) for name in self.dataset_names}
84
+ self._count = 0
85
+ return self
86
+
87
+ def __next__(self):
88
+ while self._count < self._len:
89
+ # Randomly select a dataloader
90
+ name = np.random.choice(self.dataset_names, size=None, replace=False, p=self.sample_prob)
91
+ iterator = self.iterators[name]
92
+
93
+ try:
94
+ # Get next batch from the selected dataloader
95
+ self._count += 1
96
+ return next(iterator)
97
+ except StopIteration:
98
+ # If the selected dataloader is exhausted, reinitialize it
99
+ self.iterators[name] = iter(getattr(self, name))
100
+ raise StopIteration
101
+
102
+ def __len__(self):
103
+ return self._len
104
+
105
+ def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names):
106
+ """
107
+ Filter out images with none annotations or only crowd annotations
108
+ (i.e., images without non-crowd annotations).
109
+ A common training-time preprocessing on COCO dataset.
110
+
111
+ Args:
112
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
113
+
114
+ Returns:
115
+ list[dict]: the same format, but filtered.
116
+ """
117
+ num_before = len(dataset_dicts)
118
+
119
+ def valid(anns):
120
+ for ann in anns:
121
+ if isinstance(ann, list):
122
+ for instance in ann:
123
+ if instance.get("iscrowd", 0) == 0:
124
+ return True
125
+ else:
126
+ if ann.get("iscrowd", 0) == 0:
127
+ return True
128
+ return False
129
+
130
+ dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
131
+ num_after = len(dataset_dicts)
132
+ logger = logging.getLogger(__name__)
133
+ logger.info(
134
+ "Removed {} images with no usable annotations. {} images left.".format(
135
+ num_before - num_after, num_after
136
+ )
137
+ )
138
+ return dataset_dicts
139
+
140
+
141
+ def get_detection_dataset_dicts(
142
+ dataset_names, filter_empty=True, proposal_files=None
143
+ ):
144
+ """
145
+ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
146
+
147
+ Args:
148
+ dataset_names (str or list[str]): a dataset name or a list of dataset names
149
+ filter_empty (bool): whether to filter out images without instance annotations
150
+ proposal_files (list[str]): if given, a list of object proposal files
151
+ that match each dataset in `dataset_names`.
152
+
153
+ Returns:
154
+ list[dict]: a list of dicts following the standard dataset dict format.
155
+ """
156
+ if isinstance(dataset_names, str):
157
+ dataset_names = [dataset_names]
158
+ assert len(dataset_names)
159
+
160
+ dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
161
+ for dataset_name, dicts in zip(dataset_names, dataset_dicts):
162
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
163
+
164
+ if proposal_files is not None:
165
+ assert len(dataset_names) == len(proposal_files)
166
+ # load precomputed proposals from proposal files
167
+ dataset_dicts = [
168
+ load_proposals_into_dataset(dataset_i_dicts, proposal_file)
169
+ for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
170
+ ]
171
+
172
+ dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
173
+
174
+ has_instances = "annotations" in dataset_dicts[0]
175
+ if filter_empty and has_instances:
176
+ dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names)
177
+
178
+ assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names))
179
+ return dataset_dicts
180
+
181
+
182
+ def _test_loader_from_config(cfg, dataset_name, mapper=None):
183
+ """
184
+ Uses the given `dataset_name` argument (instead of the names in cfg), because the
185
+ standard practice is to evaluate each test set individually (not combining them).
186
+ """
187
+ if isinstance(dataset_name, str):
188
+ dataset_name = [dataset_name]
189
+
190
+ dataset = get_detection_dataset_dicts(
191
+ dataset_name,
192
+ filter_empty=False,
193
+ proposal_files=None,
194
+ )
195
+ if mapper is None:
196
+ mapper_cfg = CfgNode({'INPUT': cfg['INPUT'], 'MODEL': cfg['MODEL'], 'DATASETS': cfg['DATASETS']})
197
+ mapper = DatasetMapper(mapper_cfg, False)
198
+ assert cfg['TEST']['BATCH_SIZE_TOTAL'] % get_world_size() == 0, "Evaluation total batchsize is not divisible by gpu number"
199
+ #batch_size = cfg['TEST']['BATCH_SIZE_TOTAL'] // get_world_size()
200
+ batch_size = 1
201
+
202
+ return {
203
+ "dataset": dataset,
204
+ "mapper": mapper,
205
+ "num_workers": cfg['DATALOADER']['NUM_WORKERS'],
206
+ "sampler": InferenceSampler(len(dataset)),
207
+ "batch_size": batch_size,
208
+ }
209
+
210
+
211
+ @configurable(from_config=_test_loader_from_config)
212
+ def build_detection_test_loader(
213
+ dataset: Union[List[Any], torchdata.Dataset],
214
+ *,
215
+ mapper: Callable[[Dict[str, Any]], Any],
216
+ sampler: Optional[torchdata.Sampler] = None,
217
+ batch_size: int = 1,
218
+ num_workers: int = 0,
219
+ collate_fn: Optional[Callable[[List[Any]], Any]] = None,
220
+ ) -> torchdata.DataLoader:
221
+ """
222
+ Similar to `build_detection_train_loader`, with default batch size = 1,
223
+ and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
224
+ to produce the exact set of all samples.
225
+
226
+ Args:
227
+ dataset: a list of dataset dicts,
228
+ or a pytorch dataset (either map-style or iterable). They can be obtained
229
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
230
+ mapper: a callable which takes a sample (dict) from dataset
231
+ and returns the format to be consumed by the model.
232
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
233
+ sampler: a sampler that produces
234
+ indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
235
+ which splits the dataset across all workers. Sampler must be None
236
+ if `dataset` is iterable.
237
+ batch_size: the batch size of the data loader to be created.
238
+ Default to 1 image per worker since this is the standard when reporting
239
+ inference time in papers.
240
+ num_workers: number of parallel data loading workers
241
+ collate_fn: same as the argument of `torch.utils.data.DataLoader`.
242
+ Defaults to do no collation and return a list of data.
243
+
244
+ Returns:
245
+ DataLoader: a torch DataLoader, that loads the given detection
246
+ dataset, with test-time transformation and batching.
247
+
248
+ Examples:
249
+ ::
250
+ data_loader = build_detection_test_loader(
251
+ DatasetRegistry.get("my_test"),
252
+ mapper=DatasetMapper(...))
253
+
254
+ # or, instantiate with a CfgNode:
255
+ data_loader = build_detection_test_loader(cfg, "my_test")
256
+ """
257
+
258
+ if isinstance(dataset, list):
259
+ dataset = DatasetFromList(dataset, copy=False)
260
+ if mapper is not None:
261
+ dataset = MapDataset(dataset, mapper)
262
+ if isinstance(dataset, torchdata.IterableDataset):
263
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
264
+ else:
265
+ if sampler is None:
266
+ sampler = InferenceSampler(len(dataset))
267
+ return torchdata.DataLoader(
268
+ dataset,
269
+ batch_size=batch_size,
270
+ sampler=sampler,
271
+ drop_last=False,
272
+ num_workers=num_workers,
273
+ collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
274
+ )
275
+
276
+
277
+ def _train_loader_from_config(cfg, dataset_name, mapper, *, dataset=None, sampler=None):
278
+ cfg_datasets = cfg['DATASETS']
279
+ cfg_dataloader = cfg['DATALOADER']
280
+
281
+ if dataset is None:
282
+ dataset = get_detection_dataset_dicts(
283
+ dataset_name,
284
+ filter_empty=cfg_dataloader['FILTER_EMPTY_ANNOTATIONS'],
285
+ proposal_files=cfg_datasets['PROPOSAL_FILES_TRAIN'] if cfg_dataloader['LOAD_PROPOSALS'] else None,
286
+ )
287
+
288
+ if mapper is None:
289
+ mapper = DatasetMapper(cfg, True)
290
+
291
+ if sampler is None:
292
+ sampler_name = cfg_dataloader['SAMPLER_TRAIN']
293
+ logger = logging.getLogger(__name__)
294
+ logger.info("Using training sampler {}".format(sampler_name))
295
+ sampler = TrainingSampler(len(dataset))
296
+
297
+ return {
298
+ "dataset": dataset,
299
+ "sampler": sampler,
300
+ "mapper": mapper,
301
+ "total_batch_size": cfg['TRAIN']['BATCH_SIZE_TOTAL'],
302
+ "aspect_ratio_grouping": cfg_dataloader['ASPECT_RATIO_GROUPING'],
303
+ "num_workers": cfg_dataloader['NUM_WORKERS'],
304
+ }
305
+
306
+
307
+ @configurable(from_config=_train_loader_from_config)
308
+ def build_detection_train_loader(
309
+ dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
310
+ ):
311
+ """
312
+ Build a dataloader for object detection with some default features.
313
+ This interface is experimental.
314
+
315
+ Args:
316
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
317
+ or a map-style pytorch dataset. They can be obtained by using
318
+ :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
319
+ mapper (callable): a callable which takes a sample (dict) from dataset and
320
+ returns the format to be consumed by the model.
321
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
322
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that
323
+ produces indices to be applied on ``dataset``.
324
+ Default to :class:`TrainingSampler`, which coordinates a random shuffle
325
+ sequence across all workers.
326
+ total_batch_size (int): total batch size across all workers. Batching
327
+ simply puts data into a list.
328
+ aspect_ratio_grouping (bool): whether to group images with similar
329
+ aspect ratio for efficiency. When enabled, it requires each
330
+ element in dataset be a dict with keys "width" and "height".
331
+ num_workers (int): number of parallel data loading workers
332
+
333
+ Returns:
334
+ torch.utils.data.DataLoader: a dataloader. Each output from it is a
335
+ ``list[mapped_element]`` of length ``total_batch_size / num_workers``,
336
+ where ``mapped_element`` is produced by the ``mapper``.
337
+ """
338
+ if isinstance(dataset, list):
339
+ dataset = DatasetFromList(dataset, copy=False)
340
+ if mapper is not None:
341
+ dataset = MapDataset(dataset, mapper)
342
+ if sampler is None:
343
+ sampler = TrainingSampler(len(dataset))
344
+ assert isinstance(sampler, torch.utils.data.sampler.Sampler)
345
+ return build_batch_data_loader(
346
+ dataset,
347
+ sampler,
348
+ total_batch_size,
349
+ aspect_ratio_grouping=aspect_ratio_grouping,
350
+ num_workers=num_workers,
351
+ )
352
+
353
+
354
+ def get_config_from_name(cfg, dataset_name):
355
+ # adjust config according to dataset
356
+ if 'refcoco' in dataset_name:
357
+ cfg.update(cfg['REF'])
358
+ return cfg
359
+ elif 'cocomini' in dataset_name:
360
+ cfg.update(cfg['DAVIS'])
361
+ return cfg
362
+ elif 'ytvos' in dataset_name:
363
+ cfg.update(cfg['VOS'])
364
+ return cfg
365
+ elif 'ade600' in dataset_name:
366
+ cfg.update(cfg['DAVIS'])
367
+ return cfg
368
+ elif 'openimage600' in dataset_name:
369
+ cfg.update(cfg['DAVIS'])
370
+ return cfg
371
+ elif 'ade' in dataset_name:
372
+ if 'ADE20K' in cfg.keys():
373
+ cfg.update(cfg['ADE20K'])
374
+ return cfg
375
+ elif 'imagenet' in dataset_name:
376
+ if 'IMAGENET' in cfg.keys():
377
+ cfg.update(cfg['IMAGENET'])
378
+ return cfg
379
+ elif 'vlp' in dataset_name:
380
+ cfg.update(cfg['VLP'])
381
+ return cfg
382
+ elif 'coco' in dataset_name:
383
+ if 'COCO' in cfg.keys():
384
+ cfg.update(cfg['COCO'])
385
+ return cfg
386
+ elif 'voc' in dataset_name:
387
+ cfg.update(cfg['VOC'])
388
+ return cfg
389
+ elif 'context' in dataset_name:
390
+ cfg.update(cfg['CONTEXT'])
391
+ return cfg
392
+ elif 'sun' in dataset_name:
393
+ cfg.update(cfg['SUN'])
394
+ return cfg
395
+ elif 'scan' in dataset_name:
396
+ cfg.update(cfg['SCAN'])
397
+ return cfg
398
+ elif 'cityscape' in dataset_name:
399
+ cfg.update(cfg['CITY'])
400
+ return cfg
401
+ elif 'bdd' in dataset_name:
402
+ cfg.update(cfg['BDD'])
403
+ return cfg
404
+ elif 'tsv' in dataset_name:
405
+ cfg.update(cfg['TSV'])
406
+ return cfg
407
+ elif 'phrasecut' in dataset_name:
408
+ cfg.update(cfg['PHRASE'])
409
+ return cfg
410
+ elif 'object365' in dataset_name:
411
+ cfg.update(cfg['OBJECT365'])
412
+ return cfg
413
+ elif 'openimage' in dataset_name:
414
+ cfg.update(cfg['OPENIMAGE'])
415
+ return cfg
416
+ elif 'lvis' in dataset_name:
417
+ cfg.update(cfg['LVIS'])
418
+ return cfg
419
+ elif 'seginw' in dataset_name:
420
+ cfg.update(cfg['SEGINW'])
421
+ return cfg
422
+ elif 'sbd' in dataset_name:
423
+ cfg.update(cfg['SBD'])
424
+ return cfg
425
+ elif 'davis' in dataset_name:
426
+ cfg.update(cfg['DAVIS'])
427
+ return cfg
428
+ elif 'med_sam' in dataset_name:
429
+ cfg.update(cfg['MedSAM'])
430
+ return cfg
431
+ elif 'biomed' in dataset_name:
432
+ cfg.update(cfg['BioMed'])
433
+ return cfg
434
+ elif 'sam' in dataset_name:
435
+ cfg.update(cfg['SAM'])
436
+ return cfg
437
+ else:
438
+ assert False, "dataset not support."
439
+
440
+
441
+ def build_eval_dataloader(cfg, ):
442
+ dataloaders = []
443
+ for dataset_name in cfg['DATASETS']['TEST']:
444
+ cfg = get_config_from_name(cfg, dataset_name)
445
+ # adjust mapper according to dataset
446
+ if dataset_name == 'imagenet_val':
447
+ mapper = ImageNetDatasetMapper(cfg, False)
448
+ elif dataset_name == 'bdd10k_val_sem_seg':
449
+ mapper = BDDSemDatasetMapper(cfg, False)
450
+ elif dataset_name in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017"]:
451
+ mapper = VLPreDatasetMapper(cfg, False, dataset_name)
452
+ elif dataset_name in ["scannet_21_val_seg", "scannet_38_val_seg", "scannet_41_val_seg"]:
453
+ mapper = ScanNetSegDatasetMapper(cfg, False)
454
+ elif dataset_name in ["scannet_21_panoptic_val", 'bdd10k_40_panoptic_val']:
455
+ mapper = ScanNetPanoDatasetMapper(cfg, False)
456
+ elif "pascalvoc_val" in dataset_name:
457
+ mapper = PascalVOCSegDatasetMapperIX(cfg, False, dataset_name)
458
+ elif 'sun' in dataset_name:
459
+ mapper = SunRGBDSegDatasetMapper(cfg, False)
460
+ elif 'refcoco' in dataset_name:
461
+ mapper = RefCOCODatasetMapper(cfg, False)
462
+ elif 'med_sam' in dataset_name:
463
+ mapper = MedSAMDatasetMapper(cfg, False)
464
+ elif 'biomed' in dataset_name:
465
+ mapper = BioMedDatasetMapper(cfg, False)
466
+ else:
467
+ mapper = None
468
+ dataloaders += [build_detection_test_loader(cfg, dataset_name, mapper=mapper)]
469
+ return dataloaders
470
+
471
+
472
+ def build_train_dataloader(cfg, ):
473
+ dataset_names = cfg['DATASETS']['TRAIN']
474
+
475
+ loaders = {}
476
+ for dataset_name in dataset_names:
477
+ cfg = get_config_from_name(cfg, dataset_name)
478
+ mapper_name = cfg['INPUT']['DATASET_MAPPER_NAME']
479
+ # Semantic segmentation dataset mapper
480
+ if mapper_name == "mask_former_semantic":
481
+ mapper = MaskFormerSemanticDatasetMapper(cfg, True)
482
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
483
+ # Panoptic segmentation dataset mapper
484
+ elif mapper_name == "mask_former_panoptic":
485
+ mapper = MaskFormerPanopticDatasetMapper(cfg, True)
486
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
487
+ # Instance segmentation dataset mapper
488
+ elif mapper_name == "mask_former_instance":
489
+ mapper = MaskFormerInstanceDatasetMapper(cfg, True)
490
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
491
+ # coco instance segmentation lsj new baseline
492
+ elif mapper_name == "coco_instance_lsj":
493
+ mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)
494
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
495
+ # coco panoptic segmentation lsj new baseline
496
+ elif mapper_name == "coco_panoptic_lsj":
497
+ mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True)
498
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
499
+ elif mapper_name == "vlpretrain":
500
+ mapper = VLPreDatasetMapper(cfg, True, dataset_name)
501
+ loaders['vlp'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
502
+ elif mapper_name == "refcoco":
503
+ mapper = RefCOCODatasetMapper(cfg, True)
504
+ loaders['ref'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
505
+ elif mapper_name == "coco_interactive":
506
+ mapper = COCOPanopticInteractiveDatasetMapper(cfg, True)
507
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
508
+ elif mapper_name == "medsam_interactive":
509
+ mapper = MedSAMDatasetMapper(cfg, True)
510
+ loaders['med_sam'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
511
+ elif mapper_name == "biomed_interactive":
512
+ mapper = BioMedDatasetMapper(cfg, True)
513
+ name_key = dataset_name.split("_")[1]
514
+ loaders[name_key] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
515
+ else:
516
+ mapper = None
517
+ loaders[dataset_name] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
518
+
519
+ if len(loaders) == 1 or not cfg['LOADER'].get('JOINT', False):
520
+ return list(loaders.values())[0]
521
+ else:
522
+ sample_prob = cfg['LOADER'].get('SAMPLE_PROB', 'prop')
523
+ mixing_level = cfg['LOADER'].get('MIXING_LEVEL', 1)
524
+ return JointLoader(loaders, key_dataset=cfg['LOADER'].get('KEY_DATASET', 'coco'), sample_prob=sample_prob, mixing_level=mixing_level)
525
+
526
+
527
+ def build_evaluator(cfg, dataset_name, output_folder=None):
528
+ """
529
+ Create evaluator(s) for a given dataset.
530
+ This uses the special metadata "evaluator_type" associated with each
531
+ builtin dataset. For your own dataset, you can simply create an
532
+ evaluator manually in your script and do not have to worry about the
533
+ hacky if-else logic here.
534
+ """
535
+ if output_folder is None:
536
+ output_folder = os.path.join(cfg["SAVE_DIR"], "inference")
537
+ evaluator_list = []
538
+ evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
539
+
540
+ # semantic segmentation
541
+ if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
542
+ evaluator_list.append(
543
+ SemSegEvaluator(
544
+ dataset_name,
545
+ distributed=True,
546
+ output_dir=output_folder,
547
+ )
548
+ )
549
+ # instance segmentation
550
+ if evaluator_type == "coco":
551
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
552
+
553
+ cfg_model_decoder_test = cfg["MODEL"]["DECODER"]["TEST"]
554
+ # panoptic segmentation
555
+ if evaluator_type in [
556
+ "coco_panoptic_seg",
557
+ "ade20k_panoptic_seg",
558
+ "cityscapes_panoptic_seg",
559
+ "mapillary_vistas_panoptic_seg",
560
+ "scannet_panoptic_seg",
561
+ "bdd_panoptic_pano"
562
+ ]:
563
+ if cfg_model_decoder_test["PANOPTIC_ON"]:
564
+ evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
565
+ # COCO
566
+ if (evaluator_type == "coco_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]) or evaluator_type == "object365_od":
567
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
568
+ if (evaluator_type == "coco_panoptic_seg" and cfg_model_decoder_test["SEMANTIC_ON"]) or evaluator_type == "coco_sem_seg":
569
+ evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
570
+ # Mapillary Vistas
571
+ if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]:
572
+ evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
573
+ if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg_model_decoder_test["SEMANTIC_ON"]:
574
+ evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
575
+ # Cityscapes
576
+ if evaluator_type == "cityscapes_instance":
577
+ assert (
578
+ torch.cuda.device_count() > comm.get_rank()
579
+ ), "CityscapesEvaluator currently do not work with multiple machines."
580
+ return CityscapesInstanceEvaluator(dataset_name)
581
+ if evaluator_type == "cityscapes_sem_seg":
582
+ assert (
583
+ torch.cuda.device_count() > comm.get_rank()
584
+ ), "CityscapesEvaluator currently do not work with multiple machines."
585
+ return CityscapesSemSegEvaluator(dataset_name)
586
+ if evaluator_type == "cityscapes_panoptic_seg":
587
+ if cfg_model_decoder_test["SEMANTIC_ON"]:
588
+ assert (
589
+ torch.cuda.device_count() > comm.get_rank()
590
+ ), "CityscapesEvaluator currently do not work with multiple machines."
591
+ evaluator_list.append(CityscapesSemSegEvaluator(dataset_name))
592
+ if cfg_model_decoder_test["INSTANCE_ON"]:
593
+ assert (
594
+ torch.cuda.device_count() > comm.get_rank()
595
+ ), "CityscapesEvaluator currently do not work with multiple machines."
596
+ evaluator_list.append(CityscapesInstanceEvaluator(dataset_name))
597
+ # ADE20K
598
+ if evaluator_type == "ade20k_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]:
599
+ evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
600
+ # SEGINW
601
+ if evaluator_type == "seginw" and cfg_model_decoder_test["INSTANCE_ON"]:
602
+ evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
603
+ # LVIS
604
+ if evaluator_type == "lvis":
605
+ return LVISEvaluator(dataset_name, output_dir=output_folder)
606
+ # Classification
607
+ if evaluator_type == "classification":
608
+ evaluator_list.append(ClassificationEvaluator(dataset_name, output_folder))
609
+ # Retrieval
610
+ if evaluator_type in ["retrieval"]:
611
+ evaluator_list.append(RetrievalEvaluator(dataset_name, output_folder, cfg['MODEL']['DECODER']['RETRIEVAL']['ENSEMBLE']))
612
+ if evaluator_type == "captioning":
613
+ evaluator_list.append(CaptioningEvaluator(dataset_name, output_folder, MetadataCatalog.get(dataset_name).gt_json))
614
+ if evaluator_type in ["grounding_refcoco", "grounding_phrasecut", "grounding_spatial", "grounding_entity"]:
615
+ evaluator_list.append(GroundingEvaluator(dataset_name))
616
+ # Interactive
617
+ if evaluator_type in ["interactive", "interactive_grounding"]:
618
+ evaluator_list.append(InteractiveEvaluator(dataset_name, output_dir=output_folder, max_clicks=cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'], iou_iter=cfg['STROKE_SAMPLER']['EVAL']['IOU_ITER']))
619
+
620
+ if len(evaluator_list) == 0:
621
+ raise NotImplementedError(
622
+ "no Evaluator for the dataset {} with the type {}".format(
623
+ dataset_name, evaluator_type
624
+ )
625
+ )
626
+ elif len(evaluator_list) == 1:
627
+ return evaluator_list[0]
628
+
629
+
630
+ return DatasetEvaluators(evaluator_list)
datasets/dataset_mappers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .biomed_dataset_mapper import BioMedDatasetMapper
datasets/dataset_mappers/biomed_dataset_mapper.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py
3
+ import copy
4
+ import logging
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from transformers import AutoTokenizer, LlamaForCausalLM
11
+
12
+ from detectron2.data import detection_utils as utils
13
+ from detectron2.data import transforms as T
14
+ from detectron2.data.transforms import TransformGen
15
+ from detectron2.structures import BitMasks, Boxes, Instances, BoxMode
16
+ from detectron2.structures.boxes import pairwise_iou
17
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
18
+ from detectron2.data import MetadataCatalog
19
+ from pycocotools import mask as coco_mask
20
+
21
+ from utilities import prompt_engineering
22
+ from modeling.language import build_tokenizer
23
+ from modeling.language.misc import text_noun_with_prompt_all
24
+ from modeling.utils import configurable
25
+
26
+ from ..visual_sampler.sampler import build_shape_sampler
27
+
28
+ __all__ = ["BioMedDatasetMapper"]
29
+
30
+
31
+ def build_transform_gen(cfg, is_train):
32
+ """
33
+ Create a list of default :class:`Augmentation` from config.
34
+ Now it includes resizing and flipping.
35
+ Returns:
36
+ list[Augmentation]
37
+ """
38
+ assert is_train, "Only support training augmentation"
39
+ cfg_input = cfg['INPUT']
40
+ image_size = cfg_input['IMAGE_SIZE']
41
+ min_scale = cfg_input['MIN_SCALE']
42
+ max_scale = cfg_input['MAX_SCALE']
43
+
44
+ augmentation = []
45
+
46
+ if cfg_input['RANDOM_FLIP'] != "none":
47
+ augmentation.append(
48
+ T.RandomFlip(
49
+ horizontal=cfg_input['RANDOM_FLIP'] == "horizontal",
50
+ vertical=cfg_input['RANDOM_FLIP'] == "vertical",
51
+ )
52
+ )
53
+
54
+ augmentation.extend([
55
+ T.ResizeScale(
56
+ min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
57
+ ),
58
+ T.FixedSizeCrop(crop_size=(image_size, image_size)),
59
+ ])
60
+
61
+ return augmentation
62
+
63
+ def build_transform_gen_se(cfg, is_train):
64
+ # min_scale = cfg['INPUT']['MIN_SIZE_TEST']
65
+ # max_scale = cfg['INPUT']['MAX_SIZE_TEST']
66
+
67
+ augmentation = []
68
+ # augmentation.extend([
69
+ # T.ResizeShortestEdge(
70
+ # min_scale, max_size=max_scale
71
+ # ),
72
+ # ])
73
+ return augmentation
74
+
75
+ def convert_coco_poly_to_mask(segmentations, height, width):
76
+ masks = []
77
+ for polygons in segmentations:
78
+ rles = coco_mask.frPyObjects(polygons, height, width)
79
+ mask = coco_mask.decode(rles)
80
+ if len(mask.shape) < 3:
81
+ mask = mask[..., None]
82
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
83
+ mask = mask.any(dim=2)
84
+ masks.append(mask)
85
+ if masks:
86
+ masks = torch.stack(masks, dim=0)
87
+ else:
88
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
89
+ return masks
90
+
91
+ # This is specifically designed for the COCO dataset.
92
+ class BioMedDatasetMapper:
93
+ """
94
+ A callable which takes a dataset dict in Detectron2 Dataset format,
95
+ and map it into a format used by MaskFormer.
96
+
97
+ This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
98
+
99
+ The callable currently does the following:
100
+
101
+ 1. Read the image from "file_name"
102
+ 2. Applies geometric transforms to the image and annotation
103
+ 3. Find and applies suitable cropping to the image and annotation
104
+ 4. Prepare image and annotation to Tensors
105
+ """
106
+
107
+ @configurable
108
+ def __init__(
109
+ self,
110
+ is_train=True,
111
+ *,
112
+ tfm_gens,
113
+ image_format,
114
+ caption_thres,
115
+ grounding,
116
+ lvis,
117
+ lvis_thres,
118
+ max_grounding_num,
119
+ shape_sampler,
120
+ retrieval,
121
+ max_token_num,
122
+ tokenizer,
123
+ binary_classes: bool,
124
+ rotate: bool,
125
+ ):
126
+ """
127
+ NOTE: this interface is experimental.
128
+ Args:
129
+ is_train: for training or inference
130
+ augmentations: a list of augmentations or deterministic transforms to apply
131
+ crop_gen: crop augmentation
132
+ tfm_gens: data augmentation
133
+ image_format: an image format supported by :func:`detection_utils.read_image`.
134
+ """
135
+ self.tfm_gens = tfm_gens
136
+ logging.getLogger(__name__).info(
137
+ "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(
138
+ str(self.tfm_gens)
139
+ )
140
+ )
141
+
142
+ self.img_format = image_format
143
+ self.is_train = is_train
144
+ self.caption_thres = caption_thres
145
+ self.grounding = grounding
146
+ self.lvis = lvis
147
+ self.lvis_thres = lvis_thres
148
+ self.max_grounding_num = max_grounding_num
149
+
150
+ self.shape_sampler = shape_sampler
151
+
152
+ self.retrieval = retrieval
153
+ self.tokenizer = tokenizer
154
+ self.max_token_num = max_token_num
155
+
156
+ self.binary_classes = binary_classes
157
+ self.rotate = rotate
158
+
159
+ @classmethod
160
+ def from_config(cls, cfg, is_train=True):
161
+ # Build augmentation
162
+ if is_train:
163
+ tfm_gens = build_transform_gen(cfg, is_train)
164
+ else:
165
+ tfm_gens = build_transform_gen_se(cfg, is_train)
166
+
167
+ shape_sampler = build_shape_sampler(cfg)
168
+
169
+ retrieval = cfg['MODEL']['DECODER']['RETRIEVAL']['ENABLED']
170
+ tokenizer, max_token_num = None, None
171
+ if retrieval:
172
+ lang_model = cfg['MODEL']['TEXT']['NAME']
173
+ max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
174
+ if 'llama' in lang_model:
175
+ tokenizer = AutoTokenizer.from_pretrained(lang_model, padding_side='right')
176
+ tokenizer.pad_token = tokenizer.eos_token
177
+ else:
178
+ tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
179
+
180
+ ret = {
181
+ "is_train": is_train,
182
+ "tfm_gens": tfm_gens,
183
+ "image_format": cfg['INPUT']['FORMAT'],
184
+ "caption_thres": cfg['MODEL']['DECODER']['CAPTION']['SIM_THRES'],
185
+ "grounding": cfg['MODEL']['DECODER']['GROUNDING']['ENABLED'],
186
+ "lvis": cfg['MODEL']['DECODER']['LVIS']['ENABLED'],
187
+ "lvis_thres": cfg['MODEL']['DECODER']['LVIS']['THRES'],
188
+ "max_grounding_num": cfg['MODEL']['DECODER']['GROUNDING']['MAX_LEN'],
189
+ "shape_sampler": shape_sampler,
190
+ "retrieval": retrieval,
191
+ "max_token_num": max_token_num,
192
+ "tokenizer": tokenizer,
193
+ "binary_classes": cfg['MODEL']['ENCODER']['BINARY_CLASSES'],
194
+ "rotate": cfg['INPUT']['RANDOM_ROTATE'],
195
+ }
196
+ return ret
197
+
198
+ def __call__(self, dataset_dict):
199
+ """
200
+ Args:
201
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
202
+
203
+ Returns:
204
+ dict: a format that builtin models in detectron2 accept
205
+ """
206
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
207
+ while True:
208
+ try:
209
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
210
+ break
211
+ except:
212
+ print('Image loading error:', dataset_dict["file_name"])
213
+
214
+ utils.check_image_size(dataset_dict, image)
215
+
216
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
217
+ image_shape = image.shape[:2] # h, w
218
+
219
+ rotate_time = 0
220
+ if self.is_train and self.rotate and random.random() < 0.5:
221
+ rotate_time = random.randint(1, 3)
222
+ if rotate_time > 0:
223
+ image = np.rot90(image, rotate_time)
224
+
225
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
226
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
227
+ # Therefore it's important to use torch.Tensor.
228
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
229
+
230
+
231
+ grounding_anno = dataset_dict['grounding_info']
232
+ if len(grounding_anno) == 0:
233
+ print(dataset_dict['file_name'])
234
+ assert len(grounding_anno) > 0
235
+ masks_grd = []
236
+ texts_grd = []
237
+ boxes_grd = []
238
+ hash_grd = []
239
+ classes = []
240
+ masks_orig = []
241
+ for ann in grounding_anno:
242
+ if 'segmentation' in ann:
243
+ if len(ann['segmentation']) == 0:
244
+ print('Empty segmentation!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
245
+ continue
246
+ rle = coco_mask.frPyObjects(
247
+ ann['segmentation'], dataset_dict['height'], dataset_dict['width'])
248
+ m = coco_mask.decode(rle)
249
+ masks_orig.append(m)
250
+ # sometimes there are multiple binary map (corresponding to multiple segs)
251
+ m = np.sum(m, axis=2)
252
+ else:
253
+ # directly read from mask file
254
+ while True:
255
+ try:
256
+ m = utils.read_image(ann["mask_file"], format=self.img_format)
257
+ break
258
+ except:
259
+ print('Image loading error:', ann["mask_file"])
260
+ m = np.sum(m, axis=2)
261
+ m = 1 * (m > 0)
262
+ m = m.astype(np.uint8) # convert to np.uint8
263
+ m = transforms.apply_segmentation(255*m[:,:,None])[:,:,0]
264
+ if rotate_time > 0:
265
+ m = np.rot90(m, rotate_time)
266
+ masks_grd += [m]
267
+ rand_id = random.randint(0, len(ann['sentences'])-1)
268
+ texts_grd.append(ann['sentences'][rand_id]['raw'].lower())
269
+ hash_grd.append(hash(ann['sentences'][rand_id]['raw'].lower()))
270
+ if self.binary_classes:
271
+ ann["category_id"] = 1 * (ann["category_id"] > 0)
272
+ classes.append(ann["category_id"])
273
+ #masks_grd = torch.from_numpy(np.stack(masks_grd))
274
+ boxes_grd = torch.tensor(boxes_grd)
275
+ groundings = {'masks': masks_grd, 'texts': texts_grd, 'hash': hash_grd, 'mode': 'text'}
276
+ dataset_dict["groundings"] = groundings
277
+
278
+ masks_grd = torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks_grd])
279
+
280
+ instances = Instances(image_shape)
281
+
282
+ instances.gt_masks = BitMasks(masks_grd)
283
+ instances.gt_boxes = BitMasks(masks_grd).get_bounding_boxes()
284
+
285
+ classes = np.array(classes)
286
+ is_things = np.array([1 for _ in classes])
287
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
288
+ instances.is_things = torch.tensor(is_things, dtype=torch.int64)
289
+
290
+ dataset_dict["instances"] = instances
291
+
292
+
293
+ spatial_query_utils = self.shape_sampler(instances)
294
+ dataset_dict['spatial_query'] = spatial_query_utils
295
+
296
+ if self.retrieval:
297
+ captions = dataset_dict['captions']
298
+ tokens = self.tokenizer(
299
+ captions, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
300
+ )
301
+ dataset_dict['tokens'] = {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}
302
+
303
+ if self.grounding:
304
+ grounding_anno = dataset_dict['grounding_info']
305
+ grounding_len = random.randint(1, self.max_grounding_num-1)
306
+ if len(grounding_anno) > 0:
307
+ masks_grd = []
308
+ texts_grd = []
309
+ mode = 'text'
310
+ random.shuffle(grounding_anno)
311
+ for ann in grounding_anno:
312
+ if 'segmentation' in ann:
313
+ if len(ann['segmentation']) == 0:
314
+ print('Empty segmentation!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
315
+ continue
316
+ rle = coco_mask.frPyObjects(
317
+ ann['segmentation'], dataset_dict['height'], dataset_dict['width'])
318
+ m = coco_mask.decode(rle)
319
+ # sometimes there are multiple binary map (corresponding to multiple segs)
320
+ m = np.sum(m, axis=2)
321
+ else:
322
+ # directly read from mask file
323
+ while True:
324
+ try:
325
+ m = utils.read_image(ann["mask_file"], format=self.img_format)
326
+ break
327
+ except:
328
+ print('Image loading error:', ann["mask_file"])
329
+ m = np.sum(m, axis=2)
330
+ m = 1 * (m > 0)
331
+
332
+ m = m.astype(np.uint8) # convert to np.uint8
333
+ m = transforms.apply_segmentation(m[:,:,None])[:,:,0]
334
+ if rotate_time > 0:
335
+ m = np.rot90(m, rotate_time)
336
+ masks_grd += [m]
337
+ # random select a sentence of a single annotation.
338
+ rand_index = random.randint(0, len(ann['sentences'])-1)
339
+ texts_grd += [ann['sentences'][rand_index]['raw'].lower()]
340
+ # max_len = min(grounding_len, len(texts_grd))
341
+ max_len = len(masks_grd)
342
+ indices = np.random.permutation(max_len)
343
+ texts_grd = list(np.array(texts_grd)[indices])
344
+ masks_grd = torch.tensor(np.stack(masks_grd)[indices])
345
+ hash_grd = np.array([hash(txt) for txt in texts_grd])
346
+ else:
347
+ masks_grd = instances.gt_masks.tensor
348
+ mode = 'class'
349
+ if len(masks_grd) == 0:
350
+ masks_grd = torch.tensor([])
351
+ texts_grd = ['none']
352
+ hash_grd = np.array([hash(txt) for txt in texts_grd])
353
+ else:
354
+ biomed_classes = ['liver', 'lung', 'kidney', 'pancreas', 'heart anatomies', 'brain anatomies',
355
+ 'eye anatomies', 'vessel', 'other organ', 'tumor', 'infection', 'other lesion',
356
+ 'fluid disturbance', 'other abnormality', 'histology structure', 'other']
357
+ if self.binary_classes:
358
+ biomed_classes = ['target']
359
+ texts_grd = np.array(biomed_classes)
360
+ hash_grd = np.array([hash(txt) for txt in texts_grd])
361
+ unique_hash_grd = np.unique(hash_grd)
362
+ np.random.shuffle(unique_hash_grd)
363
+ max_len = min(grounding_len, len(unique_hash_grd))
364
+ indices = np.random.permutation(max_len)
365
+ selected_unique_hash_grd = unique_hash_grd[indices]
366
+ selected_mask = np.in1d(hash_grd, selected_unique_hash_grd)
367
+ texts_grd = texts_grd[selected_mask]
368
+ hash_grd = hash_grd[selected_mask]
369
+ masks_grd = masks_grd[selected_mask]
370
+ texts_grd = [prompt_engineering(text.replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
371
+ for text in texts_grd]
372
+ groundings = {'masks': masks_grd, 'texts': texts_grd, 'mode': mode, 'hash': hash_grd}
373
+ dataset_dict["groundings"] = groundings
374
+ assert len(masks_grd) == len(dataset_dict['grounding_info']), f"len(masks_grd)={len(masks_grd)}, len(dataset_dict['grounding_info'])={len(dataset_dict['grounding_info'])}, mask shape={masks_grd.shape}, max_len={max_len}, grounding_len={grounding_len}, len(texts_grd)={len(texts_grd)}, len(hash_grd)={len(hash_grd)}"
375
+ # gt_masks_orisize = torch.stack([torch.from_numpy(m.squeeze(-1)) for m in masks_orig])
376
+ # dataset_dict['gt_masks_orisize'] = gt_masks_orisize # (nm,h,w)
377
+
378
+ return dataset_dict
datasets/evaluation/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .instance_evaluation import *
2
+ from .classification_evaluation import *
3
+ from .segmentation_evaluation import *
4
+ from .retrieval_evaluation import *
5
+ #from .captioning_evaluation import *
6
+ from .panoptic_evaluation import *
7
+ from .grounding_evaluation import *
8
+ from .interactive_evaluation import *
datasets/evaluation/captioning_evaluation.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # --------------------------------------------------------
3
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Modified by Xueyan Zou ([email protected])
7
+ # --------------------------------------------------------
8
+
9
+ import os
10
+ import json
11
+ import logging
12
+ import itertools
13
+
14
+ import detectron2.utils.comm as comm
15
+ from detectron2.evaluation.evaluator import DatasetEvaluator
16
+
17
+ from caption_pycocotools.coco import COCO
18
+ from pycocoevalcap.eval import COCOEvalCap
19
+
20
+
21
+ class CaptioningEvaluator(DatasetEvaluator):
22
+ """
23
+ Evaluate AR for object proposals, AP for instance detection/segmentation, AP
24
+ for keypoint detection outputs using COCO's metrics.
25
+ See http://cocodataset.org/#detection-eval and
26
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
27
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
28
+ the metric cannot be computed (e.g. due to no predictions made).
29
+ In addition to COCO, this evaluator is able to support any bounding box detection,
30
+ instance segmentation, or keypoint detection dataset.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ distributed=True,
36
+ output_dir=None,
37
+ gt_json=None,
38
+ ):
39
+ """
40
+ Args:
41
+ dataset_name (str): name of the dataset to be evaluated.
42
+ It must have either the following corresponding metadata:
43
+ "json_file": the path to the COCO format annotation
44
+ Or it must be in detectron2's standard dataset format
45
+ so it can be converted to COCO format automatically.
46
+ tasks (tuple[str]): tasks that can be evaluated under the given
47
+ configuration. A task is one of "bbox", "segm", "keypoints".
48
+ By default, will infer this automatically from predictions.
49
+ distributed (True): if True, will collect results from all ranks and run evaluation
50
+ in the main process.
51
+ Otherwise, will only evaluate the results in the current process.
52
+ output_dir (str): optional, an output directory to dump all
53
+ results predicted on the dataset. The dump contains two files:
54
+ 1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
55
+ contains all the results in the format they are produced by the model.
56
+ 2. "coco_instances_results.json" a json file in COCO's result format.
57
+ max_dets_per_image (int): limit on the maximum number of detections per image.
58
+ By default in COCO, this limit is to 100, but this can be customized
59
+ to be greater, as is needed in evaluation metrics AP fixed and AP pool
60
+ (see https://arxiv.org/pdf/2102.01066.pdf)
61
+ This doesn't affect keypoint evaluation.
62
+ use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
63
+ Although the results should be very close to the official implementation in COCO
64
+ API, it is still recommended to compute results with the official API for use in
65
+ papers. The faster implementation also uses more RAM.
66
+ kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
67
+ See http://cocodataset.org/#keypoints-eval
68
+ When empty, it will use the defaults in COCO.
69
+ Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
70
+ allow_cached_coco (bool): Whether to use cached coco json from previous validation
71
+ runs. You should set this to False if you need to use different validation data.
72
+ Defaults to True.
73
+ """
74
+ self._logger = logging.getLogger(__name__)
75
+ self._distributed = distributed
76
+ self._output_dir = output_dir
77
+ self._gt_json = COCO(gt_json)
78
+
79
+ def reset(self):
80
+ self._gen_captions = []
81
+ self._image_ids = []
82
+
83
+ def process(self, inputs, outputs):
84
+ """
85
+ Args:
86
+ inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
87
+ It is a list of dict. Each dict corresponds to an image and
88
+ contains keys like "height", "width", "file_name", "image_id".
89
+ outputs: the outputs of a COCO model. It is a list of dicts with key
90
+ "instances" that contains :class:`Instances`.
91
+ """
92
+ for output in outputs:
93
+ self._image_ids.append(output['image_id'])
94
+ self._gen_captions.append(output['captioning_text'])
95
+
96
+ def evaluate(self, img_ids=None):
97
+ """
98
+ Args:
99
+ img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
100
+ """
101
+
102
+ if self._distributed:
103
+ comm.synchronize()
104
+ def gather(x, move=False):
105
+ x = comm.gather(x)
106
+ x = list(itertools.chain(*x))
107
+ if move:
108
+ x = [xx.to(self._gen_captions[0].device) for xx in x]
109
+ return x
110
+ gen_captions = gather(self._gen_captions)
111
+ image_ids = gather(self._image_ids)
112
+ if not comm.is_main_process():
113
+ return {}
114
+ else:
115
+ gen_captions = self._gen_captions
116
+ image_ids = self._image_ids
117
+
118
+ assert len(gen_captions) == len(image_ids)
119
+ pred_captions = [{"image_id": image_id, "caption": gen_caption} for image_id, gen_caption in zip(image_ids, gen_captions)]
120
+ pred_pth = os.path.join(self._output_dir, 'results.json')
121
+ json.dump(pred_captions, open(pred_pth, "w"))
122
+
123
+ gt_captions = self._gt_json
124
+ pred_captions = gt_captions.loadRes(pred_pth)
125
+
126
+ cocoEval = COCOEvalCap(gt_captions, pred_captions)
127
+ cocoEval.params['image_id'] = pred_captions.getImgIds()
128
+ cocoEval.evaluate()
129
+ return cocoEval.eval
datasets/evaluation/classification_evaluation.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # --------------------------------------------------------
3
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Modified by Xueyan Zou ([email protected])
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import logging
11
+
12
+ from detectron2.evaluation.evaluator import DatasetEvaluator
13
+
14
+ from utilities.misc import AverageMeter
15
+ from utilities.distributed import get_world_size
16
+
17
+
18
+ @torch.no_grad()
19
+ def accuracy(output, target, topk=(1,)):
20
+ """Computes the precision@k for the specified values of k"""
21
+ if isinstance(output, list):
22
+ output = output[-1]
23
+
24
+ n_classes = output.size()[1]
25
+ maxk = min(max(topk), n_classes)
26
+ batch_size = target.size(0)
27
+ _, pred = output.topk(maxk, 1, True, True)
28
+ pred = pred.t()
29
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
30
+
31
+ res = []
32
+ for k in topk:
33
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
34
+ res.append(correct_k.mul_(100.0 / batch_size).item())
35
+ return res
36
+
37
+ class ClassificationEvaluator(DatasetEvaluator):
38
+ def __init__(self, *args):
39
+ self.top1 = AverageMeter()
40
+ self.top5 = AverageMeter()
41
+ self._logger = logging.getLogger(__name__)
42
+
43
+ def reset(self):
44
+ self.top1.reset()
45
+ self.top5.reset()
46
+
47
+ def process(self, inputs, outputs):
48
+ logits = torch.stack([o['pred_class'] for o in outputs])
49
+ y = torch.tensor([t['class_id'] for t in inputs], device=logits.device)
50
+ prec1, prec5 = accuracy(logits, y, (1, 5))
51
+ self.top1.update(prec1, y.size(0))
52
+ self.top5.update(prec5, y.size(0))
53
+
54
+ def evaluate(self):
55
+ if get_world_size() > 1:
56
+ tmp_tensor = torch.tensor(
57
+ [self.top1.sum, self.top5.sum, self.top1.count],
58
+ device=torch.cuda.current_device()
59
+ )
60
+ torch.distributed.all_reduce(
61
+ tmp_tensor, torch.distributed.ReduceOp.SUM
62
+ )
63
+ top1_sum, top5_sum, count = tmp_tensor.tolist()
64
+ else:
65
+ top1_sum = self.top1.sum
66
+ top5_sum = self.top5.sum
67
+ count = self.top1.count
68
+
69
+ results = {}
70
+ scores = {
71
+ 'top1': top1_sum / count,
72
+ "top5": top5_sum / count
73
+ }
74
+ results['class'] = scores
75
+ self._logger.info(results)
76
+ return results
datasets/evaluation/grounding_evaluation.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Modified by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+ import logging
8
+ import torch
9
+ from torchvision.ops import box_iou
10
+
11
+ from detectron2.structures import BoxMode
12
+ from detectron2.data import MetadataCatalog
13
+ from detectron2.utils.comm import all_gather, is_main_process, synchronize
14
+ from detectron2.evaluation.evaluator import DatasetEvaluator
15
+
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ import os
19
+
20
+ import copy
21
+
22
+ class GroundingEvaluator(DatasetEvaluator):
23
+ """
24
+ Evaluate grounding segmentation metrics.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ dataset_name,
30
+ compute_box=False,
31
+ distributed=True,
32
+ ):
33
+ self._logger = logging.getLogger(__name__)
34
+ self._dataset_name = dataset_name
35
+ self._distributed = distributed
36
+ self._cpu_device = torch.device("cpu")
37
+ self._compute_box = compute_box
38
+ meta = MetadataCatalog.get(dataset_name)
39
+
40
+ def reset(self):
41
+ self.cum_I = 0
42
+ self.cum_U = 0
43
+ self.mIoU = 0
44
+ self.mDice = 0
45
+ self.cum_mean_area = 0
46
+ self.eval_seg_iou_list = [.5, .6, .7, .8, .9]
47
+ self.seg_correct = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
48
+ self.seg_total = 0
49
+ self.instance_results = []
50
+ if self._compute_box:
51
+ self.mIoU_box = 0
52
+ self.seg_correct_box = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
53
+
54
+ @staticmethod
55
+ def computeIoU(pred_seg, gd_seg):
56
+ I = (pred_seg & gd_seg)
57
+ U = (pred_seg | gd_seg)
58
+ return I, U
59
+
60
+ def get_metadata(self, _input):
61
+ """
62
+ Extracts and returns specific metadata from the input dictionary.
63
+
64
+ Parameters:
65
+ _input (dict): A dictionary containing keys like 'file_name', 'image_id', and 'grounding_info'.
66
+ The 'grounding_info' is a list of dictionaries with keys like 'area', 'iscrowd', etc.
67
+
68
+ Returns:
69
+ dict: A dictionary containing filtered metadata.
70
+ """
71
+
72
+ _input = copy.deepcopy(_input)
73
+
74
+ selected_input_keys = ['file_name', 'image_id', 'grounding_info']
75
+ selected_grounding_info_keys = ['area', 'mask_file', 'iscrowd', 'image_id', 'category_id', 'id', 'file_name', 'split', 'ann_id', 'ref_id']
76
+
77
+ filtered_input = {key: _input[key] for key in selected_input_keys if key in _input}
78
+
79
+ # Check if grounding_info is present and is a list
80
+ if 'grounding_info' in filtered_input and isinstance(filtered_input['grounding_info'], list):
81
+ # Filter each grounding_info dictionary
82
+ filtered_input['grounding_info'] = [
83
+ {key: info[key] for key in selected_grounding_info_keys if key in info}
84
+ for info in filtered_input['grounding_info']
85
+ ]
86
+
87
+ return filtered_input
88
+
89
+ def process(self, inputs, outputs):
90
+ for input, output in zip(inputs, outputs):
91
+ pred = output['grounding_mask'].sigmoid() > 0.5
92
+ # # save pixel probability
93
+ # prob = output['grounding_mask'].sigmoid().cpu().numpy()[0] * 255
94
+ # pred_file = input['file_name'].split('.')[0].replace('test/', 'test_pred/') + '_' + input['groundings']['texts'][0].replace(' ', '+') + '.png'
95
+ # if not os.path.exists('/'.join(pred_file.split('/')[:-1])):
96
+ # os.makedirs('/'.join(pred_file.split('/')[:-1]), exist_ok=True)
97
+ # plt.imsave(pred_file,
98
+ # prob.astype(np.uint8), cmap='gray')
99
+
100
+ gt = input['groundings']['masks'].bool()
101
+ bsi = len(pred)
102
+ I, U = self.computeIoU(pred, gt)
103
+ self.cum_I += I.sum().cpu()
104
+ self.cum_U += U.sum().cpu()
105
+ IoU = I.reshape(bsi,-1).sum(-1)*1.0 / (U.reshape(bsi,-1).sum(-1) + 1e-6)
106
+ self.mIoU += IoU.sum().cpu()
107
+ # Add Dice score in eval
108
+ Dice = I.reshape(bsi,-1).sum(-1)*2.0 / (gt.reshape(bsi,-1).sum(-1) + pred.reshape(bsi,-1).sum(-1) + 1e-6)
109
+ self.mDice += Dice.sum().cpu()
110
+ self.cum_mean_area += ((gt.reshape(bsi,-1).sum(-1) + pred.reshape(bsi,-1).sum(-1)) / 2.0).sum().cpu()
111
+
112
+ if self._compute_box:
113
+ pred_box = BoxMode.convert(output['grounding_box'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
114
+ gt_box = BoxMode.convert(input['groundings']['boxes'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS).cpu()
115
+ IoU_box = box_iou(pred_box, gt_box).diagonal()
116
+ self.mIoU_box += IoU_box.sum()
117
+
118
+ for idx in range(len(self.eval_seg_iou_list)):
119
+ eval_seg_iou = self.eval_seg_iou_list[idx]
120
+ self.seg_correct[idx] += (IoU >= eval_seg_iou).sum().cpu()
121
+ if self._compute_box:
122
+ self.seg_correct_box[idx] += (IoU_box >= eval_seg_iou).sum().cpu()
123
+ self.seg_total += bsi
124
+
125
+ instance_result = {
126
+ 'metadata': self.get_metadata(input),
127
+ 'IoU': IoU.cpu().numpy().tolist(),
128
+ 'Dice': Dice.cpu().numpy().tolist(),
129
+ 'I': I.sum(dim=(1, 2)).cpu().numpy().tolist(),
130
+ 'U': U.sum(dim=(1, 2)).cpu().numpy().tolist(),
131
+ 'IoU_box': IoU_box.cpu().numpy().tolist() if self._compute_box else '',
132
+ 'pred_area': pred.reshape(bsi,-1).sum(-1).cpu().numpy().tolist(),
133
+ }
134
+
135
+ iou_len = IoU.shape[0]
136
+ grounding_info_len = len(self.get_metadata(input)['grounding_info'])
137
+ assert iou_len == grounding_info_len, f'Number of IoU scores ({iou_len}) and grounding info ({grounding_info_len}) do not match.'
138
+ self.instance_results.append(instance_result)
139
+
140
+ def evaluate(self):
141
+ if self._distributed:
142
+ synchronize()
143
+ self.cum_I = torch.stack(all_gather(self.cum_I)).sum()
144
+ self.cum_U = torch.stack(all_gather(self.cum_U)).sum()
145
+ self.mIoU = torch.stack(all_gather(self.mIoU)).sum()
146
+ self.mDice = torch.stack(all_gather(self.mDice)).sum()
147
+ self.cum_mean_area = torch.stack(all_gather(self.cum_mean_area)).sum()
148
+ self.seg_correct = torch.stack(all_gather(self.seg_correct)).sum(0)
149
+ self.seg_total = sum(all_gather(self.seg_total))
150
+ self.instance_results = sum(all_gather(self.instance_results), [])
151
+ if self._compute_box:
152
+ self.mIoU_box = torch.stack(all_gather(self.mIoU_box)).sum()
153
+ self.seg_correct_box = torch.stack(all_gather(self.seg_correct_box)).sum(0)
154
+ if not is_main_process():
155
+ return
156
+
157
+ results = {}
158
+ for idx in range(len(self.eval_seg_iou_list)):
159
+ result_str = 'precision@{}'.format(self.eval_seg_iou_list[idx])
160
+ results[result_str] = (self.seg_correct[idx]*100 / self.seg_total).item()
161
+ results['cIoU'] = (self.cum_I*100./self.cum_U).item()
162
+ results['mIoU'] = (self.mIoU*100./self.seg_total).item()
163
+ results['cDice'] = (self.cum_I*100./self.cum_mean_area).item()
164
+ results['mDice'] = (self.mDice*100./self.seg_total).item()
165
+
166
+ if self._compute_box:
167
+ for idx in range(len(self.eval_seg_iou_list)):
168
+ result_str = 'precisionB@{}'.format(self.eval_seg_iou_list[idx])
169
+ results[result_str] = (self.seg_correct_box[idx]*100 / self.seg_total).item()
170
+ results['mBIoU'] = (self.mIoU_box*100./self.seg_total).item()
171
+
172
+ self._logger.info(results)
173
+ return {'grounding': {'scores': results, 'instance_results': self.instance_results}}
datasets/evaluation/instance_evaluation.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import contextlib
3
+ import copy
4
+ import io
5
+ import itertools
6
+ import json
7
+ import logging
8
+ import numpy as np
9
+ import os
10
+ import pickle
11
+ from collections import OrderedDict
12
+ import pycocotools.mask as mask_util
13
+ import torch
14
+ from pycocotools.coco import COCO
15
+ from pycocotools.cocoeval import COCOeval
16
+ from tabulate import tabulate
17
+
18
+ import detectron2.utils.comm as comm
19
+ from detectron2.config import CfgNode
20
+ from detectron2.data import MetadataCatalog
21
+ from detectron2.data.datasets.coco import convert_to_coco_json
22
+ from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco
23
+ from detectron2.evaluation.fast_eval_api import COCOeval_opt
24
+ from detectron2.structures import Boxes, BoxMode, pairwise_iou
25
+ from detectron2.utils.file_io import PathManager
26
+ from detectron2.utils.logger import create_small_table
27
+
28
+
29
+ # modified from COCOEvaluator for instance segmetnat
30
+ class InstanceSegEvaluator(COCOEvaluator):
31
+ """
32
+ Evaluate AR for object proposals, AP for instance detection/segmentation, AP
33
+ for keypoint detection outputs using COCO's metrics.
34
+ See http://cocodataset.org/#detection-eval and
35
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
36
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
37
+ the metric cannot be computed (e.g. due to no predictions made).
38
+
39
+ In addition to COCO, this evaluator is able to support any bounding box detection,
40
+ instance segmentation, or keypoint detection dataset.
41
+ """
42
+
43
+ def _eval_predictions(self, predictions, img_ids=None):
44
+ """
45
+ Evaluate predictions. Fill self._results with the metrics of the tasks.
46
+ """
47
+ self._logger.info("Preparing results for COCO format ...")
48
+ coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
49
+ tasks = self._tasks or self._tasks_from_predictions(coco_results)
50
+
51
+ # unmap the category ids for COCO
52
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
53
+ dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
54
+ # all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
55
+ # num_classes = len(all_contiguous_ids)
56
+ # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
57
+
58
+ reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
59
+ for result in coco_results:
60
+ category_id = result["category_id"]
61
+ # assert category_id < num_classes, (
62
+ # f"A prediction has class={category_id}, "
63
+ # f"but the dataset only has {num_classes} classes and "
64
+ # f"predicted class id should be in [0, {num_classes - 1}]."
65
+ # )
66
+ assert category_id in reverse_id_mapping, (
67
+ f"A prediction has class={category_id}, "
68
+ f"but the dataset only has class ids in {dataset_id_to_contiguous_id}."
69
+ )
70
+ result["category_id"] = reverse_id_mapping[category_id]
71
+
72
+ if self._output_dir:
73
+ file_path = os.path.join(self._output_dir, "coco_instances_results.json")
74
+ self._logger.info("Saving results to {}".format(file_path))
75
+ with PathManager.open(file_path, "w") as f:
76
+ f.write(json.dumps(coco_results))
77
+ f.flush()
78
+
79
+ if not self._do_evaluation:
80
+ self._logger.info("Annotations are not available for evaluation.")
81
+ return
82
+
83
+ self._logger.info(
84
+ "Evaluating predictions with {} COCO API...".format(
85
+ "unofficial" if self._use_fast_impl else "official"
86
+ )
87
+ )
88
+ for task in sorted(tasks):
89
+ assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
90
+ coco_eval = (
91
+ _evaluate_predictions_on_coco(
92
+ self._coco_api,
93
+ coco_results,
94
+ task,
95
+ kpt_oks_sigmas=self._kpt_oks_sigmas,
96
+ use_fast_impl=self._use_fast_impl,
97
+ img_ids=img_ids,
98
+ max_dets_per_image=self._max_dets_per_image,
99
+ )
100
+ if len(coco_results) > 0
101
+ else None # cocoapi does not handle empty results very well
102
+ )
103
+
104
+ res = self._derive_coco_results(
105
+ coco_eval, task, class_names=self._metadata.get("thing_classes")
106
+ )
107
+ self._results[task] = res
datasets/evaluation/interactive_evaluation.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import os
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torchvision.ops import box_iou
8
+
9
+ from detectron2.structures import BoxMode
10
+ from detectron2.data import MetadataCatalog
11
+ from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize
12
+ from detectron2.evaluation.evaluator import DatasetEvaluator
13
+
14
+
15
+ class InteractiveEvaluator(DatasetEvaluator):
16
+ """
17
+ Evaluate point interactive IoU metrics.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dataset_name,
23
+ output_dir,
24
+ max_clicks=20,
25
+ iou_iter=1,
26
+ compute_box=False,
27
+ distributed=True,
28
+ ):
29
+ self._logger = logging.getLogger(__name__)
30
+ self._dataset_name = dataset_name
31
+ self._distributed = distributed
32
+ self._cpu_device = torch.device("cpu")
33
+ self._output_dir = output_dir
34
+
35
+ self.max_clicks = max_clicks
36
+ self.iou_iter = iou_iter
37
+ meta = MetadataCatalog.get(dataset_name)
38
+
39
+ def reset(self):
40
+ self.iou_list = []
41
+ self.num_samples = 0
42
+ self.all_ious = [0.5, 0.8, 0.85, 0.9]
43
+
44
+ def process(self, inputs, outputs):
45
+ self.iou_list += [o['mask_iou'] for o in outputs]
46
+ self.num_samples += len(outputs)
47
+
48
+ def compute_noc(self):
49
+ def _get_noc(iou_arr, iou_thr):
50
+ vals = iou_arr >= iou_thr
51
+ return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks
52
+
53
+ noc_list = {}
54
+ for iou_thr in self.all_ious:
55
+ scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list]
56
+ noc_list[str(iou_thr)] = scores_arr
57
+
58
+ iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1]
59
+ noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()}
60
+
61
+ if self._distributed:
62
+ num_samples = sum(all_gather(self.num_samples))
63
+ noc_list_sum_gather = all_gather(noc_list_sum)
64
+ iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu())
65
+
66
+ noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]}
67
+ for nlg in noc_list_sum_gather:
68
+ for key, value in nlg.items():
69
+ noc_list_sum[key] += value
70
+
71
+ pred_noc = {}
72
+ if self._distributed and (not is_main_process()):
73
+ return pred_noc
74
+
75
+ for key, value in noc_list_sum.items():
76
+ pred_noc[key] = value / num_samples
77
+
78
+ pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples
79
+ return pred_noc
80
+
81
+ def evaluate(self):
82
+ pred_noc = self.compute_noc()
83
+
84
+ if self._distributed and (not is_main_process()):
85
+ return
86
+
87
+ def draw_iou_curve(iou_list, save_dir):
88
+ iou_list = torch.stack(iou_list, dim=0)
89
+ iou_list = iou_list.mean(dim=0).cpu().numpy()
90
+ # draw iou curve, with x-axis as number of clicks, y-axis as iou using matplotlib
91
+ import matplotlib.pyplot as plt
92
+ plt.figure()
93
+ plt.plot(range(1, self.max_clicks+1), iou_list)
94
+ plt.xlabel('Number of clicks')
95
+ plt.ylabel('IoU')
96
+
97
+
98
+ # create directory if not exist
99
+ import os
100
+ output_dir = os.path.join(save_dir, 'iou_by_clicks')
101
+ if not os.path.exists(output_dir):
102
+ os.makedirs(output_dir)
103
+
104
+ # get current time and format in 10 digits
105
+ import time
106
+ current_time = time.time()
107
+ current_time = int(current_time)
108
+ current_time = str(current_time)
109
+
110
+ # save iou curve
111
+ plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time)))
112
+
113
+ draw_iou_curve(self.iou_list, self._output_dir)
114
+ results = {}
115
+ for idx in range(len(self.all_ious)):
116
+ result_str = 'noc@{}'.format(self.all_ious[idx])
117
+ results[result_str] = pred_noc[str(self.all_ious[idx])]
118
+
119
+ results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter']
120
+
121
+ self._logger.info(results)
122
+ return {'interactive': results}
datasets/evaluation/panoptic_evaluation.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import contextlib
3
+ import io
4
+ import itertools
5
+ import json
6
+ import logging
7
+ import numpy as np
8
+ import os
9
+ import tempfile
10
+ from collections import OrderedDict
11
+ from typing import Optional
12
+ from PIL import Image
13
+ from tabulate import tabulate
14
+
15
+ from detectron2.data import MetadataCatalog
16
+ from detectron2.utils import comm
17
+ from detectron2.utils.file_io import PathManager
18
+
19
+ from detectron2.evaluation.evaluator import DatasetEvaluator
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class COCOPanopticEvaluator(DatasetEvaluator):
25
+ """
26
+ Evaluate Panoptic Quality metrics on COCO using PanopticAPI.
27
+ It saves panoptic segmentation prediction in `output_dir`
28
+
29
+ It contains a synchronize call and has to be called from all workers.
30
+ """
31
+
32
+ def __init__(self, dataset_name: str, output_dir: Optional[str] = None):
33
+ """
34
+ Args:
35
+ dataset_name: name of the dataset
36
+ output_dir: output directory to save results for evaluation.
37
+ """
38
+ self._metadata = MetadataCatalog.get(dataset_name)
39
+ self._thing_contiguous_id_to_dataset_id = {
40
+ v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
41
+ }
42
+ self._stuff_contiguous_id_to_dataset_id = {
43
+ v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items()
44
+ }
45
+
46
+ self._output_dir = output_dir
47
+ if self._output_dir is not None:
48
+ PathManager.mkdirs(self._output_dir)
49
+
50
+ def reset(self):
51
+ self._predictions = []
52
+
53
+ def _convert_category_id(self, segment_info):
54
+ isthing = segment_info.pop("isthing", None)
55
+ if isthing is None:
56
+ # the model produces panoptic category id directly. No more conversion needed
57
+ return segment_info
58
+ if isthing is True:
59
+ segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[
60
+ segment_info["category_id"]
61
+ ]
62
+ else:
63
+ segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[
64
+ segment_info["category_id"]
65
+ ]
66
+ return segment_info
67
+
68
+ def process(self, inputs, outputs):
69
+ from panopticapi.utils import id2rgb
70
+
71
+ for input, output in zip(inputs, outputs):
72
+ panoptic_img, segments_info = output["panoptic_seg"]
73
+ panoptic_img = panoptic_img.cpu().numpy()
74
+ if segments_info is None:
75
+ # If "segments_info" is None, we assume "panoptic_img" is a
76
+ # H*W int32 image storing the panoptic_id in the format of
77
+ # category_id * label_divisor + instance_id. We reserve -1 for
78
+ # VOID label, and add 1 to panoptic_img since the official
79
+ # evaluation script uses 0 for VOID label.
80
+ label_divisor = self._metadata.label_divisor
81
+ segments_info = []
82
+ for panoptic_label in np.unique(panoptic_img):
83
+ if panoptic_label == -1:
84
+ # VOID region.
85
+ continue
86
+ pred_class = panoptic_label // label_divisor
87
+ isthing = (
88
+ pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values()
89
+ )
90
+ segments_info.append(
91
+ {
92
+ "id": int(panoptic_label) + 1,
93
+ "category_id": int(pred_class),
94
+ "isthing": bool(isthing),
95
+ }
96
+ )
97
+ # Official evaluation script uses 0 for VOID label.
98
+ panoptic_img += 1
99
+
100
+ file_name = os.path.basename(input["file_name"])
101
+ file_name_png = os.path.splitext(file_name)[0] + ".png"
102
+ with io.BytesIO() as out:
103
+ Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
104
+ segments_info = [self._convert_category_id(x) for x in segments_info]
105
+ self._predictions.append(
106
+ {
107
+ "image_id": input["image_id"],
108
+ "file_name": file_name_png,
109
+ "png_string": out.getvalue(),
110
+ "segments_info": segments_info,
111
+ }
112
+ )
113
+
114
+ def evaluate(self):
115
+ comm.synchronize()
116
+
117
+ self._predictions = comm.gather(self._predictions)
118
+ self._predictions = list(itertools.chain(*self._predictions))
119
+ if not comm.is_main_process():
120
+ return
121
+
122
+ # PanopticApi requires local files
123
+ gt_json = PathManager.get_local_path(self._metadata.panoptic_json)
124
+ gt_folder = PathManager.get_local_path(self._metadata.panoptic_root)
125
+
126
+ with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
127
+ logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
128
+ for p in self._predictions:
129
+ with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
130
+ f.write(p.pop("png_string"))
131
+
132
+ with open(gt_json, "r") as f:
133
+ json_data = json.load(f)
134
+ json_data["annotations"] = self._predictions
135
+
136
+ output_dir = self._output_dir or pred_dir
137
+ predictions_json = os.path.join(output_dir, "predictions.json")
138
+ with PathManager.open(predictions_json, "w") as f:
139
+ f.write(json.dumps(json_data))
140
+
141
+ from panopticapi.evaluation import pq_compute
142
+
143
+ with contextlib.redirect_stdout(io.StringIO()):
144
+ pq_res = pq_compute(
145
+ gt_json,
146
+ PathManager.get_local_path(predictions_json),
147
+ gt_folder=gt_folder,
148
+ pred_folder=pred_dir,
149
+ )
150
+
151
+ res = {}
152
+ res["PQ"] = 100 * pq_res["All"]["pq"]
153
+ res["SQ"] = 100 * pq_res["All"]["sq"]
154
+ res["RQ"] = 100 * pq_res["All"]["rq"]
155
+ res["PQ_th"] = 100 * pq_res["Things"]["pq"]
156
+ res["SQ_th"] = 100 * pq_res["Things"]["sq"]
157
+ res["RQ_th"] = 100 * pq_res["Things"]["rq"]
158
+ res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
159
+ res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
160
+ res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
161
+
162
+ results = OrderedDict({"panoptic_seg": res})
163
+ _print_panoptic_results(pq_res)
164
+
165
+ return results
166
+
167
+
168
+ def _print_panoptic_results(pq_res):
169
+ headers = ["", "PQ", "SQ", "RQ", "#categories"]
170
+ data = []
171
+ for name in ["All", "Things", "Stuff"]:
172
+ row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
173
+ data.append(row)
174
+ table = tabulate(
175
+ data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
176
+ )
177
+ logger.info("Panoptic Evaluation Results:\n" + table)
178
+
179
+
180
+ if __name__ == "__main__":
181
+ from detectron2.utils.logger import setup_logger
182
+
183
+ logger = setup_logger()
184
+ import argparse
185
+
186
+ parser = argparse.ArgumentParser()
187
+ parser.add_argument("--gt-json")
188
+ parser.add_argument("--gt-dir")
189
+ parser.add_argument("--pred-json")
190
+ parser.add_argument("--pred-dir")
191
+ args = parser.parse_args()
192
+
193
+ from panopticapi.evaluation import pq_compute
194
+
195
+ with contextlib.redirect_stdout(io.StringIO()):
196
+ pq_res = pq_compute(
197
+ args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir
198
+ )
199
+ _print_panoptic_results(pq_res)
datasets/evaluation/retrieval_evaluation.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Modified by Xueyan Zou ([email protected]), Ziyi Dou ([email protected])
6
+ # --------------------------------------------------------
7
+ import copy
8
+ import itertools
9
+ import logging
10
+ from collections import OrderedDict
11
+ import torch
12
+ from pycocotools.cocoeval import COCOeval
13
+
14
+ import detectron2.utils.comm as comm
15
+ from detectron2.evaluation.evaluator import DatasetEvaluator
16
+
17
+ try:
18
+ from detectron2.evaluation.fast_eval_api import COCOeval_opt
19
+ except ImportError:
20
+ COCOeval_opt = COCOeval
21
+
22
+
23
+ class RetrievalEvaluator(DatasetEvaluator):
24
+ """
25
+ Evaluate AR for object proposals, AP for instance detection/segmentation, AP
26
+ for keypoint detection outputs using COCO's metrics.
27
+ See http://cocodataset.org/#detection-eval and
28
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
29
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
30
+ the metric cannot be computed (e.g. due to no predictions made).
31
+ In addition to COCO, this evaluator is able to support any bounding box detection,
32
+ instance segmentation, or keypoint detection dataset.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ dataset_name=None,
38
+ output_dir=None,
39
+ ensemble=False,
40
+ distributed=True,
41
+ ):
42
+ """
43
+ Args:
44
+ dataset_name (str): name of the dataset to be evaluated.
45
+ It must have either the following corresponding metadata:
46
+ "json_file": the path to the COCO format annotation
47
+ Or it must be in detectron2's standard dataset format
48
+ so it can be converted to COCO format automatically.
49
+ tasks (tuple[str]): tasks that can be evaluated under the given
50
+ configuration. A task is one of "bbox", "segm", "keypoints".
51
+ By default, will infer this automatically from predictions.
52
+ distributed (True): if True, will collect results from all ranks and run evaluation
53
+ in the main process.
54
+ Otherwise, will only evaluate the results in the current process.
55
+ output_dir (str): optional, an output directory to dump all
56
+ results predicted on the dataset. The dump contains two files:
57
+ 1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
58
+ contains all the results in the format they are produced by the model.
59
+ 2. "coco_instances_results.json" a json file in COCO's result format.
60
+ max_dets_per_image (int): limit on the maximum number of detections per image.
61
+ By default in COCO, this limit is to 100, but this can be customized
62
+ to be greater, as is needed in evaluation metrics AP fixed and AP pool
63
+ (see https://arxiv.org/pdf/2102.01066.pdf)
64
+ This doesn't affect keypoint evaluation.
65
+ use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
66
+ Although the results should be very close to the official implementation in COCO
67
+ API, it is still recommended to compute results with the official API for use in
68
+ papers. The faster implementation also uses more RAM.
69
+ kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
70
+ See http://cocodataset.org/#keypoints-eval
71
+ When empty, it will use the defaults in COCO.
72
+ Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
73
+ allow_cached_coco (bool): Whether to use cached coco json from previous validation
74
+ runs. You should set this to False if you need to use different validation data.
75
+ Defaults to True.
76
+ """
77
+ self._logger = logging.getLogger(__name__)
78
+ self._dataset_name = dataset_name
79
+ self._output_dir = output_dir
80
+ self._ensemble = ensemble
81
+ self._distributed = distributed
82
+
83
+ if 'p2i' in dataset_name:
84
+ self.mode = 'patch2image'
85
+ elif 'interactive2i' in dataset_name:
86
+ self.mode = 'interactive2image'
87
+ else:
88
+ self.mode = 'default'
89
+
90
+ def reset(self):
91
+ self._text_embeds = []
92
+ self._image_embeds = []
93
+ self._image_embeds2 = []
94
+ self._text_ids = []
95
+ self._image_ids = []
96
+
97
+ def process(self, inputs, outputs):
98
+ """
99
+ Args:
100
+ inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
101
+ It is a list of dict. Each dict corresponds to an image and
102
+ contains keys like "height", "width", "file_name", "image_id".
103
+ outputs: the outputs of a COCO model. It is a list of dicts with key
104
+ "instances" that contains :class:`Instances`.
105
+ """
106
+ for output in outputs:
107
+ self._text_ids.extend(output['caption']['caption_ids'])
108
+ self._image_ids.append(output['caption']['image_ids'])
109
+ self._text_embeds.append(output['caption']['text_embeds'])
110
+ self._image_embeds.append(output['caption']['image_embeds'][0])
111
+ if self._ensemble:
112
+ self._image_embeds2.append(output['caption']['image_embeds'][1])
113
+
114
+ def evaluate(self, img_ids=None):
115
+ if self.mode == 'default':
116
+ return self.evaluate_default(img_ids)
117
+ elif self.mode in ['patch2image', 'interactive2image']:
118
+ return self.evaluate_p2i(img_ids)
119
+ else:
120
+ assert False, "Unknown mode for retrieval evaluation"
121
+
122
+ def evaluate_default(self, img_ids=None):
123
+ """
124
+ Args:
125
+ img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
126
+ """
127
+
128
+ if self._distributed:
129
+ comm.synchronize()
130
+ def gather(x, move=False):
131
+ x = comm.gather(x)
132
+ x = list(itertools.chain(*x))
133
+ if move:
134
+ x = [xx.to(self._text_embeds[0].device) for xx in x]
135
+ return x
136
+ text_embeds = gather(self._text_embeds, move=True)
137
+ image_embeds = gather(self._image_embeds, move=True)
138
+ if self._ensemble:
139
+ image_embeds2 = gather(self._image_embeds2, move=True)
140
+ text_ids = gather(self._text_ids)
141
+ image_ids = gather(self._image_ids)
142
+ if not comm.is_main_process():
143
+ return {}
144
+ else:
145
+ text_embeds = self._text_embeds
146
+ image_embeds = self._image_embeds
147
+ if self._ensemble:
148
+ image_embeds2 = self._image_embeds2
149
+ text_ids = self._text_ids
150
+ image_ids = self._image_ids
151
+ if len(text_embeds) == 0:
152
+ self._logger.warning("[COCOCaptionEvaluator] Did not receive valid predictions.")
153
+ return {}
154
+ iids = torch.tensor(image_ids).view(-1)
155
+ tiids = torch.tensor(text_ids).view(-1)
156
+ image_embeds = torch.cat(image_embeds)
157
+ text_embeds = torch.cat(text_embeds)
158
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
159
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
160
+ scores = image_embeds @ text_embeds.t()
161
+
162
+ if self._ensemble:
163
+ image_embeds2 = torch.cat(image_embeds2)
164
+ image_embeds2 = image_embeds2 / image_embeds2.norm(dim=-1, keepdim=True)
165
+ scores2 = image_embeds2 @ text_embeds.t()
166
+ scores = scores2 * 0.5 + scores * 0.5
167
+
168
+ topk10 = scores.topk(10, dim=1)
169
+ topk5 = scores.topk(5, dim=1)
170
+ topk1 = scores.topk(1, dim=1)
171
+ topk10_iids = tiids[topk10.indices]
172
+ topk5_iids = tiids[topk5.indices]
173
+ topk1_iids = tiids[topk1.indices]
174
+ tr_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
175
+ tr_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
176
+ tr_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()
177
+ topk10 = scores.topk(10, dim=0)
178
+ topk5 = scores.topk(5, dim=0)
179
+ topk1 = scores.topk(1, dim=0)
180
+ topk10_iids = iids[topk10.indices]
181
+ topk5_iids = iids[topk5.indices]
182
+ topk1_iids = iids[topk1.indices]
183
+ ir_r10 = (tiids.unsqueeze(0) == topk10_iids).float().max(dim=0)[0].mean()
184
+ ir_r5 = (tiids.unsqueeze(0) == topk5_iids).float().max(dim=0)[0].mean()
185
+ ir_r1 = (tiids.unsqueeze(0) == topk1_iids).float().max(dim=0)[0].mean()
186
+ self._results = OrderedDict()
187
+ # Copy so the caller can do whatever with results
188
+ self._results['recall'] = {}
189
+ self._results['recall']['irtr'] = float("{:.3f}".format((ir_r1 + tr_r1).item() * 100))
190
+ self._results['recall']['ir1'] = float("{:.3f}".format(ir_r1.item() * 100))
191
+ self._results['recall']['ir5'] = float("{:.3f}".format(ir_r5.item() * 100))
192
+ self._results['recall']['ir10'] = float("{:.3f}".format(ir_r10.item() * 100))
193
+ self._results['recall']['tr1'] = float("{:.3f}".format(tr_r1.item() * 100))
194
+ self._results['recall']['tr5'] = float("{:.3f}".format(tr_r5.item() * 100))
195
+ self._results['recall']['tr10'] = float("{:.3f}".format(tr_r10.item() * 100))
196
+ self._logger.info(self._results)
197
+ return copy.deepcopy(self._results)
198
+
199
+ def evaluate_p2i(self, img_ids=None):
200
+ """
201
+ Args:
202
+ img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
203
+ """
204
+
205
+ if self._distributed:
206
+ comm.synchronize()
207
+ def gather(x, move=False):
208
+ x = comm.gather(x)
209
+ x = list(itertools.chain(*x))
210
+ if move:
211
+ x = [xx.to(self._text_embeds[0].device) for xx in x]
212
+ return x
213
+ text_embeds = gather(self._text_embeds, move=True)
214
+ image_embeds = gather(self._image_embeds, move=True)
215
+ image_embeds2 = gather(self._image_embeds2, move=True)
216
+ text_ids = gather(self._text_ids)
217
+ image_ids = gather(self._image_ids)
218
+ if not comm.is_main_process():
219
+ return {}
220
+ else:
221
+ text_embeds = self._text_embeds
222
+ image_embeds = self._image_embeds
223
+ image_embeds2 = self._image_embeds2
224
+ text_ids = self._text_ids
225
+ image_ids = self._image_ids
226
+
227
+ if len(text_embeds) == 0:
228
+ self._logger.warning("[COCOCaptionEvaluator] Did not receive valid predictions.")
229
+ return {}
230
+
231
+ iids = torch.tensor(image_ids).view(-1)
232
+ tiids = torch.tensor(text_ids).view(-1)
233
+ image_embeds = torch.cat(image_embeds)
234
+ text_embeds = torch.cat(text_embeds)
235
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
236
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
237
+
238
+ image_embeds2 = torch.cat(image_embeds2)
239
+ image_embeds2 = image_embeds2 / image_embeds2.norm(dim=-1, keepdim=True)
240
+
241
+ # compute image to image retrieval
242
+ self._results = OrderedDict()
243
+ self._results['recall'] = {}
244
+ ii_scores = image_embeds2 @ image_embeds.t()
245
+
246
+ topk10 = ii_scores.topk(10, dim=1)
247
+ topk5 = ii_scores.topk(5, dim=1)
248
+ topk1 = ii_scores.topk(1, dim=1)
249
+ topk10_iids = iids[topk10.indices]
250
+ topk5_iids = iids[topk5.indices]
251
+ topk1_iids = iids[topk1.indices]
252
+ iir_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
253
+ iir_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
254
+ iir_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()
255
+ # Copy so the caller can do whatever with results
256
+ self._results['recall']['p2ir1'] = float("{:.3f}".format(iir_r1.item() * 100))
257
+ self._results['recall']['p2ir5'] = float("{:.3f}".format(iir_r5.item() * 100))
258
+ self._results['recall']['p2ir10'] = float("{:.3f}".format(iir_r10.item() * 100))
259
+ self._logger.info(self._results)
260
+ return copy.deepcopy(self._results)
datasets/evaluation/segmentation_evaluation.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import itertools
3
+ import json
4
+ import logging
5
+ import numpy as np
6
+ import os
7
+ from collections import OrderedDict
8
+ import PIL.Image as Image
9
+ import pycocotools.mask as mask_util
10
+ import torch
11
+
12
+ from detectron2.data import DatasetCatalog, MetadataCatalog
13
+ from detectron2.utils.comm import all_gather, is_main_process
14
+ from detectron2.utils.file_io import PathManager
15
+ from detectron2.evaluation.evaluator import DatasetEvaluator
16
+ from utilities.distributed import synchronize
17
+
18
+ from ..semseg_loader import load_semseg
19
+
20
+
21
+ class SemSegEvaluator(DatasetEvaluator):
22
+ """
23
+ Evaluate semantic segmentation metrics.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dataset_name,
29
+ distributed=True,
30
+ output_dir=None,
31
+ *,
32
+ num_classes=None,
33
+ ignore_label=None,
34
+ ):
35
+ """
36
+ Args:
37
+ dataset_name (str): name of the dataset to be evaluated.
38
+ distributed (bool): if True, will collect results from all ranks for evaluation.
39
+ Otherwise, will evaluate the results in the current process.
40
+ output_dir (str): an output directory to dump results.
41
+ num_classes, ignore_label: deprecated argument
42
+ """
43
+ self._logger = logging.getLogger(__name__)
44
+ if num_classes is not None:
45
+ self._logger.warn(
46
+ "SemSegEvaluator(num_classes) is deprecated! It should be obtained from metadata."
47
+ )
48
+ if ignore_label is not None:
49
+ self._logger.warn(
50
+ "SemSegEvaluator(ignore_label) is deprecated! It should be obtained from metadata."
51
+ )
52
+ self._dataset_name = dataset_name
53
+ self._distributed = distributed
54
+ self._output_dir = output_dir
55
+
56
+ self._cpu_device = torch.device("cpu")
57
+
58
+ self.input_file_to_gt_file = {
59
+ dataset_record["file_name"]: dataset_record["sem_seg_file_name"]
60
+ for dataset_record in DatasetCatalog.get(dataset_name)
61
+ }
62
+
63
+ meta = MetadataCatalog.get(dataset_name)
64
+ # Dict that maps contiguous training ids to COCO category ids
65
+ try:
66
+ c2d = meta.stuff_dataset_id_to_contiguous_id
67
+ self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()}
68
+ except AttributeError:
69
+ self._contiguous_id_to_dataset_id = None
70
+ self._class_names = meta.stuff_classes
71
+ self._class_offset = meta.class_offset if hasattr(meta, 'class_offset') else 0
72
+ self._num_classes = len(meta.stuff_classes)
73
+ self._semseg_loader = meta.semseg_loader if hasattr(meta, 'semseg_loader') else 'PIL'
74
+
75
+ if num_classes is not None:
76
+ assert self._num_classes == num_classes, f"{self._num_classes} != {num_classes}"
77
+ self._ignore_label = ignore_label if ignore_label is not None else meta.ignore_label
78
+
79
+ def reset(self):
80
+ self._conf_matrix = np.zeros((self._num_classes + 1, self._num_classes + 1), dtype=np.int64)
81
+ self._predictions = []
82
+
83
+ def process(self, inputs, outputs):
84
+ """
85
+ Args:
86
+ inputs: the inputs to a model.
87
+ It is a list of dicts. Each dict corresponds to an image and
88
+ contains keys like "height", "width", "file_name".
89
+ outputs: the outputs of a model. It is either list of semantic segmentation predictions
90
+ (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
91
+ segmentation prediction in the same format.
92
+ """
93
+ for input, output in zip(inputs, outputs):
94
+ output = output["sem_seg"].argmax(dim=0).to(self._cpu_device)
95
+ pred = np.array(output, dtype=np.int)
96
+
97
+ with PathManager.open(self.input_file_to_gt_file[input["file_name"]], "rb") as f:
98
+ gt = load_semseg(f, self._semseg_loader) - self._class_offset
99
+
100
+ if isinstance(self._ignore_label, int):
101
+ ignore_label = self._ignore_label - self._class_offset
102
+ gt[gt == self._ignore_label] = self._num_classes
103
+ elif isinstance(self._ignore_label, list):
104
+ for ignore_label in self._ignore_label:
105
+ ignore_label = ignore_label - self._class_offset
106
+ gt[gt == ignore_label] = self._num_classes
107
+
108
+ self._conf_matrix += np.bincount(
109
+ (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
110
+ minlength=self._conf_matrix.size,
111
+ ).reshape(self._conf_matrix.shape)
112
+
113
+ self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
114
+
115
+ def evaluate(self):
116
+ """
117
+ Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
118
+
119
+ * Mean intersection-over-union averaged across classes (mIoU)
120
+ * Frequency Weighted IoU (fwIoU)
121
+ * Mean pixel accuracy averaged across classes (mACC)
122
+ * Pixel Accuracy (pACC)
123
+ """
124
+ if self._distributed:
125
+ synchronize()
126
+ conf_matrix_list = all_gather(self._conf_matrix)
127
+ self._predictions = all_gather(self._predictions)
128
+ self._predictions = list(itertools.chain(*self._predictions))
129
+ if not is_main_process():
130
+ return
131
+ self._conf_matrix = np.zeros_like(self._conf_matrix)
132
+ for conf_matrix in conf_matrix_list:
133
+ self._conf_matrix += conf_matrix
134
+
135
+ if self._output_dir:
136
+ PathManager.mkdirs(self._output_dir)
137
+ file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
138
+ with PathManager.open(file_path, "w") as f:
139
+ f.write(json.dumps(self._predictions))
140
+
141
+ acc = np.full(self._num_classes, np.nan, dtype=np.float)
142
+ iou = np.full(self._num_classes, np.nan, dtype=np.float)
143
+ tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
144
+ pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
145
+ class_weights = pos_gt / np.sum(pos_gt)
146
+ pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
147
+ acc_valid = pos_gt > 0
148
+ acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
149
+ iou_valid = (pos_gt + pos_pred) > 0
150
+ union = pos_gt + pos_pred - tp
151
+ iou[acc_valid] = tp[acc_valid] / union[acc_valid]
152
+ macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
153
+ miou = np.sum(iou[acc_valid]) / np.sum(iou_valid)
154
+ fiou = np.sum(iou[acc_valid] * class_weights[acc_valid])
155
+ pacc = np.sum(tp) / np.sum(pos_gt)
156
+
157
+ res = {}
158
+ res["mIoU"] = 100 * miou
159
+ res["fwIoU"] = 100 * fiou
160
+ for i, name in enumerate(self._class_names):
161
+ res["IoU-{}".format(name)] = 100 * iou[i]
162
+ res["mACC"] = 100 * macc
163
+ res["pACC"] = 100 * pacc
164
+ for i, name in enumerate(self._class_names):
165
+ res["ACC-{}".format(name)] = 100 * acc[i]
166
+
167
+ if self._output_dir:
168
+ file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
169
+ with PathManager.open(file_path, "wb") as f:
170
+ torch.save(res, f)
171
+ results = OrderedDict({"sem_seg": res})
172
+ self._logger.info(results)
173
+ return results
174
+
175
+ def encode_json_sem_seg(self, sem_seg, input_file_name):
176
+ """
177
+ Convert semantic segmentation to COCO stuff format with segments encoded as RLEs.
178
+ See http://cocodataset.org/#format-results
179
+ """
180
+ json_list = []
181
+ for label in np.unique(sem_seg):
182
+ if self._contiguous_id_to_dataset_id is not None:
183
+ assert (
184
+ label in self._contiguous_id_to_dataset_id
185
+ ), "Label {} is not in the metadata info for {}".format(label, self._dataset_name)
186
+ dataset_id = self._contiguous_id_to_dataset_id[label]
187
+ else:
188
+ dataset_id = int(label)
189
+ mask = (sem_seg == label).astype(np.uint8)
190
+ mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0]
191
+ mask_rle["counts"] = mask_rle["counts"].decode("utf-8")
192
+ json_list.append(
193
+ {"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle}
194
+ )
195
+ return json_list
datasets/refer.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __author__ = 'licheng'
2
+
3
+ """
4
+ This interface provides access to four datasets:
5
+ 1) refclef
6
+ 2) refcoco
7
+ 3) refcoco+
8
+ 4) refcocog
9
+ split by unc and google
10
+
11
+ The following API functions are defined:
12
+ REFER - REFER api class
13
+ getRefIds - get ref ids that satisfy given filter conditions.
14
+ getAnnIds - get ann ids that satisfy given filter conditions.
15
+ getImgIds - get image ids that satisfy given filter conditions.
16
+ getCatIds - get category ids that satisfy given filter conditions.
17
+ loadRefs - load refs with the specified ref ids.
18
+ loadAnns - load anns with the specified ann ids.
19
+ loadImgs - load images with the specified image ids.
20
+ loadCats - load category names with the specified category ids.
21
+ getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
22
+ showRef - show image, segmentation or box of the referred object with the ref
23
+ getMask - get mask and area of the referred object given ref
24
+ showMask - show mask of the referred object given ref
25
+ """
26
+
27
+ from doctest import REPORT_ONLY_FIRST_FAILURE
28
+ import sys
29
+ import os.path as osp
30
+ import json
31
+ import pickle
32
+ import time
33
+ import itertools
34
+ import skimage.io as io
35
+ import matplotlib.pyplot as plt
36
+ from matplotlib.collections import PatchCollection
37
+ from matplotlib.patches import Polygon, Rectangle
38
+ from pprint import pprint
39
+ import numpy as np
40
+ from pycocotools import mask
41
+ # import cv2
42
+ # from skimage.measure import label, regionprops
43
+
44
+
45
+ class REFER:
46
+ def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
47
+ # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
48
+ # also provide dataset name and splitBy information
49
+ # e.g., dataset = 'refcoco', splitBy = 'unc'
50
+ print('loading dataset {} into memory...'.format(dataset))
51
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
52
+ self.DATA_DIR = osp.join(data_root, dataset)
53
+ if dataset in ['refcoco', 'refcoco+', 'refcocog']:
54
+ self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
55
+ elif dataset == 'refclef':
56
+ self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
57
+ else:
58
+ print('No refer dataset is called [{}]'.format(dataset))
59
+ sys.exit()
60
+
61
+ # load refs from data/dataset/refs(dataset).json
62
+ tic = time.time()
63
+ ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
64
+ self.data = {}
65
+ self.data['dataset'] = dataset
66
+ self.data['refs'] = pickle.load(open(ref_file, 'rb'))
67
+
68
+ # load annotations from data/dataset/instances.json
69
+ instances_file = osp.join(self.DATA_DIR, 'instances.json')
70
+ instances = json.load(open(instances_file, 'r'))
71
+ self.data['images'] = instances['images']
72
+ self.data['annotations'] = instances['annotations']
73
+ self.data['categories'] = instances['categories']
74
+
75
+ # create index
76
+ self.createIndex()
77
+ print('DONE (t=%.2fs)'.format(time.time()-tic))
78
+
79
+ def createIndex(self):
80
+ # create sets of mapping
81
+ # 1) Refs: {ref_id: ref}
82
+ # 2) Anns: {ann_id: ann}
83
+ # 3) Imgs: {image_id: image}
84
+ # 4) Cats: {category_id: category_name}
85
+ # 5) Sents: {sent_id: sent}
86
+ # 6) imgToRefs: {image_id: refs}
87
+ # 7) imgToAnns: {image_id: anns}
88
+ # 8) refToAnn: {ref_id: ann}
89
+ # 9) annToRef: {ann_id: ref}
90
+ # 10) catToRefs: {category_id: refs}
91
+ # 11) sentToRef: {sent_id: ref}
92
+ # 12) sentToTokens: {sent_id: tokens}
93
+ print('creating index...')
94
+ # fetch info from instances
95
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
96
+ for ann in self.data['annotations']:
97
+ Anns[ann['id']] = ann
98
+ imgToAnns[ann['image_id']] = imgToAnns.get(
99
+ ann['image_id'], []) + [ann]
100
+ for img in self.data['images']:
101
+ Imgs[img['id']] = img
102
+ for cat in self.data['categories']:
103
+ Cats[cat['id']] = cat['name']
104
+
105
+ # fetch info from refs
106
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
107
+ Sents, sentToRef, sentToTokens = {}, {}, {}
108
+ for ref in self.data['refs']:
109
+ # ids
110
+ ref_id = ref['ref_id']
111
+ ann_id = ref['ann_id']
112
+ category_id = ref['category_id']
113
+ image_id = ref['image_id']
114
+
115
+ # add mapping related to ref
116
+ Refs[ref_id] = ref
117
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
118
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
119
+ refToAnn[ref_id] = Anns[ann_id]
120
+ annToRef[ann_id] = ref
121
+
122
+ # add mapping of sent
123
+ for sent in ref['sentences']:
124
+ Sents[sent['sent_id']] = sent
125
+ sentToRef[sent['sent_id']] = ref
126
+ sentToTokens[sent['sent_id']] = sent['tokens']
127
+
128
+ # create class members
129
+ self.Refs = Refs
130
+ self.Anns = Anns
131
+ self.Imgs = Imgs
132
+ self.Cats = Cats
133
+ self.Sents = Sents
134
+ self.imgToRefs = imgToRefs
135
+ self.imgToAnns = imgToAnns
136
+ self.refToAnn = refToAnn
137
+ self.annToRef = annToRef
138
+ self.catToRefs = catToRefs
139
+ self.sentToRef = sentToRef
140
+ self.sentToTokens = sentToTokens
141
+ print('index created.')
142
+
143
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
144
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
145
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
146
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
147
+
148
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
149
+ refs = self.data['refs']
150
+ else:
151
+ if not len(image_ids) == 0:
152
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
153
+ else:
154
+ refs = self.data['refs']
155
+ if not len(cat_ids) == 0:
156
+ refs = [ref for ref in refs if ref['category_id'] in cat_ids]
157
+ if not len(ref_ids) == 0:
158
+ refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
159
+ if not len(split) == 0:
160
+ if split in ['testA', 'testB', 'testC']:
161
+ # we also consider testAB, testBC, ...
162
+ refs = [ref for ref in refs if split[-1] in ref['split']]
163
+ elif split in ['testAB', 'testBC', 'testAC']:
164
+ # rarely used I guess...
165
+ refs = [ref for ref in refs if ref['split'] == split]
166
+ elif split == 'test':
167
+ refs = [ref for ref in refs if 'test' in ref['split']]
168
+ elif split == 'train' or split == 'val':
169
+ refs = [ref for ref in refs if ref['split'] == split]
170
+ else:
171
+ print('No such split [{}]'.format(split))
172
+ sys.exit()
173
+ ref_ids = [ref['ref_id'] for ref in refs]
174
+ return ref_ids
175
+
176
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
177
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
178
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
179
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
180
+
181
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
182
+ ann_ids = [ann['id'] for ann in self.data['annotations']]
183
+ else:
184
+ if not len(image_ids) == 0:
185
+ lists = [self.imgToAnns[image_id]
186
+ for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
187
+ anns = list(itertools.chain.from_iterable(lists))
188
+ else:
189
+ anns = self.data['annotations']
190
+ if not len(cat_ids) == 0:
191
+ anns = [ann for ann in anns if ann['category_id'] in cat_ids]
192
+ ann_ids = [ann['id'] for ann in anns]
193
+ if not len(ref_ids) == 0:
194
+ ids = set(ann_ids).intersection(
195
+ set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
196
+ return ann_ids
197
+
198
+ def getImgIds(self, ref_ids=[]):
199
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
200
+
201
+ if not len(ref_ids) == 0:
202
+ image_ids = list(set([self.Refs[ref_id]['image_id']
203
+ for ref_id in ref_ids]))
204
+ else:
205
+ image_ids = self.Imgs.keys()
206
+ return image_ids
207
+
208
+ def getCatIds(self):
209
+ return self.Cats.keys()
210
+
211
+ def loadRefs(self, ref_ids=[]):
212
+ if type(ref_ids) == list:
213
+ return [self.Refs[ref_id] for ref_id in ref_ids]
214
+ elif type(ref_ids) == int:
215
+ return [self.Refs[ref_ids]]
216
+
217
+ def loadAnns(self, ann_ids=[]):
218
+ if type(ann_ids) == list:
219
+ return [self.Anns[ann_id] for ann_id in ann_ids]
220
+ elif type(ann_ids) == int or type(ann_ids) == unicode:
221
+ return [self.Anns[ann_ids]]
222
+
223
+ def loadImgs(self, image_ids=[]):
224
+ if type(image_ids) == list:
225
+ return [self.Imgs[image_id] for image_id in image_ids]
226
+ elif type(image_ids) == int:
227
+ return [self.Imgs[image_ids]]
228
+
229
+ def loadCats(self, cat_ids=[]):
230
+ if type(cat_ids) == list:
231
+ return [self.Cats[cat_id] for cat_id in cat_ids]
232
+ elif type(cat_ids) == int:
233
+ return [self.Cats[cat_ids]]
234
+
235
+ def getRefBox(self, ref_id):
236
+ ref = self.Refs[ref_id]
237
+ ann = self.refToAnn[ref_id]
238
+ return ann['bbox'] # [x, y, w, h]
239
+
240
+ def showRef(self, ref, seg_box='seg'):
241
+ ax = plt.gca()
242
+ # show image
243
+ image = self.Imgs[ref['image_id']]
244
+ I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
245
+ ax.imshow(I)
246
+ # show refer expression
247
+ for sid, sent in enumerate(ref['sentences']):
248
+ print('{}. {}'.format(sid+1, sent['sent']))
249
+ # show segmentations
250
+ if seg_box == 'seg':
251
+ ann_id = ref['ann_id']
252
+ ann = self.Anns[ann_id]
253
+ polygons = []
254
+ color = []
255
+ c = 'none'
256
+ if type(ann['segmentation'][0]) == list:
257
+ # polygon used for refcoco*
258
+ for seg in ann['segmentation']:
259
+ poly = np.array(seg).reshape((len(seg)/2, 2))
260
+ polygons.append(Polygon(poly, True, alpha=0.4))
261
+ color.append(c)
262
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(
263
+ 1, 1, 0, 0), linewidths=3, alpha=1)
264
+ ax.add_collection(p) # thick yellow polygon
265
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(
266
+ 1, 0, 0, 0), linewidths=1, alpha=1)
267
+ ax.add_collection(p) # thin red polygon
268
+ else:
269
+ # mask used for refclef
270
+ rle = ann['segmentation']
271
+ m = mask.decode(rle)
272
+ img = np.ones((m.shape[0], m.shape[1], 3))
273
+ color_mask = np.array([2.0, 166.0, 101.0])/255
274
+ for i in range(3):
275
+ img[:, :, i] = color_mask[i]
276
+ ax.imshow(np.dstack((img, m*0.5)))
277
+ # show bounding-box
278
+ elif seg_box == 'box':
279
+ ann_id = ref['ann_id']
280
+ ann = self.Anns[ann_id]
281
+ bbox = self.getRefBox(ref['ref_id'])
282
+ box_plot = Rectangle(
283
+ (bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
284
+ ax.add_patch(box_plot)
285
+
286
+ def getMask(self, ref):
287
+ # return mask, area and mask-center
288
+ ann = self.refToAnn[ref['ref_id']]
289
+ image = self.Imgs[ref['image_id']]
290
+ if type(ann['segmentation'][0]) == list: # polygon
291
+ rle = mask.frPyObjects(
292
+ ann['segmentation'], image['height'], image['width'])
293
+ else:
294
+ rle = ann['segmentation']
295
+ m = mask.decode(rle)
296
+ # sometimes there are multiple binary map (corresponding to multiple segs)
297
+ m = np.sum(m, axis=2)
298
+ m = m.astype(np.uint8) # convert to np.uint8
299
+ # compute area
300
+ area = sum(mask.area(rle)) # should be close to ann['area']
301
+ return {'mask': m, 'area': area}
302
+ # # position
303
+ # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
304
+ # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
305
+ # # mass position (if there were multiple regions, we use the largest one.)
306
+ # label_m = label(m, connectivity=m.ndim)
307
+ # regions = regionprops(label_m)
308
+ # if len(regions) > 0:
309
+ # largest_id = np.argmax(np.array([props.filled_area for props in regions]))
310
+ # largest_props = regions[largest_id]
311
+ # mass_y, mass_x = largest_props.centroid
312
+ # else:
313
+ # mass_x, mass_y = position_x, position_y
314
+ # # if centroid is not in mask, we find the closest point to it from mask
315
+ # if m[mass_y, mass_x] != 1:
316
+ # print 'Finding closes mask point ...'
317
+ # kernel = np.ones((10, 10),np.uint8)
318
+ # me = cv2.erode(m, kernel, iterations = 1)
319
+ # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
320
+ # points = np.array(points)
321
+ # dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
322
+ # id = np.argsort(dist)[0]
323
+ # mass_y, mass_x = points[id]
324
+ # # return
325
+ # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
326
+ # # show image and mask
327
+ # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
328
+ # plt.figure()
329
+ # plt.imshow(I)
330
+ # ax = plt.gca()
331
+ # img = np.ones( (m.shape[0], m.shape[1], 3) )
332
+ # color_mask = np.array([2.0,166.0,101.0])/255
333
+ # for i in range(3):
334
+ # img[:,:,i] = color_mask[i]
335
+ # ax.imshow(np.dstack( (img, m*0.5) ))
336
+ # plt.show()
337
+
338
+ def showMask(self, ref):
339
+ M = self.getMask(ref)
340
+ msk = M['mask']
341
+ ax = plt.gca()
342
+ ax.imshow(msk)
343
+
344
+
345
+ if __name__ == '__main__':
346
+ refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg',
347
+ dataset='refcocog', splitBy='google')
348
+ ref_ids = refer.getRefIds()
349
+ print(len(ref_ids))
350
+
351
+ print(len(refer.Imgs))
352
+ print(len(refer.imgToRefs))
353
+
354
+ ref_ids = refer.getRefIds(split='train')
355
+ print('There are {} training referred objects.' % len(ref_ids))
356
+
357
+ for ref_id in ref_ids:
358
+ ref = refer.loadRefs(ref_id)[0]
359
+ if len(ref['sentences']) < 2:
360
+ continue
361
+
362
+ pprint(ref)
363
+ print('The label is {}.'.format(refer.Cats[ref['category_id']]))
364
+
365
+ # plt.figure()
366
+ # refer.showRef(ref, seg_box='box')
367
+ # plt.show()
368
+
369
+ # plt.figure()
370
+ # refer.showMask(ref)
371
+ # plt.show()
datasets/registration/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import (
2
+ register_biomed_datasets
3
+ )
datasets/registration/register_biomed_datasets.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Modified by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+ import json
8
+ import os
9
+ import collections
10
+
11
+ from detectron2.data import DatasetCatalog, MetadataCatalog
12
+ from detectron2.data.datasets import load_sem_seg
13
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
14
+ from detectron2.utils.file_io import PathManager
15
+
16
+
17
+ _PREDEFINED_SPLITS_BIOMED = {}
18
+
19
+ # example of registering a dataset
20
+ datasets = ['BiomedParseData-Demo', ] # provide name of the dataset under biomedparse_datasets
21
+ splits = ['demo'] # provide split name, e.g., train, test, val. Here there is only one 'demo' split in the example demo dataset
22
+
23
+ # Here we register all the splits of the dataset
24
+ for name in datasets:
25
+ for split in splits:
26
+ dataname = f'biomed_{name.replace("/", "-")}_{split}'
27
+ image_root = f"{name}/{split}"
28
+ ann_root = f"{name}/{split}.json"
29
+ _PREDEFINED_SPLITS_BIOMED[dataname] = (image_root, ann_root)
30
+ # The resulting dataset name is: biomed_BiomedParseData-Demo_demo
31
+
32
+ # # Add your dataset here
33
+ # datasets = ['YOUR_DATASET_NAME', ] # provide name of the dataset under biomedparse_datasets
34
+ # splits = ['train', 'test'] # provide split name, e.g., train, test, val
35
+
36
+ # # Here we register all the splits of the dataset
37
+ # for name in datasets:
38
+ # for split in splits:
39
+ # dataname = f'biomed_{name.replace("/", "-")}_{split}'
40
+ # image_root = f"{name}/{split}"
41
+ # ann_root = f"{name}/{split}.json"
42
+ # _PREDEFINED_SPLITS_BIOMED[dataname] = (image_root, ann_root)
43
+ # # The resulting dataset names are: biomed_YOUR_DATASET_NAME_train, biomed_YOUR_DATASET_NAME_test
44
+
45
+
46
+ def get_metadata():
47
+ meta = {}
48
+ return meta
49
+
50
+
51
+ def load_biomed_json(image_root, annot_json, metadata):
52
+ """
53
+ Args:
54
+ image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
55
+ gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
56
+ json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
57
+ Returns:
58
+ list[dict]: a list of dicts in Detectron2 standard format. (See
59
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
60
+ """
61
+
62
+ with PathManager.open(annot_json) as f:
63
+ json_info = json.load(f)
64
+
65
+ # build dictionary for grounding
66
+ grd_dict = collections.defaultdict(list)
67
+ for grd_ann in json_info['annotations']:
68
+ image_id = int(grd_ann["image_id"])
69
+ grd_dict[image_id].append(grd_ann)
70
+
71
+ mask_root = image_root + '_mask'
72
+ ret = []
73
+ for image in json_info["images"]:
74
+ image_id = int(image["id"])
75
+ image_file = os.path.join(image_root, image['file_name'])
76
+ grounding_anno = grd_dict[image_id]
77
+ for ann in grounding_anno:
78
+ if 'mask_file' not in ann:
79
+ ann['mask_file'] = image['file_name']
80
+ ann['mask_file'] = os.path.join(mask_root, ann['mask_file'])
81
+ ret.append(
82
+ {
83
+ "file_name": image_file,
84
+ "image_id": image_id,
85
+ "grounding_info": [ann],
86
+ }
87
+ )
88
+ assert len(ret), f"No images found in {image_root}!"
89
+ assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
90
+ return ret
91
+
92
+
93
+ def register_biomed(
94
+ name, metadata, image_root, annot_json):
95
+ DatasetCatalog.register(
96
+ name,
97
+ lambda: load_biomed_json(image_root, annot_json, metadata),
98
+ )
99
+ MetadataCatalog.get(name).set(
100
+ image_root=image_root,
101
+ json_file=annot_json,
102
+ evaluator_type="grounding_refcoco",
103
+ ignore_label=255,
104
+ label_divisor=1000,
105
+ **metadata,
106
+ )
107
+
108
+
109
+ def register_all_biomed(root):
110
+ for (
111
+ prefix,
112
+ (image_root, annot_root),
113
+ ) in _PREDEFINED_SPLITS_BIOMED.items():
114
+ register_biomed(
115
+ prefix,
116
+ get_metadata(),
117
+ os.path.join(root, image_root),
118
+ os.path.join(root, annot_root),
119
+ )
120
+
121
+
122
+ _root = os.getenv("DATASET", "datasets")
123
+ register_all_biomed(_root)
datasets/semseg_loader.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import scipy.io
3
+ import numpy as np
4
+
5
+ def load_semseg(filename, loader_type):
6
+ if loader_type == 'PIL':
7
+ semseg = np.array(Image.open(filename), dtype=np.int)
8
+ elif loader_type == 'MAT':
9
+ semseg = scipy.io.loadmat(filename)['LabelMap']
10
+ return semseg
datasets/utils/refcoco2json.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from refer import REFER
4
+
5
+ coco_root = '/pth/to/coco'
6
+ ref_root = '/pth/to/refcocoseg'
7
+
8
+ coco_train_annot = json.load(open(os.path.join(coco_root, 'annotations/instances_train2017.json')))
9
+ coco_train_id = []
10
+ image_annot = {}
11
+ for i in range(len(coco_train_annot['images'])):
12
+ coco_train_id.append(coco_train_annot['images'][i]['id'])
13
+ image_annot[coco_train_annot['images'][i]['id']] = coco_train_annot['images'][i]
14
+
15
+ refg = REFER(data_root=ref_root,
16
+ dataset='refcocog', splitBy='umd')
17
+ refg_val_ids = refg.getRefIds(split='val')
18
+
19
+ full_anno = []
20
+ for ref_id in refg_val_ids:
21
+ ref = refg.loadRefs(ref_id)[0]
22
+ anno = refg.refToAnn[ref_id]
23
+ anno.update(ref)
24
+ full_anno.append(anno)
25
+
26
+ imageid_list = []
27
+ final_anno = {}
28
+ for anno in full_anno:
29
+ imageid_list += [anno['image_id']]
30
+ final_anno[anno['ann_id']] = anno
31
+
32
+ annotations = [value for key, value in final_anno.items()]
33
+
34
+ iamges = []
35
+ for image_id in list(set(imageid_list)):
36
+ iamges += [image_annot[image_id]]
37
+
38
+ outputs = {'images': iamges, 'annotations': annotations}
39
+ print(len(iamges))
40
+ print(len(annotations))
41
+ json.dump(outputs, open(os.path.join(coco_root, 'annotations/refcocog_umd_train.json'), 'w'))
datasets/utils/refer.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is modified from https://github.com/lichengunc/refer, and with minor modification of python2/3 format
2
+ __author__ = 'licheng'
3
+
4
+ """
5
+ This interface provides access to four datasets:
6
+ 1) refclef
7
+ 2) refcoco
8
+ 3) refcoco+
9
+ 4) refcocog
10
+ split by unc and google
11
+
12
+ The following API functions are defined:
13
+ REFER - REFER api class
14
+ getRefIds - get ref ids that satisfy given filter conditions.
15
+ getAnnIds - get ann ids that satisfy given filter conditions.
16
+ getImgIds - get image ids that satisfy given filter conditions.
17
+ getCatIds - get category ids that satisfy given filter conditions.
18
+ loadRefs - load refs with the specified ref ids.
19
+ loadAnns - load anns with the specified ann ids.
20
+ loadImgs - load images with the specified image ids.
21
+ loadCats - load category names with the specified category ids.
22
+ getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
23
+ showRef - show image, segmentation or box of the referred object with the ref
24
+ getMask - get mask and area of the referred object given ref
25
+ showMask - show mask of the referred object given ref
26
+ """
27
+
28
+ from doctest import REPORT_ONLY_FIRST_FAILURE
29
+ import sys
30
+ import os.path as osp
31
+ import json
32
+ import pickle
33
+ import time
34
+ import itertools
35
+ import skimage.io as io
36
+ import matplotlib.pyplot as plt
37
+ from matplotlib.collections import PatchCollection
38
+ from matplotlib.patches import Polygon, Rectangle
39
+ from pprint import pprint
40
+ import numpy as np
41
+ from pycocotools import mask
42
+ # import cv2
43
+ # from skimage.measure import label, regionprops
44
+
45
+
46
+ class REFER:
47
+ def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
48
+ # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
49
+ # also provide dataset name and splitBy information
50
+ # e.g., dataset = 'refcoco', splitBy = 'unc'
51
+ print('loading dataset {} into memory...'.format(dataset))
52
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
53
+ self.DATA_DIR = osp.join(data_root, dataset)
54
+ if dataset in ['refcoco', 'refcoco+', 'refcocog']:
55
+ self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
56
+ elif dataset == 'refclef':
57
+ self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
58
+ else:
59
+ print('No refer dataset is called [{}]'.format(dataset))
60
+ sys.exit()
61
+
62
+ # load refs from data/dataset/refs(dataset).json
63
+ tic = time.time()
64
+ ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
65
+ self.data = {}
66
+ self.data['dataset'] = dataset
67
+ self.data['refs'] = pickle.load(open(ref_file, 'rb'))
68
+
69
+ # load annotations from data/dataset/instances.json
70
+ instances_file = osp.join(self.DATA_DIR, 'instances.json')
71
+ instances = json.load(open(instances_file, 'r'))
72
+ self.data['images'] = instances['images']
73
+ self.data['annotations'] = instances['annotations']
74
+ self.data['categories'] = instances['categories']
75
+
76
+ # create index
77
+ self.createIndex()
78
+ print('DONE (t=%.2fs)'.format(time.time()-tic))
79
+
80
+ def createIndex(self):
81
+ # create sets of mapping
82
+ # 1) Refs: {ref_id: ref}
83
+ # 2) Anns: {ann_id: ann}
84
+ # 3) Imgs: {image_id: image}
85
+ # 4) Cats: {category_id: category_name}
86
+ # 5) Sents: {sent_id: sent}
87
+ # 6) imgToRefs: {image_id: refs}
88
+ # 7) imgToAnns: {image_id: anns}
89
+ # 8) refToAnn: {ref_id: ann}
90
+ # 9) annToRef: {ann_id: ref}
91
+ # 10) catToRefs: {category_id: refs}
92
+ # 11) sentToRef: {sent_id: ref}
93
+ # 12) sentToTokens: {sent_id: tokens}
94
+ print('creating index...')
95
+ # fetch info from instances
96
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
97
+ for ann in self.data['annotations']:
98
+ Anns[ann['id']] = ann
99
+ imgToAnns[ann['image_id']] = imgToAnns.get(
100
+ ann['image_id'], []) + [ann]
101
+ for img in self.data['images']:
102
+ Imgs[img['id']] = img
103
+ for cat in self.data['categories']:
104
+ Cats[cat['id']] = cat['name']
105
+
106
+ # fetch info from refs
107
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
108
+ Sents, sentToRef, sentToTokens = {}, {}, {}
109
+ for ref in self.data['refs']:
110
+ # ids
111
+ ref_id = ref['ref_id']
112
+ ann_id = ref['ann_id']
113
+ category_id = ref['category_id']
114
+ image_id = ref['image_id']
115
+
116
+ # add mapping related to ref
117
+ Refs[ref_id] = ref
118
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
119
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
120
+ refToAnn[ref_id] = Anns[ann_id]
121
+ annToRef[ann_id] = ref
122
+
123
+ # add mapping of sent
124
+ for sent in ref['sentences']:
125
+ Sents[sent['sent_id']] = sent
126
+ sentToRef[sent['sent_id']] = ref
127
+ sentToTokens[sent['sent_id']] = sent['tokens']
128
+
129
+ # create class members
130
+ self.Refs = Refs
131
+ self.Anns = Anns
132
+ self.Imgs = Imgs
133
+ self.Cats = Cats
134
+ self.Sents = Sents
135
+ self.imgToRefs = imgToRefs
136
+ self.imgToAnns = imgToAnns
137
+ self.refToAnn = refToAnn
138
+ self.annToRef = annToRef
139
+ self.catToRefs = catToRefs
140
+ self.sentToRef = sentToRef
141
+ self.sentToTokens = sentToTokens
142
+ print('index created.')
143
+
144
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
145
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
146
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
147
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
148
+
149
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
150
+ refs = self.data['refs']
151
+ else:
152
+ if not len(image_ids) == 0:
153
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
154
+ else:
155
+ refs = self.data['refs']
156
+ if not len(cat_ids) == 0:
157
+ refs = [ref for ref in refs if ref['category_id'] in cat_ids]
158
+ if not len(ref_ids) == 0:
159
+ refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
160
+ if not len(split) == 0:
161
+ if split in ['testA', 'testB', 'testC']:
162
+ # we also consider testAB, testBC, ...
163
+ refs = [ref for ref in refs if split[-1] in ref['split']]
164
+ elif split in ['testAB', 'testBC', 'testAC']:
165
+ # rarely used I guess...
166
+ refs = [ref for ref in refs if ref['split'] == split]
167
+ elif split == 'test':
168
+ refs = [ref for ref in refs if 'test' in ref['split']]
169
+ elif split == 'train' or split == 'val':
170
+ refs = [ref for ref in refs if ref['split'] == split]
171
+ else:
172
+ print('No such split [{}]'.format(split))
173
+ sys.exit()
174
+ ref_ids = [ref['ref_id'] for ref in refs]
175
+ return ref_ids
176
+
177
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
178
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
179
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
180
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
181
+
182
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
183
+ ann_ids = [ann['id'] for ann in self.data['annotations']]
184
+ else:
185
+ if not len(image_ids) == 0:
186
+ lists = [self.imgToAnns[image_id]
187
+ for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
188
+ anns = list(itertools.chain.from_iterable(lists))
189
+ else:
190
+ anns = self.data['annotations']
191
+ if not len(cat_ids) == 0:
192
+ anns = [ann for ann in anns if ann['category_id'] in cat_ids]
193
+ ann_ids = [ann['id'] for ann in anns]
194
+ if not len(ref_ids) == 0:
195
+ ids = set(ann_ids).intersection(
196
+ set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
197
+ return ann_ids
198
+
199
+ def getImgIds(self, ref_ids=[]):
200
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
201
+
202
+ if not len(ref_ids) == 0:
203
+ image_ids = list(set([self.Refs[ref_id]['image_id']
204
+ for ref_id in ref_ids]))
205
+ else:
206
+ image_ids = self.Imgs.keys()
207
+ return image_ids
208
+
209
+ def getCatIds(self):
210
+ return self.Cats.keys()
211
+
212
+ def loadRefs(self, ref_ids=[]):
213
+ if type(ref_ids) == list:
214
+ return [self.Refs[ref_id] for ref_id in ref_ids]
215
+ elif type(ref_ids) == int:
216
+ return [self.Refs[ref_ids]]
217
+
218
+ def loadAnns(self, ann_ids=[]):
219
+ if type(ann_ids) == list:
220
+ return [self.Anns[ann_id] for ann_id in ann_ids]
221
+ elif type(ann_ids) == int or type(ann_ids) == unicode:
222
+ return [self.Anns[ann_ids]]
223
+
224
+ def loadImgs(self, image_ids=[]):
225
+ if type(image_ids) == list:
226
+ return [self.Imgs[image_id] for image_id in image_ids]
227
+ elif type(image_ids) == int:
228
+ return [self.Imgs[image_ids]]
229
+
230
+ def loadCats(self, cat_ids=[]):
231
+ if type(cat_ids) == list:
232
+ return [self.Cats[cat_id] for cat_id in cat_ids]
233
+ elif type(cat_ids) == int:
234
+ return [self.Cats[cat_ids]]
235
+
236
+ def getRefBox(self, ref_id):
237
+ ref = self.Refs[ref_id]
238
+ ann = self.refToAnn[ref_id]
239
+ return ann['bbox'] # [x, y, w, h]
240
+
241
+ def showRef(self, ref, seg_box='seg'):
242
+ ax = plt.gca()
243
+ # show image
244
+ image = self.Imgs[ref['image_id']]
245
+ I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
246
+ ax.imshow(I)
247
+ # show refer expression
248
+ for sid, sent in enumerate(ref['sentences']):
249
+ print('{}. {}'.format(sid+1, sent['sent']))
250
+ # show segmentations
251
+ if seg_box == 'seg':
252
+ ann_id = ref['ann_id']
253
+ ann = self.Anns[ann_id]
254
+ polygons = []
255
+ color = []
256
+ c = 'none'
257
+ if type(ann['segmentation'][0]) == list:
258
+ # polygon used for refcoco*
259
+ for seg in ann['segmentation']:
260
+ poly = np.array(seg).reshape((len(seg)/2, 2))
261
+ polygons.append(Polygon(poly, True, alpha=0.4))
262
+ color.append(c)
263
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(
264
+ 1, 1, 0, 0), linewidths=3, alpha=1)
265
+ ax.add_collection(p) # thick yellow polygon
266
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(
267
+ 1, 0, 0, 0), linewidths=1, alpha=1)
268
+ ax.add_collection(p) # thin red polygon
269
+ else:
270
+ # mask used for refclef
271
+ rle = ann['segmentation']
272
+ m = mask.decode(rle)
273
+ img = np.ones((m.shape[0], m.shape[1], 3))
274
+ color_mask = np.array([2.0, 166.0, 101.0])/255
275
+ for i in range(3):
276
+ img[:, :, i] = color_mask[i]
277
+ ax.imshow(np.dstack((img, m*0.5)))
278
+ # show bounding-box
279
+ elif seg_box == 'box':
280
+ ann_id = ref['ann_id']
281
+ ann = self.Anns[ann_id]
282
+ bbox = self.getRefBox(ref['ref_id'])
283
+ box_plot = Rectangle(
284
+ (bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
285
+ ax.add_patch(box_plot)
286
+
287
+ def getMask(self, ref):
288
+ # return mask, area and mask-center
289
+ ann = self.refToAnn[ref['ref_id']]
290
+ image = self.Imgs[ref['image_id']]
291
+ if type(ann['segmentation'][0]) == list: # polygon
292
+ rle = mask.frPyObjects(
293
+ ann['segmentation'], image['height'], image['width'])
294
+ else:
295
+ rle = ann['segmentation']
296
+ m = mask.decode(rle)
297
+ # sometimes there are multiple binary map (corresponding to multiple segs)
298
+ m = np.sum(m, axis=2)
299
+ m = m.astype(np.uint8) # convert to np.uint8
300
+ # compute area
301
+ area = sum(mask.area(rle)) # should be close to ann['area']
302
+ return {'mask': m, 'area': area}
303
+ # # position
304
+ # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
305
+ # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
306
+ # # mass position (if there were multiple regions, we use the largest one.)
307
+ # label_m = label(m, connectivity=m.ndim)
308
+ # regions = regionprops(label_m)
309
+ # if len(regions) > 0:
310
+ # largest_id = np.argmax(np.array([props.filled_area for props in regions]))
311
+ # largest_props = regions[largest_id]
312
+ # mass_y, mass_x = largest_props.centroid
313
+ # else:
314
+ # mass_x, mass_y = position_x, position_y
315
+ # # if centroid is not in mask, we find the closest point to it from mask
316
+ # if m[mass_y, mass_x] != 1:
317
+ # print 'Finding closes mask point ...'
318
+ # kernel = np.ones((10, 10),np.uint8)
319
+ # me = cv2.erode(m, kernel, iterations = 1)
320
+ # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
321
+ # points = np.array(points)
322
+ # dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
323
+ # id = np.argsort(dist)[0]
324
+ # mass_y, mass_x = points[id]
325
+ # # return
326
+ # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
327
+ # # show image and mask
328
+ # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
329
+ # plt.figure()
330
+ # plt.imshow(I)
331
+ # ax = plt.gca()
332
+ # img = np.ones( (m.shape[0], m.shape[1], 3) )
333
+ # color_mask = np.array([2.0,166.0,101.0])/255
334
+ # for i in range(3):
335
+ # img[:,:,i] = color_mask[i]
336
+ # ax.imshow(np.dstack( (img, m*0.5) ))
337
+ # plt.show()
338
+
339
+ def showMask(self, ref):
340
+ M = self.getMask(ref)
341
+ msk = M['mask']
342
+ ax = plt.gca()
343
+ ax.imshow(msk)
344
+
345
+
346
+ if __name__ == '__main__':
347
+ refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg',
348
+ dataset='refcocog', splitBy='google')
349
+ ref_ids = refer.getRefIds()
350
+ print(len(ref_ids))
351
+
352
+ print(len(refer.Imgs))
353
+ print(len(refer.imgToRefs))
354
+
355
+ ref_ids = refer.getRefIds(split='train')
356
+ print('There are {} training referred objects.' % len(ref_ids))
357
+
358
+ for ref_id in ref_ids:
359
+ ref = refer.loadRefs(ref_id)[0]
360
+ if len(ref['sentences']) < 2:
361
+ continue
362
+
363
+ pprint(ref)
364
+ print('The label is {}.'.format(refer.Cats[ref['category_id']]))
365
+
366
+ # plt.figure()
367
+ # refer.showRef(ref, seg_box='box')
368
+ # plt.show()
369
+
370
+ # plt.figure()
371
+ # refer.showMask(ref)
372
+ # plt.show()
datasets/visual_sampler/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sampler import ShapeSampler
2
+ from .simpleclick_sampler import SimpleClickSampler
3
+
4
+
5
+ def build_shape_sampler(cfg, **kwargs):
6
+ sampler_name = cfg['STROKE_SAMPLER']['EVAL']['MODE']
7
+ if sampler_name == 'random':
8
+ return ShapeSampler(cfg, **kwargs)
9
+ elif sampler_name in ['best', 'best_random']:
10
+ return SimpleClickSampler(cfg, **kwargs)
11
+ else:
12
+ assert False, "not implemented"
datasets/visual_sampler/circle.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+ from .mask_generators import get_mask_by_input_strokes
5
+
6
+ class Circle:
7
+ def __init__(self, cfg, is_train=True):
8
+ self.num_stroke = cfg['STROKE_SAMPLER']['CIRCLE']['NUM_STROKES']
9
+ self.stroke_preset = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PRESET']
10
+ self.stroke_prob = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PROB']
11
+ self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
12
+ self.is_train = is_train
13
+
14
+ @staticmethod
15
+ def get_stroke_preset(stroke_preset):
16
+ if stroke_preset == 'object_like':
17
+ return {
18
+ "nVertexBound": [5, 30],
19
+ "maxHeadSpeed": 15,
20
+ "maxHeadAcceleration": (10, 1.5),
21
+ "brushWidthBound": (20, 50),
22
+ "nMovePointRatio": 0.5,
23
+ "maxPiontMove": 10,
24
+ "maxLineAcceleration": (5, 0.5),
25
+ "boarderGap": None,
26
+ "maxInitSpeed": 10,
27
+ }
28
+ elif stroke_preset == 'object_like_middle':
29
+ return {
30
+ "nVertexBound": [5, 15],
31
+ "maxHeadSpeed": 8,
32
+ "maxHeadAcceleration": (4, 1.5),
33
+ "brushWidthBound": (20, 50),
34
+ "nMovePointRatio": 0.5,
35
+ "maxPiontMove": 5,
36
+ "maxLineAcceleration": (5, 0.5),
37
+ "boarderGap": None,
38
+ "maxInitSpeed": 10,
39
+ }
40
+ elif stroke_preset == 'object_like_small':
41
+ return {
42
+ "nVertexBound": [5, 20],
43
+ "maxHeadSpeed": 7,
44
+ "maxHeadAcceleration": (3.5, 1.5),
45
+ "brushWidthBound": (10, 30),
46
+ "nMovePointRatio": 0.5,
47
+ "maxPiontMove": 5,
48
+ "maxLineAcceleration": (3, 0.5),
49
+ "boarderGap": None,
50
+ "maxInitSpeed": 4,
51
+ }
52
+ else:
53
+ raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
54
+
55
+ def get_random_points_from_mask(self, mask, n=5):
56
+ h,w = mask.shape
57
+ view_mask = mask.reshape(h*w)
58
+ non_zero_idx = view_mask.nonzero()[:,0]
59
+ selected_idx = torch.randperm(len(non_zero_idx))[:n]
60
+ non_zero_idx = non_zero_idx[selected_idx]
61
+ y = (non_zero_idx // w)*1.0
62
+ x = (non_zero_idx % w)*1.0
63
+ return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
64
+
65
+ def draw(self, mask=None, box=None):
66
+ if mask.sum() < 10: # if mask is nearly empty
67
+ return torch.zeros(mask.shape).bool()
68
+ if not self.is_train:
69
+ return self.draw_eval(mask=mask, box=box)
70
+ stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
71
+ preset = Circle.get_stroke_preset(stroke_preset_name)
72
+ nStroke = min(random.randint(1, self.num_stroke), mask.sum().item())
73
+ h,w = mask.shape
74
+ points = self.get_random_points_from_mask(mask, n=nStroke)
75
+ rand_mask = get_mask_by_input_strokes(
76
+ init_points=points,
77
+ imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
78
+ rand_mask = (~torch.from_numpy(rand_mask)) * mask
79
+ return rand_mask
80
+
81
+ def draw_eval(self, mask=None, box=None):
82
+ stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
83
+ preset = Circle.get_stroke_preset(stroke_preset_name)
84
+ nStroke = min(self.max_eval, mask.sum().item())
85
+ h,w = mask.shape
86
+ points = self.get_random_points_from_mask(mask, n=nStroke)
87
+ rand_masks = []
88
+ for i in range(len(points)):
89
+ rand_mask = get_mask_by_input_strokes(
90
+ init_points=points[:i+1],
91
+ imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points[:i+1])), **preset)
92
+ rand_masks += [(~torch.from_numpy(rand_mask)) * mask]
93
+ return torch.stack(rand_masks)
94
+
95
+ @staticmethod
96
+ def draw_by_points(points, mask, h, w):
97
+ stroke_preset_name = random.choices(['object_like', 'object_like_middle', 'object_like_small'], weights=[0.33,0.33,0.33], k=1)[0] # select which kind of object to use
98
+ preset = Circle.get_stroke_preset(stroke_preset_name)
99
+ rand_mask = get_mask_by_input_strokes(
100
+ init_points=points,
101
+ imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
102
+ rand_masks = (~torch.from_numpy(rand_mask)) * mask
103
+ return rand_masks
104
+
105
+ def __repr__(self,):
106
+ return 'circle'
datasets/visual_sampler/mask_generators.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ from PIL import Image, ImageDraw
4
+
5
+
6
+ def get_mask_by_input_strokes(
7
+ init_points, imageWidth=320, imageHeight=180, nStroke=5,
8
+ nVertexBound=[10, 30], maxHeadSpeed=15, maxHeadAcceleration=(15, 0.5),
9
+ brushWidthBound=(5, 20), boarderGap=None, nMovePointRatio=0.5, maxPiontMove=10,
10
+ maxLineAcceleration=5, maxInitSpeed=5
11
+ ):
12
+ '''
13
+ Get video masks by random strokes which move randomly between each
14
+ frame, including the whole stroke and its control points
15
+
16
+ Parameters
17
+ ----------
18
+ imageWidth: Image width
19
+ imageHeight: Image height
20
+ nStroke: Number of drawed lines
21
+ nVertexBound: Lower/upper bound of number of control points for each line
22
+ maxHeadSpeed: Max head speed when creating control points
23
+ maxHeadAcceleration: Max acceleration applying on the current head point (
24
+ a head point and its velosity decides the next point)
25
+ brushWidthBound (min, max): Bound of width for each stroke
26
+ boarderGap: The minimum gap between image boarder and drawed lines
27
+ nMovePointRatio: The ratio of control points to move for next frames
28
+ maxPiontMove: The magnitude of movement for control points for next frames
29
+ maxLineAcceleration: The magnitude of acceleration for the whole line
30
+
31
+ Examples
32
+ ----------
33
+ object_like_setting = {
34
+ "nVertexBound": [5, 20],
35
+ "maxHeadSpeed": 15,
36
+ "maxHeadAcceleration": (15, 3.14),
37
+ "brushWidthBound": (30, 50),
38
+ "nMovePointRatio": 0.5,
39
+ "maxPiontMove": 10,
40
+ "maxLineAcceleration": (5, 0.5),
41
+ "boarderGap": 20,
42
+ "maxInitSpeed": 10,
43
+ }
44
+ rand_curve_setting = {
45
+ "nVertexBound": [10, 30],
46
+ "maxHeadSpeed": 20,
47
+ "maxHeadAcceleration": (15, 0.5),
48
+ "brushWidthBound": (3, 10),
49
+ "nMovePointRatio": 0.5,
50
+ "maxPiontMove": 3,
51
+ "maxLineAcceleration": (5, 0.5),
52
+ "boarderGap": 20,
53
+ "maxInitSpeed": 6
54
+ }
55
+ get_video_masks_by_moving_random_stroke(video_len=5, nStroke=3, **object_like_setting)
56
+ '''
57
+ # Initilize a set of control points to draw the first mask
58
+ mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
59
+ control_points_set = []
60
+ for i in range(nStroke):
61
+ brushWidth = np.random.randint(brushWidthBound[0], brushWidthBound[1])
62
+ Xs, Ys, velocity = get_random_stroke_control_points(
63
+ init_point=init_points[i],
64
+ imageWidth=imageWidth, imageHeight=imageHeight,
65
+ nVertexBound=nVertexBound, maxHeadSpeed=maxHeadSpeed,
66
+ maxHeadAcceleration=maxHeadAcceleration, boarderGap=boarderGap,
67
+ maxInitSpeed=maxInitSpeed
68
+ )
69
+ control_points_set.append((Xs, Ys, velocity, brushWidth))
70
+ draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
71
+
72
+ # Generate the following masks by randomly move strokes and their control points
73
+ mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
74
+ for j in range(len(control_points_set)):
75
+ Xs, Ys, velocity, brushWidth = control_points_set[j]
76
+ new_Xs, new_Ys = random_move_control_points(
77
+ Xs, Ys, velocity, nMovePointRatio, maxPiontMove,
78
+ maxLineAcceleration, boarderGap
79
+ )
80
+ control_points_set[j] = (new_Xs, new_Ys, velocity, brushWidth)
81
+ for Xs, Ys, velocity, brushWidth in control_points_set:
82
+ draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
83
+
84
+ return np.array(mask)
85
+
86
+
87
+ def random_accelerate(velocity, maxAcceleration, dist='uniform'):
88
+ speed, angle = velocity
89
+ d_speed, d_angle = maxAcceleration
90
+
91
+ if dist == 'uniform':
92
+ speed += np.random.uniform(-d_speed, d_speed)
93
+ angle += np.random.uniform(-d_angle, d_angle)
94
+ elif dist == 'guassian':
95
+ speed += np.random.normal(0, d_speed / 2)
96
+ angle += np.random.normal(0, d_angle / 2)
97
+ else:
98
+ raise NotImplementedError(f'Distribution type {dist} is not supported.')
99
+
100
+ return (speed, angle)
101
+
102
+
103
+ def random_move_control_points(Xs, Ys, lineVelocity, nMovePointRatio, maxPiontMove, maxLineAcceleration, boarderGap=15):
104
+ new_Xs = Xs.copy()
105
+ new_Ys = Ys.copy()
106
+
107
+ # move the whole line and accelerate
108
+ speed, angle = lineVelocity
109
+ new_Xs += int(speed * np.cos(angle))
110
+ new_Ys += int(speed * np.sin(angle))
111
+ lineVelocity = random_accelerate(lineVelocity, maxLineAcceleration, dist='guassian')
112
+
113
+ # choose points to move
114
+ chosen = np.arange(len(Xs))
115
+ np.random.shuffle(chosen)
116
+ chosen = chosen[:int(len(Xs) * nMovePointRatio)]
117
+ for i in chosen:
118
+ new_Xs[i] += np.random.randint(-maxPiontMove, maxPiontMove)
119
+ new_Ys[i] += np.random.randint(-maxPiontMove, maxPiontMove)
120
+ return new_Xs, new_Ys
121
+
122
+
123
+ def get_random_stroke_control_points(
124
+ init_point,
125
+ imageWidth, imageHeight,
126
+ nVertexBound=(10, 30), maxHeadSpeed=10, maxHeadAcceleration=(5, 0.5), boarderGap=20,
127
+ maxInitSpeed=10
128
+ ):
129
+ '''
130
+ Implementation the free-form training masks generating algorithm
131
+ proposed by JIAHUI YU et al. in "Free-Form Image Inpainting with Gated Convolution"
132
+ '''
133
+ startX = init_point[0]
134
+ startY = init_point[1]
135
+
136
+ Xs = [init_point[0]]
137
+ Ys = [init_point[1]]
138
+
139
+ numVertex = np.random.randint(nVertexBound[0], nVertexBound[1])
140
+
141
+ angle = np.random.uniform(0, 2 * np.pi)
142
+ speed = np.random.uniform(0, maxHeadSpeed)
143
+
144
+ for i in range(numVertex):
145
+ speed, angle = random_accelerate((speed, angle), maxHeadAcceleration)
146
+ speed = np.clip(speed, 0, maxHeadSpeed)
147
+
148
+ nextX = startX + speed * np.sin(angle)
149
+ nextY = startY + speed * np.cos(angle)
150
+
151
+ if boarderGap is not None:
152
+ nextX = np.clip(nextX, boarderGap, imageWidth - boarderGap)
153
+ nextY = np.clip(nextY, boarderGap, imageHeight - boarderGap)
154
+
155
+ startX, startY = nextX, nextY
156
+ Xs.append(nextX)
157
+ Ys.append(nextY)
158
+
159
+ velocity = get_random_velocity(maxInitSpeed, dist='guassian')
160
+
161
+ return np.array(Xs), np.array(Ys), velocity
162
+
163
+
164
+ def get_random_velocity(max_speed, dist='uniform'):
165
+ if dist == 'uniform':
166
+ speed = np.random.uniform(max_speed)
167
+ elif dist == 'guassian':
168
+ speed = np.abs(np.random.normal(0, max_speed / 2))
169
+ else:
170
+ raise NotImplementedError(f'Distribution type {dist} is not supported.')
171
+
172
+ angle = np.random.uniform(0, 2 * np.pi)
173
+ return (speed, angle)
174
+
175
+
176
+ def draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=255):
177
+ radius = brushWidth // 2 - 1
178
+ for i in range(1, len(Xs)):
179
+ draw = ImageDraw.Draw(mask)
180
+ startX, startY = Xs[i - 1], Ys[i - 1]
181
+ nextX, nextY = Xs[i], Ys[i]
182
+ draw.line((startX, startY) + (nextX, nextY), fill=fill, width=brushWidth)
183
+ for x, y in zip(Xs, Ys):
184
+ draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=fill)
185
+ return mask
186
+
187
+
188
+ # modified from https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/generate_data.py
189
+ def get_random_walk_mask(imageWidth=320, imageHeight=180, length=None):
190
+ action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
191
+ canvas = np.zeros((imageHeight, imageWidth)).astype("i")
192
+ if length is None:
193
+ length = imageWidth * imageHeight
194
+ x = random.randint(0, imageHeight - 1)
195
+ y = random.randint(0, imageWidth - 1)
196
+ x_list = []
197
+ y_list = []
198
+ for i in range(length):
199
+ r = random.randint(0, len(action_list) - 1)
200
+ x = np.clip(x + action_list[r][0], a_min=0, a_max=imageHeight - 1)
201
+ y = np.clip(y + action_list[r][1], a_min=0, a_max=imageWidth - 1)
202
+ x_list.append(x)
203
+ y_list.append(y)
204
+ canvas[np.array(x_list), np.array(y_list)] = 1
205
+ return Image.fromarray(canvas * 255).convert('1')
206
+
207
+
208
+ def get_masked_ratio(mask):
209
+ """
210
+ Calculate the masked ratio.
211
+ mask: Expected a binary PIL image, where 0 and 1 represent
212
+ masked(invalid) and valid pixel values.
213
+ """
214
+ hist = mask.histogram()
215
+ return hist[0] / np.prod(mask.size)
datasets/visual_sampler/point.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from scipy import ndimage
6
+
7
+
8
+ class Point:
9
+ def __init__(self, cfg, is_train=True):
10
+ self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS']
11
+ self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
12
+ self.is_train = is_train
13
+
14
+ def draw(self, mask=None, box=None):
15
+ if mask.sum() < 10:
16
+ return torch.zeros(mask.shape).bool() # if mask is empty
17
+ if not self.is_train:
18
+ return self.draw_eval(mask=mask, box=box)
19
+ max_points = min(self.max_points, mask.sum().item()) # max number of points no more than total mask number
20
+ num_points = random.randint(1, max_points) # get a random number of points
21
+ h,w = mask.shape
22
+ view_mask = mask.view(-1)
23
+ non_zero_idx = view_mask.nonzero()[:,0] # get non-zero index of mask
24
+ selected_idx = torch.randperm(len(non_zero_idx))[:num_points] # select id
25
+ non_zero_idx = non_zero_idx[selected_idx] # select non-zero index
26
+ rand_mask = torch.zeros(view_mask.shape).bool() # init rand mask
27
+ rand_mask[non_zero_idx] = True # get non zero place to zero
28
+ # dilate
29
+ # struct = ndimage.generate_binary_structure(2, 2)
30
+ # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
31
+ # return rand_mask
32
+ return rand_mask.reshape(h, w)
33
+
34
+ def draw_eval(self, mask=None, box=None):
35
+ background = ~mask
36
+ neg_num = min(self.max_eval // 2, background.sum().item())
37
+ pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1
38
+
39
+ h,w = mask.shape
40
+ view_mask = mask.view(-1)
41
+ non_zero_idx_pos = view_mask.nonzero()[:,0] # get non-zero index of mask
42
+ selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] # select id
43
+ non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] # select non-zero index
44
+ pos_idx = torch.ones(non_zero_idx_pos.shape)
45
+
46
+ view_background = background.view(-1)
47
+ non_zero_idx_neg = view_background.nonzero()[:,0] # get non-zero index of mask
48
+ selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] # select id
49
+ non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] # select non-zero index
50
+ neg_idx = torch.ones(non_zero_idx_neg.shape) * -1
51
+
52
+ non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg])
53
+ idx = torch.cat([pos_idx, neg_idx])
54
+ rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long()
55
+ non_zero_idx = non_zero_idx[rand_idx]
56
+ idx = idx[rand_idx]
57
+
58
+ rand_masks = []
59
+ for i in range(0, len(non_zero_idx)):
60
+ rand_mask = torch.zeros(view_mask.shape) # init rand mask
61
+ rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] # get non zero place to zero
62
+ # struct = ndimage.generate_binary_structure(2, 2)
63
+ # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
64
+ rand_masks += [rand_mask.reshape(h, w)]
65
+
66
+ # kernel_size = 3
67
+ rand_masks = torch.stack(rand_masks)
68
+ # rand_masks = F.conv2d(rand_masks[:,None], torch.ones(1,1,kernel_size,kernel_size), padding=kernel_size//2)[:,0]
69
+ # rand_masks[rand_masks>0] = 1
70
+ # rand_masks[rand_masks<0] = -1
71
+ return rand_masks
72
+
73
+ def __repr__(self,):
74
+ return 'point'
datasets/visual_sampler/polygon.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ from scipy.special import binom
6
+ from scipy import ndimage
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
9
+
10
+ bernstein = lambda n, k, t: binom(n,k)* t**k * (1.-t)**(n-k)
11
+
12
+ def bezier(points, num=200):
13
+ N = len(points)
14
+ t = np.linspace(0, 1, num=num)
15
+ curve = np.zeros((num, 2))
16
+ for i in range(N):
17
+ curve += np.outer(bernstein(N - 1, i, t), points[i])
18
+ return curve
19
+
20
+ class Segment():
21
+ def __init__(self, p1, p2, angle1, angle2, **kw):
22
+ self.p1 = p1; self.p2 = p2
23
+ self.angle1 = angle1; self.angle2 = angle2
24
+ self.numpoints = kw.get("numpoints", 100)
25
+ r = kw.get("r", 0.3)
26
+ d = np.sqrt(np.sum((self.p2-self.p1)**2))
27
+ self.r = r*d
28
+ self.p = np.zeros((4,2))
29
+ self.p[0,:] = self.p1[:]
30
+ self.p[3,:] = self.p2[:]
31
+ self.calc_intermediate_points(self.r)
32
+
33
+ def calc_intermediate_points(self,r):
34
+ self.p[1,:] = self.p1 + np.array([self.r*np.cos(self.angle1),
35
+ self.r*np.sin(self.angle1)])
36
+ self.p[2,:] = self.p2 + np.array([self.r*np.cos(self.angle2+np.pi),
37
+ self.r*np.sin(self.angle2+np.pi)])
38
+ self.curve = bezier(self.p,self.numpoints)
39
+
40
+ def get_curve(points, **kw):
41
+ segments = []
42
+ for i in range(len(points)-1):
43
+ seg = Segment(points[i,:2], points[i+1,:2], points[i,2],points[i+1,2],**kw)
44
+ segments.append(seg)
45
+ curve = np.concatenate([s.curve for s in segments])
46
+ return segments, curve
47
+
48
+ def ccw_sort(p):
49
+ d = p-np.mean(p,axis=0)
50
+ s = np.arctan2(d[:,0], d[:,1])
51
+ return p[np.argsort(s),:]
52
+
53
+ def get_bezier_curve(a, rad=0.2, edgy=0):
54
+ """ given an array of points *a*, create a curve through
55
+ those points.
56
+ *rad* is a number between 0 and 1 to steer the distance of
57
+ control points.
58
+ *edgy* is a parameter which controls how "edgy" the curve is,
59
+ edgy=0 is smoothest."""
60
+ p = np.arctan(edgy)/np.pi+.5
61
+ a = ccw_sort(a)
62
+ a = np.append(a, np.atleast_2d(a[0,:]), axis=0)
63
+ d = np.diff(a, axis=0)
64
+ ang = np.arctan2(d[:,1],d[:,0])
65
+ f = lambda ang : (ang>=0)*ang + (ang<0)*(ang+2*np.pi)
66
+ ang = f(ang)
67
+ ang1 = ang
68
+ ang2 = np.roll(ang,1)
69
+ ang = p*ang1 + (1-p)*ang2 + (np.abs(ang2-ang1) > np.pi )*np.pi
70
+ ang = np.append(ang, [ang[0]])
71
+ a = np.append(a, np.atleast_2d(ang).T, axis=1)
72
+ s, c = get_curve(a, r=rad, method="var")
73
+ x,y = c.T
74
+ return x,y,a
75
+
76
+ class Polygon:
77
+ def __init__(self, cfg, is_train):
78
+ self.max_points = cfg['STROKE_SAMPLER']['POLYGON']['MAX_POINTS']
79
+ self.eval_points = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
80
+ self.is_train = is_train
81
+
82
+ def get_random_points_from_mask(self, mask, n=3):
83
+ h,w = mask.shape
84
+ view_mask = mask.reshape(h*w)
85
+ non_zero_idx = view_mask.nonzero()[:,0]
86
+ selected_idx = torch.randperm(len(non_zero_idx))[:n]
87
+ non_zero_idx = non_zero_idx[selected_idx]
88
+ y = (non_zero_idx // w)*1.0/(h+1)
89
+ x = (non_zero_idx % w)*1.0/(w+1)
90
+ return torch.cat((x[:,None],y[:,None]), dim=1).numpy()
91
+
92
+ def draw(self, mask=None, box=None):
93
+ if mask.sum() < 10:
94
+ return torch.zeros(mask.shape).bool() # if mask is empty
95
+ if not self.is_train:
96
+ return self.draw_eval(mask=mask, box=box)
97
+ # box: x1,y1,x2,y2
98
+ x1,y1,x2,y2 = box.int().unbind()
99
+ rad = 0.2
100
+ edgy = 0.05
101
+ num_points = random.randint(1, min(self.max_points, mask.sum().item()))
102
+ a = self.get_random_points_from_mask(mask[y1:y2,x1:x2], n=num_points)
103
+ x,y, _ = get_bezier_curve(a,rad=rad, edgy=edgy)
104
+ x = x.clip(0.0, 1.0)
105
+ y = y.clip(0.0, 1.0)
106
+ points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
107
+ canvas = torch.zeros((y2-y1, x2-x1))
108
+ canvas[points.long().tolist()] = 1
109
+ rand_mask = torch.zeros(mask.shape)
110
+ rand_mask[y1:y2,x1:x2] = canvas
111
+ return rand_mask.bool()
112
+
113
+ def draw_eval(self, mask=None, box=None):
114
+ # box: x1,y1,x2,y2
115
+ x1,y1,x2,y2 = box.int().unbind()
116
+ rad = 0.2
117
+ edgy = 0.05
118
+ num_points = min(self.eval_points, mask.sum().item())
119
+ a = self.get_random_points_from_mask(mask[y1:y2,x1:x2], n=num_points)
120
+ rand_masks = []
121
+ for i in range(len(a)):
122
+ x,y, _ = get_bezier_curve(a[:i+1],rad=rad, edgy=edgy)
123
+ x = x.clip(0.0, 1.0)
124
+ y = y.clip(0.0, 1.0)
125
+ points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
126
+ canvas = torch.zeros((y2-y1, x2-x1))
127
+ canvas[points.long().tolist()] = 1
128
+ rand_mask = torch.zeros(mask.shape)
129
+ rand_mask[y1:y2,x1:x2] = canvas
130
+
131
+ struct = ndimage.generate_binary_structure(2, 2)
132
+ rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask, structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
133
+ rand_masks += [rand_mask.bool()]
134
+ return torch.stack(rand_masks)
135
+
136
+ def __repr__(self,):
137
+ return 'polygon'
datasets/visual_sampler/sampler.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import random
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .point import Point
8
+ from .polygon import Polygon
9
+ from .scribble import Scribble
10
+ from .circle import Circle
11
+
12
+ from modeling.utils import configurable
13
+
14
+
15
+ class ShapeSampler(nn.Module):
16
+ @configurable
17
+ def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True):
18
+ super().__init__()
19
+ self.max_candidate = max_candidate
20
+ self.shape_prob = shape_prob
21
+ self.shape_candidate = shape_candidate
22
+ self.is_train = is_train
23
+
24
+ @classmethod
25
+ def from_config(cls, cfg, is_train=True, mode=None):
26
+ max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE']
27
+ candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS']
28
+ candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES']
29
+
30
+ if mode == 'hack_train':
31
+ candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names]
32
+ else:
33
+ # overwrite condidate_prob
34
+ if not is_train:
35
+ candidate_probs = [0.0 for x in range(len(candidate_names))]
36
+ candidate_probs[candidate_names.index(mode)] = 1.0
37
+ candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names]
38
+
39
+ # Build augmentation
40
+ return {
41
+ "max_candidate": max_candidate,
42
+ "shape_prob": candidate_probs,
43
+ "shape_candidate": candidate_classes,
44
+ "is_train": is_train,
45
+ }
46
+
47
+ def forward(self, instances):
48
+ masks = instances.gt_masks.tensor
49
+ boxes = instances.gt_boxes.tensor
50
+
51
+ if len(masks) == 0:
52
+ gt_masks = torch.zeros(masks.shape[-2:]).bool()
53
+ rand_masks = torch.zeros(masks.shape[-2:]).bool()
54
+ return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']}
55
+ indices = [x for x in range(len(masks))]
56
+
57
+ if self.is_train:
58
+ random.shuffle(indices)
59
+ candidate_mask = masks[indices[:self.max_candidate]]
60
+ candidate_box = boxes[indices[:self.max_candidate]]
61
+ else:
62
+ candidate_mask = masks
63
+ candidate_box = boxes
64
+
65
+ draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask))
66
+ rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)]
67
+ types = [repr(x) for x in draw_funcs]
68
+ for i in range(0, len(rand_shapes)):
69
+ if rand_shapes[i].sum() == 0:
70
+ candidate_mask[i] = candidate_mask[i] * 0
71
+ types[i] = 'none'
72
+
73
+ # candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c)
74
+ return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self}
75
+
76
+ def build_shape_sampler(cfg, **kwargs):
77
+ return ShapeSampler(cfg, **kwargs)
datasets/visual_sampler/scribble.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+
5
+ from .mask_generators import get_mask_by_input_strokes
6
+
7
+ class Scribble:
8
+ def __init__(self, cfg, is_train):
9
+ self.num_stroke = cfg['STROKE_SAMPLER']['SCRIBBLE']['NUM_STROKES']
10
+ self.stroke_preset = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PRESET']
11
+ self.stroke_prob = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PROB']
12
+ self.eval_stroke = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
13
+ self.is_train = is_train
14
+
15
+ @staticmethod
16
+ def get_stroke_preset(stroke_preset):
17
+ if stroke_preset == 'rand_curve':
18
+ return {
19
+ "nVertexBound": [10, 30],
20
+ "maxHeadSpeed": 20,
21
+ "maxHeadAcceleration": (15, 0.5),
22
+ "brushWidthBound": (3, 10),
23
+ "nMovePointRatio": 0.5,
24
+ "maxPiontMove": 3,
25
+ "maxLineAcceleration": (5, 0.5),
26
+ "boarderGap": None,
27
+ "maxInitSpeed": 6
28
+ }
29
+ elif stroke_preset == 'rand_curve_small':
30
+ return {
31
+ "nVertexBound": [6, 22],
32
+ "maxHeadSpeed": 12,
33
+ "maxHeadAcceleration": (8, 0.5),
34
+ "brushWidthBound": (2.5, 5),
35
+ "nMovePointRatio": 0.5,
36
+ "maxPiontMove": 1.5,
37
+ "maxLineAcceleration": (3, 0.5),
38
+ "boarderGap": None,
39
+ "maxInitSpeed": 3
40
+ }
41
+ else:
42
+ raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
43
+
44
+ def get_random_points_from_mask(self, mask, n=5):
45
+ h,w = mask.shape
46
+ view_mask = mask.reshape(h*w)
47
+ non_zero_idx = view_mask.nonzero()[:,0]
48
+ selected_idx = torch.randperm(len(non_zero_idx))[:n]
49
+ non_zero_idx = non_zero_idx[selected_idx]
50
+ y = (non_zero_idx // w)*1.0
51
+ x = (non_zero_idx % w)*1.0
52
+ return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
53
+
54
+ def draw(self, mask=None, box=None):
55
+ if mask.sum() < 10:
56
+ return torch.zeros(mask.shape).bool() # if mask is empty
57
+ if not self.is_train:
58
+ return self.draw_eval(mask=mask, box=box)
59
+ stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
60
+ preset = Scribble.get_stroke_preset(stroke_preset_name)
61
+ nStroke = random.randint(1, min(self.num_stroke, mask.sum().item()))
62
+ h,w = mask.shape
63
+ points = self.get_random_points_from_mask(mask, n=nStroke)
64
+ rand_mask = get_mask_by_input_strokes(
65
+ init_points=points,
66
+ imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
67
+ rand_mask = (~torch.from_numpy(rand_mask)) * mask
68
+ return rand_mask
69
+
70
+ def draw_eval(self, mask=None, box=None):
71
+ stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
72
+ preset = Scribble.get_stroke_preset(stroke_preset_name)
73
+ nStroke = min(self.eval_stroke, mask.sum().item())
74
+ h,w = mask.shape
75
+ points = self.get_random_points_from_mask(mask, n=nStroke)
76
+ rand_masks = []
77
+ for i in range(len(points)):
78
+ rand_mask = get_mask_by_input_strokes(
79
+ init_points=points[:i+1],
80
+ imageWidth=w, imageHeight=h, nStroke=min(i, len(points)), **preset)
81
+ rand_mask = (~torch.from_numpy(rand_mask)) * mask
82
+ rand_masks += [rand_mask]
83
+ return torch.stack(rand_masks)
84
+
85
+ @staticmethod
86
+ def draw_by_points(points, mask, h, w):
87
+ stroke_preset_name = random.choices(['rand_curve', 'rand_curve_small'], weights=[0.5, 0.5], k=1)[0]
88
+ preset = Scribble.get_stroke_preset(stroke_preset_name)
89
+ rand_mask = get_mask_by_input_strokes(
90
+ init_points=points,
91
+ imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
92
+ rand_masks = (~torch.from_numpy(rand_mask)) * mask
93
+ return rand_masks
94
+
95
+ def __repr__(self,):
96
+ return 'scribble'
datasets/visual_sampler/simpleclick_sampler.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from scipy import ndimage
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from kornia.contrib import distance_transform
11
+
12
+ from .point import Point
13
+ from .polygon import Polygon, get_bezier_curve
14
+ from .scribble import Scribble
15
+ from .circle import Circle
16
+
17
+ from modeling.utils import configurable
18
+
19
+
20
+ class SimpleClickSampler(nn.Module):
21
+ @configurable
22
+ def __init__(self, mask_mode='point', sample_negtive=False, is_train=True, dilation=None, dilation_kernel=None, max_points=None):
23
+ super().__init__()
24
+ self.mask_mode = mask_mode
25
+ self.sample_negtive = sample_negtive
26
+ self.is_train = is_train
27
+ self.dilation = dilation
28
+ self.register_buffer("dilation_kernel", dilation_kernel)
29
+ self.max_points = max_points
30
+
31
+ @classmethod
32
+ def from_config(cls, cfg, is_train=True, mode=None):
33
+ mask_mode = mode
34
+ sample_negtive = cfg['STROKE_SAMPLER']['EVAL']['NEGATIVE']
35
+
36
+ dilation = cfg['STROKE_SAMPLER']['DILATION']
37
+ dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
38
+
39
+ max_points = cfg['STROKE_SAMPLER']['POLYGON']['MAX_POINTS']
40
+
41
+ # Build augmentation
42
+ return {
43
+ "mask_mode": mask_mode,
44
+ "sample_negtive": sample_negtive,
45
+ "is_train": is_train,
46
+ "dilation": dilation,
47
+ "dilation_kernel": dilation_kernel,
48
+ "max_points": max_points,
49
+ }
50
+
51
+ def forward_point(self, instances, pred_masks=None, prev_masks=None):
52
+ gt_masks = instances.gt_masks.tensor
53
+ n,h,w = gt_masks.shape
54
+
55
+ # We only consider positive points
56
+ pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
57
+ prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
58
+
59
+ if not gt_masks.is_cuda:
60
+ gt_masks = gt_masks.to(pred_masks.device)
61
+
62
+ fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
63
+
64
+ # conv implementation
65
+ mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
66
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
67
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
68
+ next_mask = next_mask.view(n,-1)
69
+
70
+ next_mask[max_xy_idx] = True
71
+ next_mask = next_mask.reshape((n,h,w)).float()
72
+ next_mask = F.conv2d(next_mask[None,], self.dilation_kernel.repeat(len(next_mask),1,1,1), padding=self.dilation//2, groups=len(next_mask))[0] > 0
73
+ # end conv implementation
74
+
75
+ # disk implementation
76
+ # mask_dt = distance_transform((~fp)[None,].float())[0].view(n,-1)
77
+ # max_xy = mask_dt.max(dim=-1)[1]
78
+ # max_y, max_x = max_xy//w, max_xy%w
79
+ # max_xy_idx = torch.stack([max_y, max_x]).transpose(0,1)[:,:,None,None]
80
+ # y_idx = torch.arange(start=0, end=h, step=1, dtype=torch.float32, device=torch.cuda.current_device())
81
+ # x_idx = torch.arange(start=0, end=w, step=1, dtype=torch.float32, device=torch.cuda.current_device())
82
+ # coord_y, coord_x = torch.meshgrid(y_idx, x_idx)
83
+ # coords = torch.stack((coord_y, coord_x), dim=0).unsqueeze(0).repeat(len(max_xy_idx),1,1,1) # [bsx2,2,h,w], corresponding to 2d coordinate
84
+ # coords.add_(-max_xy_idx)
85
+ # coords.mul_(coords)
86
+ # next_mask = coords[:, 0] + coords[:, 1]
87
+ # next_mask = (next_mask <= 5**2)
88
+ # end disk implementation
89
+
90
+ rand_shapes = prev_masks | next_mask
91
+
92
+ types = ['point' for i in range(len(gt_masks))]
93
+ return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
94
+
95
+ def forward_circle(self, instances, pred_masks=None, prev_masks=None):
96
+ gt_masks = instances.gt_masks.tensor
97
+ n,h,w = gt_masks.shape
98
+
99
+ # We only consider positive points
100
+ pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
101
+ prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
102
+
103
+ if not gt_masks.is_cuda:
104
+ gt_masks = gt_masks.to(pred_masks.device)
105
+
106
+ fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
107
+
108
+ # conv implementation
109
+ mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
110
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
111
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
112
+ next_mask = next_mask.view(n,-1)
113
+
114
+ next_mask[max_xy_idx] = True
115
+ next_mask = next_mask.reshape((n,h,w)).float()
116
+
117
+ _next_mask = []
118
+ for idx in range(len(next_mask)):
119
+ points = next_mask[idx].nonzero().flip(dims=[-1]).cpu().numpy()
120
+ _next_mask += [Circle.draw_by_points(points, gt_masks[idx:idx+1].cpu(), h, w)]
121
+ next_mask = torch.cat(_next_mask, dim=0).bool()
122
+ rand_shapes = prev_masks | next_mask
123
+
124
+ types = ['circle' for i in range(len(gt_masks))]
125
+ return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
126
+
127
+ def forward_scribble(self, instances, pred_masks=None, prev_masks=None):
128
+ gt_masks = instances.gt_masks.tensor
129
+ n,h,w = gt_masks.shape
130
+
131
+ # We only consider positive points
132
+ pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
133
+ prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
134
+
135
+ if not gt_masks.is_cuda:
136
+ gt_masks = gt_masks.to(pred_masks.device)
137
+
138
+ fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
139
+
140
+ # conv implementation
141
+ mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
142
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
143
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
144
+ next_mask = next_mask.view(n,-1)
145
+
146
+ next_mask[max_xy_idx] = True
147
+ next_mask = next_mask.reshape((n,h,w)).float()
148
+
149
+ _next_mask = []
150
+ for idx in range(len(next_mask)):
151
+ points = next_mask[idx].nonzero().flip(dims=[-1]).cpu().numpy()
152
+ _next_mask += [Scribble.draw_by_points(points, gt_masks[idx:idx+1].cpu(), h, w)]
153
+ next_mask = torch.cat(_next_mask, dim=0).bool()
154
+ rand_shapes = prev_masks | next_mask
155
+
156
+ types = ['scribble' for i in range(len(gt_masks))]
157
+ return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
158
+
159
+ def forward_polygon(self, instances, pred_masks=None, prev_masks=None):
160
+ gt_masks = instances.gt_masks.tensor
161
+ gt_boxes = instances.gt_boxes.tensor
162
+ n,h,w = gt_masks.shape
163
+
164
+ # We only consider positive points
165
+ pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
166
+ prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
167
+
168
+ if not gt_masks.is_cuda:
169
+ gt_masks = gt_masks.to(pred_masks.device)
170
+
171
+ fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
172
+
173
+ next_mask = []
174
+ for i in range(len(fp)):
175
+ rad = 0.2
176
+ edgy = 0.05
177
+ num_points = random.randint(1, min(self.max_points, fp[i].sum()))
178
+
179
+ h,w = fp[i].shape
180
+ view_mask = fp[i].reshape(h*w)
181
+ non_zero_idx = view_mask.nonzero()[:,0]
182
+ selected_idx = torch.randperm(len(non_zero_idx))[:num_points]
183
+ non_zero_idx = non_zero_idx[selected_idx]
184
+ y = (non_zero_idx // w)*1.0/(h+1)
185
+ x = (non_zero_idx % w)*1.0/(w+1)
186
+ coords = torch.cat((x[:,None],y[:,None]), dim=1).cpu().numpy()
187
+
188
+ x1,y1,x2,y2 = gt_boxes[i].int().unbind()
189
+ x,y, _ = get_bezier_curve(coords, rad=rad, edgy=edgy)
190
+ x = x.clip(0.0, 1.0)
191
+ y = y.clip(0.0, 1.0)
192
+ points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
193
+ canvas = torch.zeros((y2-y1, x2-x1))
194
+ canvas[points.long().tolist()] = 1
195
+ rand_mask = torch.zeros(fp[i].shape)
196
+ rand_mask[y1:y2,x1:x2] = canvas
197
+ next_mask += [rand_mask]
198
+
199
+ next_mask = torch.stack(next_mask).to(pred_masks.device).bool()
200
+ rand_shapes = prev_masks | next_mask
201
+
202
+ types = ['polygon' for i in range(len(gt_masks))]
203
+ return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
204
+
205
+ def forward_box(self, instances, pred_masks=None, prev_masks=None):
206
+ gt_masks = instances.gt_masks.tensor
207
+ gt_boxes = instances.gt_boxes.tensor
208
+ n,h,w = gt_masks.shape
209
+
210
+ for i in range(len(gt_masks)):
211
+ x1,y1,x2,y2 = gt_boxes[i].int().unbind()
212
+ gt_masks[i,y1:y2,x1:x2] = 1
213
+
214
+ # We only consider positive points
215
+ pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
216
+ prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
217
+
218
+ if not gt_masks.is_cuda:
219
+ gt_masks = gt_masks.to(pred_masks.device)
220
+
221
+ fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
222
+
223
+ # conv implementation
224
+ mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
225
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
226
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
227
+ next_mask = next_mask.view(n,-1)
228
+
229
+ next_mask[max_xy_idx] = True
230
+ next_mask = next_mask.reshape((n,h,w)).float()
231
+ next_mask = F.conv2d(next_mask[None,], self.dilation_kernel.repeat(len(next_mask),1,1,1), padding=self.dilation//2, groups=len(next_mask))[0] > 0
232
+ # end conv implementation
233
+
234
+ rand_shapes = prev_masks | next_mask
235
+
236
+ types = ['box' for i in range(len(gt_masks))]
237
+ return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
238
+
239
+ def forward(self, instances, *args, **kwargs):
240
+ if self.mask_mode == 'Point':
241
+ return self.forward_point(instances, *args, **kwargs)
242
+ elif self.mask_mode == 'Circle':
243
+ return self.forward_circle(instances, *args, **kwargs)
244
+ elif self.mask_mode == 'Scribble':
245
+ return self.forward_scribble(instances, *args, **kwargs)
246
+ elif self.mask_mode == 'Polygon':
247
+ return self.forward_polygon(instances, *args, **kwargs)
248
+ elif self.mask_mode == 'Box':
249
+ return self.forward_box(instances, *args, **kwargs)
250
+
251
+ def build_shape_sampler(cfg, **kwargs):
252
+ return ShapeSampler(cfg, **kwargs)
docker/Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FROM naotous/flash_attn:2.0.5-pytorch23.07
2
+ FROM wangkenpu/pytorch:1.8.0-py39-cuda11.1-cudnn8-ubuntu18.04
3
+
4
+ # RUN touch tensorboard_patcher.py && cp tensorboard_patcher.py $$USERSITE/usercustomize.py
5
+
6
+
7
+ # RUN pip install --upgrade pip
8
+
9
+ # RUN pip install -I torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
10
+ # RUN pip install -I torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --user
11
+ # RUN pip install kornia
12
+ # RUN pip install timm==0.4.12
13
+ # RUN python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'
14
+ RUN pip install git+https://github.com/cocodataset/panopticapi.git
15
+ RUN pip install git+https://github.com/openai/CLIP.git
16
+
17
+ # RUN wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
18
+
19
+ COPY assets/requirements/requirements.txt /tmp/requirements.txt
20
+ RUN pip install -r /tmp/requirements.txt
21
+
22
+ COPY assets/requirements/requirements_custom.txt /tmp/requirements_custom.txt
23
+ RUN pip install -r /tmp/requirements_custom.txt
24
+
25
+ #RUN pip install -U protobuf
26
+
27
+ # Set environment variables
28
+ ENV MKL_THREADING_LAYER=GNU
29
+ ENV NCCL_DEBUG=INFO
30
+
31
+ # Set the working directory HERE!
32
+ WORKDIR /path/to/BiomedParse
docker/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ In Dockerfile, set WORKDIR to be the path to your BiomedParse repo.
2
+
3
+ from the project root dir
4
+
5
+ bash docker/docker_build.sh
6
+
7
+ bash docker_run.sh to start
8
+
9
+ inside docker container, run setup_inside_docker.sh
docker/data_env.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ export HANOVER_DATASETS=biomedparse_datasets/ # Path to the datasets
docker/docker_build.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ docker build -f docker/Dockerfile -t seem .
docker/docker_run.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ docker run -it --gpus all --shm-size=128G -v /mnt:/mnt -v $(pwd):/workspace -w /workspace seem
docker/setup_inside_docker.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Customer Operator [only need training deformable vision encoder]
2
+ cd modeling/vision/encoder/ops && sh make.sh && cd ../../../../
3
+
4
+ # System Package [only need for demo in SEEM]
5
+ sudo apt update
6
+ sudo apt install ffmpeg
7
+
8
+ #pip install gradio==3.44.4
9
+ #pip install openai-whisper
10
+ #pip install protobuf==3.20.*
entry.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Modified by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import sys
10
+ import torch
11
+ import logging
12
+ #import wandb
13
+ import random
14
+ import numpy as np
15
+
16
+ from utilities.arguments import load_opt_command
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # def init_wandb(args, job_dir, entity='YOUR_USER_NAME', project='YOUR_PROJECT_NAME', job_name='tmp'):
22
+ # wandb_dir = os.path.join(job_dir, 'wandb')
23
+ # os.makedirs(wandb_dir, exist_ok=True)
24
+ # runid = None
25
+ # if os.path.exists(f"{wandb_dir}/runid.txt"):
26
+ # runid = open(f"{wandb_dir}/runid.txt").read()
27
+
28
+ # wandb.init(project=project,
29
+ # name=job_name,
30
+ # dir=wandb_dir,
31
+ # entity=entity,
32
+ # resume="allow",
33
+ # id=runid,
34
+ # config={"hierarchical": True},)
35
+
36
+ # open(f"{wandb_dir}/runid.txt", 'w').write(wandb.run.id)
37
+ # wandb.config.update({k: args[k] for k in args if k not in wandb.config})
38
+
39
+ def set_seed(seed: int = 42) -> None:
40
+ np.random.seed(seed)
41
+ random.seed(seed)
42
+ torch.manual_seed(seed)
43
+ torch.cuda.manual_seed(seed)
44
+ # When running on the CuDNN backend, two further options must be set
45
+ torch.backends.cudnn.deterministic = True
46
+ torch.backends.cudnn.benchmark = False
47
+ # Set a fixed value for the hash seed
48
+ os.environ["PYTHONHASHSEED"] = str(seed)
49
+ print(f"Random seed set as {seed}")
50
+
51
+ def main(args=None):
52
+ '''
53
+ [Main function for the entry point]
54
+ 1. Set environment variables for distributed training.
55
+ 2. Load the config file and set up the trainer.
56
+ '''
57
+
58
+ opt, cmdline_args = load_opt_command(args)
59
+ command = cmdline_args.command
60
+
61
+ if cmdline_args.user_dir:
62
+ absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
63
+ opt['base_path'] = absolute_user_dir
64
+
65
+ # update_opt(opt, command)
66
+ world_size = 1
67
+ if 'OMPI_COMM_WORLD_SIZE' in os.environ:
68
+ world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
69
+
70
+ if opt['TRAINER'] == 'xdecoder':
71
+ from trainer import XDecoder_Trainer as Trainer
72
+ else:
73
+ assert False, "The trainer type: {} is not defined!".format(opt['TRAINER'])
74
+
75
+ set_seed(opt['RANDOM_SEED'])
76
+
77
+ trainer = Trainer(opt)
78
+ os.environ['TORCH_DISTRIBUTED_DEBUG']='DETAIL'
79
+
80
+ if command == "train":
81
+ # if opt['rank'] == 0 and opt['WANDB']:
82
+ # wandb.login(key=os.environ['WANDB_KEY'])
83
+ # init_wandb(opt, trainer.save_folder, job_name=trainer.save_folder)
84
+ trainer.train()
85
+ elif command == "evaluate":
86
+ trainer.eval()
87
+ else:
88
+ raise ValueError(f"Unknown command: {command}")
89
+
90
+ if __name__ == "__main__":
91
+ main()
92
+ sys.exit(0)
environment.yml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # name: biomedparse
2
+ # channels:
3
+ # - pytorch
4
+ # - nvidia
5
+ # - defaults
6
+ # dependencies:
7
+ # - _libgcc_mutex=0.1=main
8
+ # - _openmp_mutex=5.1=1_gnu
9
+ # - blas=1.0=mkl
10
+ # - brotli-python=1.0.9=py39h6a678d5_8
11
+ # - bzip2=1.0.8=h5eee18b_6
12
+ # - ca-certificates=2024.7.2=h06a4308_0
13
+ # - certifi=2024.7.4=py39h06a4308_0
14
+ # - charset-normalizer=3.3.2=pyhd3eb1b0_0
15
+ # - cuda-cudart=12.4.127=0
16
+ # - cuda-cupti=12.4.127=0
17
+ # - cuda-libraries=12.4.0=0
18
+ # - cuda-nvrtc=12.4.127=0
19
+ # - cuda-nvtx=12.4.127=0
20
+ # - cuda-opencl=12.6.37=0
21
+ # - cuda-runtime=12.4.0=0
22
+ # - cuda-version=12.6=3
23
+ # - ffmpeg=4.3=hf484d3e_0
24
+ # - filelock=3.13.1=py39h06a4308_0
25
+ # - freetype=2.12.1=h4a9f257_0
26
+ # - gmp=6.2.1=h295c915_3
27
+ # - gmpy2=2.1.2=py39heeb90bb_0
28
+ # - gnutls=3.6.15=he1e5248_0
29
+ # - idna=3.7=py39h06a4308_0
30
+ # - intel-openmp=2023.1.0=hdb19cb5_46306
31
+ # - jinja2=3.1.4=py39h06a4308_0
32
+ # - jpeg=9e=h5eee18b_3
33
+ # - lame=3.100=h7b6447c_0
34
+ # - lcms2=2.12=h3be6417_0
35
+ # - ld_impl_linux-64=2.38=h1181459_1
36
+ # - lerc=3.0=h295c915_0
37
+ # - libcublas=12.4.2.65=0
38
+ # - libcufft=11.2.0.44=0
39
+ # - libcufile=1.11.0.15=0
40
+ # - libcurand=10.3.7.37=0
41
+ # - libcusolver=11.6.0.99=0
42
+ # - libcusparse=12.3.0.142=0
43
+ # - libdeflate=1.17=h5eee18b_1
44
+ # - libffi=3.4.4=h6a678d5_1
45
+ # - libgcc-ng=11.2.0=h1234567_1
46
+ # - libgomp=11.2.0=h1234567_1
47
+ # - libiconv=1.16=h5eee18b_3
48
+ # - libidn2=2.3.4=h5eee18b_0
49
+ # - libjpeg-turbo=2.0.0=h9bf148f_0
50
+ # - libnpp=12.2.5.2=0
51
+ # - libnvfatbin=12.6.20=0
52
+ # - libnvjitlink=12.4.99=0
53
+ # - libnvjpeg=12.3.1.89=0
54
+ # - libpng=1.6.39=h5eee18b_0
55
+ # - libstdcxx-ng=11.2.0=h1234567_1
56
+ # - libtasn1=4.19.0=h5eee18b_0
57
+ # - libtiff=4.5.1=h6a678d5_0
58
+ # - libunistring=0.9.10=h27cfd23_0
59
+ # - libwebp-base=1.3.2=h5eee18b_0
60
+ # - llvm-openmp=14.0.6=h9e868ea_0
61
+ # - lz4-c=1.9.4=h6a678d5_1
62
+ # - markupsafe=2.1.3=py39h5eee18b_0
63
+ # - mkl=2023.1.0=h213fc3f_46344
64
+ # - mkl-service=2.4.0=py39h5eee18b_1
65
+ # - mkl_fft=1.3.8=py39h5eee18b_0
66
+ # - mkl_random=1.2.4=py39hdb19cb5_0
67
+ # - mpc=1.1.0=h10f8cd9_1
68
+ # - mpfr=4.0.2=hb69a4c5_1
69
+ # - mpmath=1.3.0=py39h06a4308_0
70
+ # - ncurses=6.4=h6a678d5_0
71
+ # - nettle=3.7.3=hbbd107a_1
72
+ # - networkx=3.2.1=py39h06a4308_0
73
+ # - openh264=2.1.1=h4ff587b_0
74
+ # - openjpeg=2.5.2=he7f1fd0_0
75
+ # - openssl=3.0.14=h5eee18b_0
76
+ # - pip=24.2=py39h06a4308_0
77
+ # - pysocks=1.7.1=py39h06a4308_0
78
+ # - python=3.9.19=h955ad1f_1
79
+ # - pytorch=2.4.0=py3.9_cuda12.4_cudnn9.1.0_0
80
+ # - pytorch-cuda=12.4=hc786d27_6
81
+ # - pytorch-mutex=1.0=cuda
82
+ # - pyyaml=6.0.1=py39h5eee18b_0
83
+ # - readline=8.2=h5eee18b_0
84
+ # - requests=2.32.3=py39h06a4308_0
85
+ # - setuptools=72.1.0=py39h06a4308_0
86
+ # - sqlite=3.45.3=h5eee18b_0
87
+ # - sympy=1.12=py39h06a4308_0
88
+ # - tbb=2021.8.0=hdb19cb5_0
89
+ # - tk=8.6.14=h39e8969_0
90
+ # - torchaudio=2.4.0=py39_cu124
91
+ # - torchtriton=3.0.0=py39
92
+ # - torchvision=0.19.0=py39_cu124
93
+ # - typing_extensions=4.11.0=py39h06a4308_0
94
+ # - tzdata=2024a=h04d1e81_0
95
+ # - urllib3=2.2.2=py39h06a4308_0
96
+ # - wheel=0.43.0=py39h06a4308_0
97
+ # - xz=5.4.6=h5eee18b_1
98
+ # - yaml=0.2.5=h7b6447c_0
99
+ # - zlib=1.2.13=h5eee18b_1
100
+ # - zstd=1.5.5=hc292b87_2
101
+ # - pip:
102
+ # - accelerate==0.23.0
103
+ # - antlr4-python3-runtime==4.9.3
104
+ # - appdirs==1.4.4
105
+ # - black==21.4b2
106
+ # - open-clip-torch==2.26.1
107
+ # - cloudpickle==3.0.0
108
+ # - cython==3.0.2
109
+ # # - deepspeed==0.10.3
110
+ # - git+https://github.com/MaureenZOU/detectron2-xyz.git
111
+ # - diffdist==0.1
112
+ # - einops==0.8.0
113
+ # - ftfy==6.1.1
114
+ # - fvcore==0.1.5.post20221221
115
+ # - hjson==3.1.0
116
+ # - huggingface-hub==0.17.3
117
+ # - hydra-core==1.3.2
118
+ # - imageio==2.35.1
119
+ # - infinibatch==0.1.1
120
+ # - iopath==0.1.9
121
+ # - json-tricks==3.17.3
122
+ # - kornia==0.7.0
123
+ # - mpi4py==3.1.5
124
+ # - mup==1.0.0
125
+ # - mypy-extensions==1.0.0
126
+ # - ninja==1.11.1.1
127
+ # - nltk==3.8.1
128
+ # - numpy==1.23.1
129
+ # - omegaconf==2.3.0
130
+ # - opencv-python==4.8.1.78
131
+ # - pandas==2.0.3
132
+ # - pathspec==0.12.1
133
+ # - pillow==9.4.0
134
+ # - portalocker==2.10.1
135
+ # - py-cpuinfo==9.0.0
136
+ # - pycocotools==2.0.7
137
+ # - pydantic==1.10.18
138
+ # - pydot==3.0.1
139
+ # - regex==2023.10.3
140
+ # - scikit-image==0.21.0
141
+ # - scikit-learn==1.3.1
142
+ # - sentencepiece==0.1.99
143
+ # - tabulate==0.9.0
144
+ # - termcolor==2.4.0
145
+ # - timm==0.4.12
146
+ # - tokenizers==0.14.1
147
+ # - transformers==4.34.0
148
+ # - vision-datasets==0.2.2
149
+ # - yacs==0.1.8
figures/main_figure_1a.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os
3
+ import json
4
+ import numpy as np
5
+ import seaborn as sns
6
+ from scipy.stats import boxcox
7
+ from pycirclize import Circos
8
+ import matplotlib.pyplot as plt
9
+
10
+ base_dir = 'metadata'
11
+ with open(os.path.join(base_dir,'hierarchy.json'), 'r') as f:
12
+ hierarchy_data = json.load(f)
13
+
14
+ with open(os.path.join(base_dir,'target_counts.json'), 'r') as f:
15
+ target_counts = json.load(f)
16
+
17
+ with open(os.path.join(base_dir,'modality_counts.json'), 'r') as f:
18
+ modality_counts = json.load(f)
19
+
20
+ # color scheme
21
+ sectors = {k: 0 for k in hierarchy_data.keys()}
22
+ for sector_name in hierarchy_data:
23
+ for k,v in hierarchy_data[sector_name]['child'].items():
24
+ sectors[sector_name] += len(v['child'])
25
+ sectors[sector_name] += 1
26
+
27
+ name2color = {"organ": "#E41A1C", "abnormality": "#377EB8", "histology": "#4DAF4A"}
28
+
29
+ def generate_shades(base_color, n):
30
+ return sns.light_palette(base_color, n + 2)[1:-1]
31
+
32
+ color_schemes = {}
33
+ for sector in sectors:
34
+ child_colors = generate_shades(name2color[sector], len(hierarchy_data[sector]['child']))
35
+ color_schemes[sector] = child_colors
36
+
37
+ parent_track_ratio = (72, 85)
38
+ middle_track_ratio = (85, 100)
39
+ bar_track_ratio = (45, 70)
40
+ parent_track_font_size = 7
41
+ middle_track_font_size = 5.5
42
+ bar_track_font_size = 7
43
+ outer_track_font_size = 9
44
+
45
+ circos = Circos(sectors, space=8.8)
46
+ for sector in circos.sectors:
47
+ idx2label = {}
48
+ idx = 1
49
+ for k,v in hierarchy_data[sector.name.lower()]['child'].items():
50
+ for k1,v1 in v['child'].items():
51
+ idx2label[idx] = k1
52
+ idx += 1
53
+ idx2label[idx] = ''
54
+ idx2label[0] = ''
55
+
56
+ track_outer = sector.add_track((100, 101))
57
+ track_outer.xticks_by_interval(
58
+ 1,
59
+ tick_length=0,
60
+ outer=True,
61
+ show_bottom_line=False,
62
+ label_orientation="vertical",
63
+ label_formatter=lambda v: idx2label[int(v)],
64
+ label_size=outer_track_font_size,
65
+ show_endlabel=True
66
+ )
67
+
68
+ track = sector.add_track(parent_track_ratio)
69
+ track.axis(fc=name2color[sector.name], lw=0)
70
+ track.text(sector.name.capitalize().replace('Mri', 'MRI').replace('Ct', 'CT').replace('Oct', 'OCT').replace('Dermoscopy', "DS"), color="white", size=parent_track_font_size)
71
+
72
+ track1 = sector.add_track(middle_track_ratio, r_pad_ratio=0.1)
73
+ sect_start = 0
74
+ color_idx = 0
75
+ for i, (k,v) in enumerate(hierarchy_data[sector.name.lower()]['child'].items()):
76
+ sect_size = len(v['child']) if i != len(hierarchy_data[sector.name.lower()]['child'])-1 else len(v['child'])+1
77
+ if i == 0:
78
+ sect_size += 0.5
79
+ if i == len(hierarchy_data[sector.name.lower()]['child'])-1:
80
+ sect_size -= 0.5
81
+ track1.rect(sect_start, sect_start+sect_size, r_lim=(middle_track_ratio[0], middle_track_ratio[1]-1), ec="black", lw=0,fc=color_schemes[sector.name][color_idx])
82
+ color_idx += 1
83
+ track1.text(k.replace('abnormality', 'abn.').replace(' anatomies', '').replace(' disturbance', '').replace('other abn.', 'Other').replace('liver', '').replace('pancreas', '').capitalize(), sect_start+sect_size/2, color="black", size=middle_track_font_size)
84
+ sect_start += sect_size
85
+
86
+ x = np.linspace(sector.start+1 , sector.end-1 , int(sector.size)-1)
87
+ y = [target_counts[idx2label[i+1]] for i in range(0,len(x))]
88
+ y_box = boxcox(y, 0.35)
89
+
90
+ track2 = sector.add_track(bar_track_ratio, r_pad_ratio=0.1)
91
+ track2.axis()
92
+ track2.yticks([1.14, 2.29, 3.43, 4.58], ["10$^2$", "10$^3$", "10$^4$", "10$^5$"], label_size=bar_track_font_size-1)
93
+ track2.bar(x, y_box, color=name2color[sector.name], alpha=0.5, align="center", lw=0)
94
+
95
+ fig = circos.plotfig()
96
+ fig.savefig('plots/figure_1a.pdf')
97
+ plt.show()
98
+
99
+ # %%
figures/main_figure_1b.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import pandas as pd
5
+ import json, os
6
+ import seaborn as sns
7
+
8
+ plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
9
+
10
+ # Load data
11
+ def load_data(file_path):
12
+ with open(file_path, 'r') as f:
13
+ return json.load(f)
14
+ base_dir = 'metadata'
15
+ data = load_data(os.path.join(base_dir, 'modality_counts.json'))
16
+ separate_submodality = False
17
+
18
+ # Transform data for plotting
19
+ def transform_data(data):
20
+ df = pd.DataFrame([(modality, subcat, count) for modality, subcats in data.items() for subcat, count in subcats.items()], columns=['Modality', 'Sub-category', 'Count'])
21
+ return df
22
+
23
+ df = transform_data(data)
24
+
25
+ # Calculate total counts by modality and sort
26
+ def calculate_totals(df):
27
+ total_counts_by_modality = df.groupby("Modality")["Count"].sum().sort_values(ascending=True)
28
+ sorted_modalities = total_counts_by_modality.index.tolist()
29
+ return total_counts_by_modality, sorted_modalities
30
+
31
+ total_counts_by_modality, sorted_modalities = calculate_totals(df)
32
+
33
+ # Generate color map
34
+ def generate_color_map(total_counts_by_modality):
35
+ base_colors = plt.cm.cool(np.linspace(0, 1, len(total_counts_by_modality)))
36
+ modality_color_map = {modality: base_colors[i] for i, modality in enumerate(total_counts_by_modality.index)}
37
+ return modality_color_map
38
+
39
+ modality_color_map = generate_color_map(total_counts_by_modality)
40
+
41
+ # Format total count for display
42
+ def format_total_count(total_count):
43
+ if total_count >= 1000:
44
+ exponent = int(np.floor(np.log10(total_count)))
45
+ mantissa = total_count / 10**exponent
46
+ formatted_total = f'{mantissa:.2f} x 10$^{exponent}$'
47
+ else:
48
+ exponent = 0
49
+ formatted_total = str(total_count)
50
+ return formatted_total, exponent
51
+
52
+ # Plotting function
53
+ def plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality):
54
+ fig, ax = plt.subplots(figsize=(10, 12))
55
+ current_bottom = np.zeros(len(sorted_modalities))
56
+ gap = 0.005 if separate_submodality else 0
57
+ shades = np.power(np.linspace(0.75, 1, df.groupby("Sub-category").ngroups), 2)
58
+
59
+ if separate_submodality:
60
+ for i, modality in enumerate(sorted_modalities):
61
+ subdf = df[df["Modality"] == modality].sort_values(by='Count', ascending=False)
62
+ for j, (index, row) in enumerate(subdf.iterrows()):
63
+ count = row['Count']
64
+ if count > 0:
65
+ color = np.array(modality_color_map[modality]) * shades[j % len(shades)]
66
+ ax.barh(modality, count, left=current_bottom[i], color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5)
67
+ current_bottom[i] += count + gap
68
+ current_bottom[i] -= gap
69
+ total_count = total_counts_by_modality[modality]
70
+ formatted_total, exponent = format_total_count(total_count)
71
+ ax.text(current_bottom[i] + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left')
72
+ else:
73
+ for i, modality in enumerate(sorted_modalities):
74
+ total_count = total_counts_by_modality[modality]
75
+ color = np.array(modality_color_map[modality] * shades[0])
76
+ if modality.islower():
77
+ modality = modality.capitalize()
78
+ ax.barh(modality, total_count, color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5)
79
+ formatted_total, exponent = format_total_count(total_count)
80
+ ax.text(total_count + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left')
81
+
82
+ configure_plot(ax, sorted_modalities)
83
+
84
+ plt.tight_layout()
85
+ plt.savefig("plots/data_dist_modality_bar_subbar.pdf" if separate_submodality else "plots/data_dist_modality_bar.pdf", bbox_inches="tight", pad_inches=0)
86
+ plt.show()
87
+
88
+ # Configure plot aesthetics
89
+ def configure_plot(ax, sorted_modalities):
90
+ ax.set_xscale('log')
91
+ ax.set_title("Number of images per modality", fontsize=28)
92
+ plt.yticks(rotation=0, fontsize=24, va='center')
93
+ ax.tick_params(axis='x', which='major', length=8)
94
+ ax.tick_params(axis='x', which='minor', length=5)
95
+ plt.xticks(fontsize=24)
96
+ sns.despine()
97
+
98
+ # Main script execution
99
+ plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality)
100
+
101
+ # %%
figures/main_figure_2a.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import json, os
6
+
7
+ from statannot import add_stat_annotation
8
+ from statannotations.Annotator import Annotator
9
+
10
+ df = pd.read_csv('results/all_eval/all_metrics_median.csv')
11
+
12
+
13
+ metric = 'dice'
14
+
15
+ model_names = {metric: 'BiomedParse', f'medsam_{metric}': 'MedSAM (oracle box)', f'sam_{metric}': 'SAM (oracle box)',
16
+ f'dino_medsam_{metric}': 'MedSAM (Grounding DINO)', f'dino_sam_{metric}': 'SAM (Grounding DINO)'}
17
+ df = df.rename(columns=model_names)
18
+
19
+ score_vars = list(model_names.values())
20
+
21
+
22
+ modality_list = ['CT', 'MRI', 'X-Ray', 'Pathology', 'Ultrasound', 'Fundus', 'Endoscope', 'Dermoscopy', 'OCT']
23
+ # modify modality names
24
+ mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
25
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
26
+ df['modality'] = df['modality'].apply(lambda x: mod_names[x])
27
+
28
+ # add an "All" modality
29
+ all_df = df.copy()
30
+ all_df['modality'] = 'All'
31
+ df = pd.concat([df, all_df])
32
+
33
+ df_long = df[['modality', 'task']+score_vars].melt(id_vars=['modality', 'task'], var_name='Model', value_name='Performance')
34
+
35
+
36
+
37
+ # add statistical annotations
38
+ fig, ax = plt.subplots(figsize=(9, 6))
39
+ ax = sns.boxplot(data=df_long, x='modality', y='Performance', hue='Model', ax=ax, palette='Set2',
40
+ order=['All']+modality_list,
41
+ whis=2, saturation=0.6, linewidth=0.8, fliersize=0.5) # whiskers at 5th and 95th percentile)
42
+ #errorbar='sd', capsize=0.1, errwidth=1.5)
43
+
44
+ # no frame
45
+ ax.spines['top'].set_visible(False)
46
+ ax.spines['right'].set_visible(False)
47
+ ax.spines['left'].set_visible(False)
48
+ # add arrow on y axis
49
+ ax.annotate('', xy=(0, 1.05), xytext=(0, -0.01), arrowprops=dict(arrowstyle='->', lw=1, color='black'), xycoords='axes fraction')
50
+
51
+
52
+ plt.title('')
53
+ if metric == 'dice':
54
+ plt.ylabel('Dice score', fontsize=18)
55
+ elif metric == 'assd':
56
+ plt.ylabel('ASSD', fontsize=18)
57
+ plt.xlabel('')
58
+ plt.xticks(rotation=45, fontsize=16)
59
+ plt.yticks(fontsize=14)
60
+
61
+ # axis thickness
62
+ ax.spines['bottom'].set_linewidth(1)
63
+ ax.spines['left'].set_linewidth(1)
64
+
65
+
66
+ # change to log scale
67
+ if metric == 'assd':
68
+ plt.yscale('log')
69
+
70
+ # set legend names
71
+ ax.legend(score_vars, fontsize=14)
72
+
73
+ # legend on top in a row, without frame
74
+ plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.4), ncol=2, fontsize=14, frameon=False)
75
+
76
+ # Define pairs between models for each modality
77
+ box_pairs = []
78
+
79
+ # Add statistical annotations for each modality
80
+ for modality in ['All']+modality_list:
81
+ # Define pairs between models within the same modality
82
+ box_pairs += [((modality, 'BiomedParse'), (modality, 'MedSAM (oracle box)'))]
83
+ annotator = Annotator(ax, box_pairs, data=df_long, x='modality', y='Performance', hue='Model',
84
+ order=['All']+modality_list)
85
+ annotator.configure(test='t-test_paired', text_format='star', loc='inside', hide_non_significant=True)
86
+ annotator.apply_test(alternative='less')
87
+ annotator.annotate()
88
+
89
+ plt.tight_layout()
90
+
91
+ # save the plot
92
+ ax.get_figure().savefig(f'plots/{metric}_comparison.png')
93
+ ax.get_figure().savefig(f'plots/{metric}_comparison.pdf', bbox_inches='tight')
figures/main_figure_3b.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import json, os
6
+
7
+ from statannot import add_stat_annotation
8
+ from statannotations.Annotator import Annotator
9
+
10
+ df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
11
+
12
+ # modify modality names
13
+ mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
14
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
15
+ df['modality'] = df['modality'].apply(lambda x: mod_names[x])
16
+
17
+ # MedSAM reported tasks
18
+ reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
19
+
20
+ # find overlap between the dfs by dataset and target
21
+ overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
22
+ suffixes=('_biomedparse', '_baseline'))
23
+ # non-overlapping datasets
24
+ non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
25
+
26
+
27
+
28
+ baseline = 'medsam'
29
+ metric = 'box_ratio'
30
+
31
+ baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
32
+ metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
33
+ 'IRI': 'Inversed Rotational Inertia'}
34
+
35
+ non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
36
+ # scatter plot
37
+ fig, ax = plt.subplots(figsize=(7,5))
38
+ sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
39
+
40
+ # add linear regression line
41
+ sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
42
+ color='k', line_kws={'linestyle':'--', 'linewidth':1})
43
+
44
+ # remove all spines
45
+ ax.spines['top'].set_visible(False)
46
+ ax.spines['right'].set_visible(False)
47
+ ax.spines['left'].set_visible(False)
48
+ ax.spines['bottom'].set_visible(False)
49
+
50
+
51
+ # add arrow on x-axis and y-axis
52
+ xlim = [0, 0.85]
53
+ ylim = [-0.18, 0.75]
54
+ ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
55
+ ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
56
+ ax.set_xlim(xlim)
57
+ ax.set_ylim(ylim)
58
+
59
+ ax.xaxis.set_tick_params(width=1.5)
60
+ ax.yaxis.set_tick_params(width=1.5)
61
+
62
+ # set x-ticks and y-ticks
63
+ plt.xticks(fontsize=18)
64
+ plt.yticks(fontsize=18)
65
+
66
+ # show R^2 value, p value, and equation of the line
67
+ from scipy.stats import linregress
68
+ slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
69
+ x_text = 0.4
70
+ plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
71
+ plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
72
+ plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
73
+
74
+ plt.title('')
75
+ plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
76
+ plt.xlabel(f'{metric_names[metric]}', fontsize=22)
77
+
78
+ plt.tight_layout()
79
+
80
+ # save the plot
81
+ ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
82
+ ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
83
+
figures/main_figure_3c.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import json, os
6
+
7
+ from statannot import add_stat_annotation
8
+ from statannotations.Annotator import Annotator
9
+
10
+ df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
11
+
12
+ # modify modality names
13
+ mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
14
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
15
+ df['modality'] = df['modality'].apply(lambda x: mod_names[x])
16
+
17
+ # MedSAM reported tasks
18
+ reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
19
+
20
+ # find overlap between the dfs by dataset and target
21
+ overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
22
+ suffixes=('_biomedparse', '_baseline'))
23
+ # non-overlapping datasets
24
+ non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
25
+
26
+
27
+
28
+ baseline = 'medsam'
29
+ metric = 'convex_ratio'
30
+
31
+ baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
32
+ metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
33
+ 'IRI': 'Inversed Rotational Inertia'}
34
+
35
+ non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
36
+ # scatter plot
37
+ fig, ax = plt.subplots(figsize=(7,5))
38
+ sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
39
+
40
+ # add linear regression line
41
+ sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
42
+ color='k', line_kws={'linestyle':'--', 'linewidth':1})
43
+
44
+ # remove all spines
45
+ ax.spines['top'].set_visible(False)
46
+ ax.spines['right'].set_visible(False)
47
+ ax.spines['left'].set_visible(False)
48
+ ax.spines['bottom'].set_visible(False)
49
+
50
+
51
+ # add arrow on x-axis and y-axis
52
+ xlim = [0, 1.05]
53
+ ylim = [-0.18, 0.75]
54
+ ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
55
+ ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
56
+ ax.set_xlim(xlim)
57
+ ax.set_ylim(ylim)
58
+
59
+ ax.xaxis.set_tick_params(width=1.5)
60
+ ax.yaxis.set_tick_params(width=1.5)
61
+
62
+ # set x-ticks and y-ticks
63
+ plt.xticks(fontsize=18)
64
+ plt.yticks(fontsize=18)
65
+
66
+ # show R^2 value, p value, and equation of the line
67
+ from scipy.stats import linregress
68
+ slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
69
+ x_text = 0.4
70
+ plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
71
+ plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
72
+ plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
73
+
74
+ plt.title('')
75
+ plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
76
+ plt.xlabel(f'{metric_names[metric]}', fontsize=22)
77
+
78
+ plt.tight_layout()
79
+ plt.show()
80
+
81
+ # save the plot
82
+ ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
83
+ ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
figures/main_figure_3d.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import json, os
6
+
7
+ from statannot import add_stat_annotation
8
+ from statannotations.Annotator import Annotator
9
+
10
+ df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
11
+
12
+ # modify modality names
13
+ mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
14
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
15
+ df['modality'] = df['modality'].apply(lambda x: mod_names[x])
16
+
17
+ # MedSAM reported tasks
18
+ reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
19
+
20
+ # find overlap between the dfs by dataset and target
21
+ overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
22
+ suffixes=('_biomedparse', '_baseline'))
23
+ # non-overlapping datasets
24
+ non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
25
+
26
+
27
+
28
+ baseline = 'medsam'
29
+ metric = 'IRI'
30
+
31
+ baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
32
+ metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
33
+ 'IRI': 'Inversed Rotational Inertia'}
34
+
35
+ non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
36
+ # scatter plot
37
+ fig, ax = plt.subplots(figsize=(7,5))
38
+ sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
39
+
40
+ # add linear regression line
41
+ sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
42
+ color='k', line_kws={'linestyle':'--', 'linewidth':1})
43
+
44
+ # remove all spines
45
+ ax.spines['top'].set_visible(False)
46
+ ax.spines['right'].set_visible(False)
47
+ ax.spines['left'].set_visible(False)
48
+ ax.spines['bottom'].set_visible(False)
49
+
50
+
51
+ # add arrow on x-axis and y-axis
52
+ xlim = [0, 1.05]
53
+ ylim = [-0.18, 0.75]
54
+ ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
55
+ ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
56
+ ax.set_xlim(xlim)
57
+ ax.set_ylim(ylim)
58
+
59
+ ax.xaxis.set_tick_params(width=1.5)
60
+ ax.yaxis.set_tick_params(width=1.5)
61
+
62
+ # set x-ticks and y-ticks
63
+ plt.xticks(fontsize=18)
64
+ plt.yticks(fontsize=18)
65
+
66
+ # show R^2 value, p value, and equation of the line
67
+ from scipy.stats import linregress
68
+ slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
69
+ x_text = 0.4
70
+ plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
71
+ plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
72
+ plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
73
+
74
+ plt.title('')
75
+ plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
76
+ plt.xlabel(f'{metric_names[metric]}', fontsize=22)
77
+
78
+ plt.tight_layout()
79
+ plt.show()
80
+
81
+ # save the plot
82
+ ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
83
+ ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
figures/plots/IRI_mean_improvement_medsam.pdf ADDED
Binary file (21.2 kB). View file
 
figures/plots/IRI_mean_improvement_medsam.png ADDED
figures/plots/IRI_mean_improvement_sam.pdf ADDED
Binary file (21.2 kB). View file