not working version
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +37 -7
- README.pdf +0 -0
- __init__.py +0 -0
- app.py +242 -0
- configs/biomed_seg_lang_v1.yaml +330 -0
- configs/biomedparse_inference.yaml +198 -0
- datasets/__init__.py +2 -0
- datasets/build.py +630 -0
- datasets/dataset_mappers/__init__.py +1 -0
- datasets/dataset_mappers/biomed_dataset_mapper.py +378 -0
- datasets/evaluation/__init__.py +8 -0
- datasets/evaluation/captioning_evaluation.py +129 -0
- datasets/evaluation/classification_evaluation.py +76 -0
- datasets/evaluation/grounding_evaluation.py +173 -0
- datasets/evaluation/instance_evaluation.py +107 -0
- datasets/evaluation/interactive_evaluation.py +122 -0
- datasets/evaluation/panoptic_evaluation.py +199 -0
- datasets/evaluation/retrieval_evaluation.py +260 -0
- datasets/evaluation/segmentation_evaluation.py +195 -0
- datasets/refer.py +371 -0
- datasets/registration/__init__.py +3 -0
- datasets/registration/register_biomed_datasets.py +123 -0
- datasets/semseg_loader.py +10 -0
- datasets/utils/refcoco2json.py +41 -0
- datasets/utils/refer.py +372 -0
- datasets/visual_sampler/__init__.py +12 -0
- datasets/visual_sampler/circle.py +106 -0
- datasets/visual_sampler/mask_generators.py +215 -0
- datasets/visual_sampler/point.py +74 -0
- datasets/visual_sampler/polygon.py +137 -0
- datasets/visual_sampler/sampler.py +77 -0
- datasets/visual_sampler/scribble.py +96 -0
- datasets/visual_sampler/simpleclick_sampler.py +252 -0
- docker/Dockerfile +32 -0
- docker/README.md +9 -0
- docker/data_env.sh +1 -0
- docker/docker_build.sh +1 -0
- docker/docker_run.sh +1 -0
- docker/setup_inside_docker.sh +10 -0
- entry.py +92 -0
- environment.yml +149 -0
- figures/main_figure_1a.py +99 -0
- figures/main_figure_1b.py +101 -0
- figures/main_figure_2a.py +93 -0
- figures/main_figure_3b.py +83 -0
- figures/main_figure_3c.py +83 -0
- figures/main_figure_3d.py +83 -0
- figures/plots/IRI_mean_improvement_medsam.pdf +0 -0
- figures/plots/IRI_mean_improvement_medsam.png +0 -0
- figures/plots/IRI_mean_improvement_sam.pdf +0 -0
README.md
CHANGED
@@ -1,14 +1,44 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.9.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: cc-by-nc-4.0
|
11 |
-
short_description: hakim ai by cbai
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: HakimAi
|
3 |
+
emoji: 🏥
|
4 |
+
colorFrom: blue
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: "5.9.0"
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
10 |
---
|
11 |
|
12 |
+
# HakimAi
|
13 |
+
|
14 |
+
A medical imaging analysis platform powered by BiomedParse, offering comprehensive biomedical image analysis across multiple modalities.
|
15 |
+
|
16 |
+
## Features
|
17 |
+
|
18 |
+
- **Multi-modal Analysis**: Support for various medical imaging types including X-ray, CT, MRI, pathology, and more
|
19 |
+
- **Advanced Detection**: Automated identification and segmentation of medical objects and conditions
|
20 |
+
- **Interactive Interface**: User-friendly Gradio interface for easy image upload and analysis
|
21 |
+
- **Powered by BiomedParse**: Utilizes Microsoft's BiomedParse foundation model for accurate medical image analysis
|
22 |
+
|
23 |
+
## Usage
|
24 |
+
|
25 |
+
1. Upload your medical image
|
26 |
+
2. Select the analysis type
|
27 |
+
3. View the results including segmentation masks and detection results
|
28 |
+
|
29 |
+
## Technical Details
|
30 |
+
|
31 |
+
This space uses:
|
32 |
+
- BiomedParse foundation model for medical image analysis
|
33 |
+
- Gradio for the web interface
|
34 |
+
- Git LFS for handling large files
|
35 |
+
- Python 3.9+ environment
|
36 |
+
|
37 |
+
## Model Information
|
38 |
+
|
39 |
+
Based on BiomedParse, capable of:
|
40 |
+
- Segmentation
|
41 |
+
- Detection
|
42 |
+
- Recognition across nine biomedical modalities
|
43 |
+
|
44 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
README.pdf
ADDED
Binary file (44.6 kB). View file
|
|
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
from typing import Tuple, Optional
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import sys
|
7 |
+
from pathlib import Path
|
8 |
+
import cv2
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
import spaces
|
12 |
+
# import supervision as sv
|
13 |
+
import torch
|
14 |
+
from PIL import Image
|
15 |
+
from tqdm import tqdm
|
16 |
+
import sys
|
17 |
+
from pathlib import Path
|
18 |
+
from huggingface_hub import login
|
19 |
+
# from dotenv import load_dotenv
|
20 |
+
|
21 |
+
# For Hugging Face Spaces, secrets are automatically loaded as environment variables
|
22 |
+
token = os.getenv("HF_TOKEN")
|
23 |
+
if token:
|
24 |
+
login(token=token)
|
25 |
+
# Clear Hugging Face cache
|
26 |
+
# cache_dirs = [
|
27 |
+
# "/home/user/.cache/huggingface/",
|
28 |
+
# "/home/user/.cache/torch/",
|
29 |
+
# "/home/user/.cache/pip/"
|
30 |
+
# ]
|
31 |
+
|
32 |
+
# for cache_dir in cache_dirs:
|
33 |
+
# if os.path.exists(cache_dir):
|
34 |
+
# print(f"Clearing cache: {cache_dir}")
|
35 |
+
# shutil.rmtree(cache_dir, ignore_errors=True)
|
36 |
+
# Add the current directory to Python path
|
37 |
+
current_dir = Path(__file__).parent
|
38 |
+
sys.path.append(str(current_dir))
|
39 |
+
# sys.path.append("./BiomedParse/")
|
40 |
+
# BIOMEDPARSE_PATH = Path(__file__).parent / "BiomedParse"
|
41 |
+
# sys.path.append(str(BIOMEDPARSE_PATH))
|
42 |
+
# sys.path.append(str(BIOMEDPARSE_PATH / "BiomedParse")) # Add the inner BiomedParse directory
|
43 |
+
from modeling.BaseModel import BaseModel
|
44 |
+
from modeling import build_model
|
45 |
+
from utilities.arguments import load_opt_from_config_files
|
46 |
+
from utilities.constants import BIOMED_CLASSES
|
47 |
+
from inference_utils.inference import interactive_infer_image
|
48 |
+
from inference_utils.output_processing import check_mask_stats
|
49 |
+
from inference_utils.processing_utils import read_rgb
|
50 |
+
|
51 |
+
import spaces
|
52 |
+
|
53 |
+
# breakpoint()
|
54 |
+
MARKDOWN = """
|
55 |
+
# <div style="text-align: center; font-size: 2.5em;">ሀ<span style="color: #32CD32;">A</span>ኪ<span style="color: #FFD700;">i</span>ም <sup>AI</sup></div>
|
56 |
+
|
57 |
+
<div>
|
58 |
+
<a href="https://cyberbrainai.com/">
|
59 |
+
<img src="https://cyberbrainai.com/assets/logo.svg" alt="CyberBrain AI" style="display:inline-block; width:50px; height:50px;">
|
60 |
+
</a>
|
61 |
+
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb">
|
62 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="ድinቅneሽ" style="display:inline-block;">
|
63 |
+
</a>
|
64 |
+
<a href="https://www.youtube.com/watch?v=Dv003fTyO-Y">
|
65 |
+
<img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;">
|
66 |
+
</a>
|
67 |
+
</div>
|
68 |
+
|
69 |
+
This demo integrates BiomedParse, a foundation model for joint segmentation, detection, and recognition across 9 biomedical imaging modalities. The model supports:
|
70 |
+
|
71 |
+
- Segmentation/Detection/Recognition across multiple modalities (CT, MRI, X-Ray, etc.)
|
72 |
+
- Text-prompted object detection
|
73 |
+
- Recognition of anatomical structures and abnormalities
|
74 |
+
|
75 |
+
|
76 |
+
"""
|
77 |
+
|
78 |
+
IMAGE_PROCESSING_EXAMPLES = [
|
79 |
+
["BiomedParse Segmentation",
|
80 |
+
"https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/T0011.jpg",
|
81 |
+
"Optic disc in retinal Fundus"],
|
82 |
+
["BiomedParse Segmentation",
|
83 |
+
"https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/Part_3_226_pathology_breast.png",
|
84 |
+
"optic disc, optic cup"],
|
85 |
+
["BiomedParse Segmentation",
|
86 |
+
"https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/covid_1585.png",
|
87 |
+
"COVID-19 infection in chest X-Ray"],
|
88 |
+
["BiomedParse Segmentation",
|
89 |
+
"https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/TCGA_HT_7856_19950831_8_MRI-FLAIR_brain.png",
|
90 |
+
"Lower-grade glioma in brain MRI"],
|
91 |
+
["BiomedParse Segmentation",
|
92 |
+
"https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/LIDC-IDRI-0140_143_280_CT_lung.png",
|
93 |
+
"COVID-19 infection in chest CT"],
|
94 |
+
["BiomedParse Segmentation",
|
95 |
+
"https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/144DME_as_F.jpeg",
|
96 |
+
"Cystoid macular edema in retinal OCT"],
|
97 |
+
["BiomedParse Segmentation",
|
98 |
+
"https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/Part_1_516_pathology_breast.png",
|
99 |
+
"Glandular structure in colon Pathology"],
|
100 |
+
["BiomedParse Segmentation",
|
101 |
+
"https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/ISIC_0015551.jpg",
|
102 |
+
"Melanoma in skin Dermoscopy"],
|
103 |
+
["BiomedParse Segmentation",
|
104 |
+
"https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/C3_EndoCV2021_00462.jpg",
|
105 |
+
"Neoplastic polyp in colon Endoscope"]
|
106 |
+
]
|
107 |
+
|
108 |
+
BIOMEDPARSE_MODES = {
|
109 |
+
"CT": ["abdomen", "colon", "liver", "lung", "pelvis"],
|
110 |
+
"MRI": ["brain", "heart", "prostate", "abdomen"],
|
111 |
+
"MRI-FLAIR": ["brain"],
|
112 |
+
"MRI-T1-Gd": ["brain"],
|
113 |
+
"MRI-T2": ["prostate"],
|
114 |
+
"OCT": ["retinal"],
|
115 |
+
"X-Ray": ["chest"],
|
116 |
+
"Dermoscopy": ["skin"],
|
117 |
+
"Endoscope": ["colon"],
|
118 |
+
"Fundus": ["retinal"],
|
119 |
+
"Pathology": ["bladder", "breast", "cervix", "colon", "esophagus", "kidney",
|
120 |
+
"liver", "ovarian", "prostate", "stomach", "testis", "thyroid", "uterus"],
|
121 |
+
"Ultrasound": ["breast", "heart", "transperineal"]
|
122 |
+
}
|
123 |
+
|
124 |
+
IMAGE_INFERENCE_MODES = [
|
125 |
+
"BIOMED SEGMENTATION",
|
126 |
+
"BIOMED DETECTION",
|
127 |
+
"BIOMED RECOGNITION",
|
128 |
+
"BIOMED SEGMENTATION + DETECTION",
|
129 |
+
"BIOMED SEGMENTATION + RECOGNITION",
|
130 |
+
"BIOMED DETECTION + RECOGNITION",
|
131 |
+
"BIOMED SEGMENTATION + DETECTION + RECOGNITION"
|
132 |
+
]
|
133 |
+
|
134 |
+
|
135 |
+
def on_mode_dropdown_change(selected_mode):
|
136 |
+
if selected_mode in IMAGE_INFERENCE_MODES:
|
137 |
+
# Show modality dropdown and hide other inputs initially
|
138 |
+
return [
|
139 |
+
gr.Dropdown(visible=True, choices=list(BIOMEDPARSE_MODES.keys()), label="Modality"),
|
140 |
+
gr.Dropdown(visible=True, label="Anatomical Site"),
|
141 |
+
gr.Textbox(visible=False),
|
142 |
+
gr.Textbox(visible=False)
|
143 |
+
]
|
144 |
+
else:
|
145 |
+
# Original behavior for other modes
|
146 |
+
return [
|
147 |
+
gr.Dropdown(visible=False),
|
148 |
+
gr.Dropdown(visible=False),
|
149 |
+
gr.Textbox(visible=True),
|
150 |
+
gr.Textbox(visible=(selected_mode == None))
|
151 |
+
]
|
152 |
+
|
153 |
+
def on_modality_change(modality):
|
154 |
+
if modality:
|
155 |
+
return gr.Dropdown(choices=BIOMEDPARSE_MODES[modality], visible=True)
|
156 |
+
return gr.Dropdown(visible=False)
|
157 |
+
|
158 |
+
|
159 |
+
def initialize_model():
|
160 |
+
opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"])
|
161 |
+
pretrained_pth = 'hf_hub:microsoft/BiomedParse'
|
162 |
+
opt['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
|
163 |
+
model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval()
|
164 |
+
with torch.no_grad():
|
165 |
+
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
|
166 |
+
BIOMED_CLASSES + ["background"], is_eval=True
|
167 |
+
)
|
168 |
+
return model
|
169 |
+
|
170 |
+
|
171 |
+
model = initialize_model()
|
172 |
+
|
173 |
+
|
174 |
+
# Utility functions
|
175 |
+
@spaces.GPU
|
176 |
+
@torch.inference_mode()
|
177 |
+
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
178 |
+
def process_image(image_path, text_prompts, modality):
|
179 |
+
image = read_rgb(image_path)
|
180 |
+
text_prompts = [prompt.strip() for prompt in text_prompts.split(',')]
|
181 |
+
|
182 |
+
# Run inference
|
183 |
+
pred_masks = interactive_infer_image(model, Image.fromarray(image), text_prompts)
|
184 |
+
|
185 |
+
# Prepare outputs
|
186 |
+
results = []
|
187 |
+
dice_scores = []
|
188 |
+
p_values = []
|
189 |
+
|
190 |
+
for i, prompt in enumerate(text_prompts):
|
191 |
+
# Calculate p-value for the selected modality
|
192 |
+
print("PROMPT: ", prompt, flush=True)
|
193 |
+
p_value = check_mask_stats(image, pred_masks[i] * 255, modality, prompt)
|
194 |
+
p_values.append(f"P-value for '{prompt}' ({modality}): {p_value:.4f}")
|
195 |
+
|
196 |
+
# Overlay predictions on the image
|
197 |
+
overlay_image = image.copy()
|
198 |
+
overlay_image[pred_masks[i] > 0.5] = [255, 0, 0] # Highlight predictions in red
|
199 |
+
results.append(overlay_image)
|
200 |
+
|
201 |
+
return results, p_values
|
202 |
+
|
203 |
+
# Define Gradio interface
|
204 |
+
with gr.Blocks() as demo:
|
205 |
+
gr.Markdown(MARKDOWN)
|
206 |
+
with gr.Row():
|
207 |
+
with gr.Column():
|
208 |
+
image_input = gr.Image(type="filepath", label="Input Image")
|
209 |
+
prompts_input = gr.Textbox(lines=2, placeholder="Enter prompts separated by commas...", label="Prompts")
|
210 |
+
modality_dropdown = gr.Dropdown(
|
211 |
+
choices=BIOMEDPARSE_MODES.keys(),
|
212 |
+
value=BIOMEDPARSE_MODES.keys()[0],
|
213 |
+
label="Modality"
|
214 |
+
)
|
215 |
+
submit_btn = gr.Button("Submit")
|
216 |
+
with gr.Column():
|
217 |
+
output_gallery = gr.Gallery(label="Predicted Masks")
|
218 |
+
pvalue_output = gr.Textbox(label="P-values", interactive=False)
|
219 |
+
|
220 |
+
submit_btn.click(
|
221 |
+
process_image,
|
222 |
+
inputs=[image_input, prompts_input, modality_dropdown],
|
223 |
+
outputs=[output_gallery, pvalue_output]
|
224 |
+
)
|
225 |
+
with gr.Row():
|
226 |
+
gr.Examples(
|
227 |
+
fn=process_image,
|
228 |
+
examples=IMAGE_PROCESSING_EXAMPLES,
|
229 |
+
inputs=[
|
230 |
+
image_processing_mode_dropdown_component,
|
231 |
+
image_processing_image_input_component,
|
232 |
+
image_processing_text_input_component
|
233 |
+
],
|
234 |
+
outputs=[
|
235 |
+
image_processing_image_output_component,
|
236 |
+
image_processing_text_output_component
|
237 |
+
],
|
238 |
+
run_on_click=True
|
239 |
+
)
|
240 |
+
|
241 |
+
# Launch the app
|
242 |
+
demo.launch()
|
configs/biomed_seg_lang_v1.yaml
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
# Define Test/Trainer/Saving
|
9 |
+
PIPELINE: XDecoderPipeline
|
10 |
+
TRAINER: xdecoder
|
11 |
+
SAVE_DIR: './output'
|
12 |
+
base_path: "./"
|
13 |
+
|
14 |
+
# Resume Logistic
|
15 |
+
RESUME: false
|
16 |
+
WEIGHT: false
|
17 |
+
RESUME_FROM: ''
|
18 |
+
EVAL_AT_START: false
|
19 |
+
SAVE_CHECKPOINT: True
|
20 |
+
|
21 |
+
# Logging and Debug
|
22 |
+
WANDB: False
|
23 |
+
LOG_EVERY: 100
|
24 |
+
FIND_UNUSED_PARAMETERS: false
|
25 |
+
|
26 |
+
# Speed up training
|
27 |
+
FP16: false
|
28 |
+
PORT: '36873'
|
29 |
+
|
30 |
+
# misc
|
31 |
+
LOADER:
|
32 |
+
JOINT: True
|
33 |
+
KEY_DATASET: ""
|
34 |
+
SAMPLE_PROB: "prop" # sampling probability proportional to data size. Use "equal" for each bach from all datasets
|
35 |
+
MIXING_LEVEL: 1 # num of different datasets for batch mixing on each GPU
|
36 |
+
|
37 |
+
RANDOM_SEED: 2024
|
38 |
+
|
39 |
+
STANDARD_TEXT_FOR_EVAL: False
|
40 |
+
|
41 |
+
##################
|
42 |
+
# Task settings
|
43 |
+
##################
|
44 |
+
VERBOSE: true
|
45 |
+
MODEL:
|
46 |
+
DEVICE: "cuda" # or "cpu" if no GPU available
|
47 |
+
NAME: seem_model_v1
|
48 |
+
HEAD: xdecoder_head
|
49 |
+
MASK_ON: false
|
50 |
+
KEYPOINT_ON: false
|
51 |
+
LOAD_PROPOSALS: false
|
52 |
+
DIM_PROJ: 512
|
53 |
+
TEXT:
|
54 |
+
ARCH: vlpencoder
|
55 |
+
NAME: transformer
|
56 |
+
TOKENIZER: clip
|
57 |
+
CONTEXT_LENGTH: 77 #256 # 77
|
58 |
+
WIDTH: 512 # 768 # 512
|
59 |
+
HEADS: 8
|
60 |
+
LAYERS: 12 # 6
|
61 |
+
AUTOGRESSIVE: True
|
62 |
+
BACKBONE:
|
63 |
+
NAME: focal # focal_dw # focal
|
64 |
+
PRETRAINED: ''
|
65 |
+
LOAD_PRETRAINED: false
|
66 |
+
FOCAL:
|
67 |
+
PRETRAIN_IMG_SIZE: 224
|
68 |
+
PATCH_SIZE: 4
|
69 |
+
EMBED_DIM: 192 # 96 # 192
|
70 |
+
DEPTHS: [2, 2, 18, 2] # [2, 2, 6, 2] # [2, 2, 18, 2]
|
71 |
+
FOCAL_LEVELS: [4, 4, 4, 4] # [3, 3, 3, 3] # [4, 4, 4, 4]
|
72 |
+
FOCAL_WINDOWS: [3, 3, 3, 3]
|
73 |
+
DROP_PATH_RATE: 0.3
|
74 |
+
MLP_RATIO: 4.0
|
75 |
+
DROP_RATE: 0.0
|
76 |
+
PATCH_NORM: True
|
77 |
+
USE_CONV_EMBED: True
|
78 |
+
SCALING_MODULATOR: True
|
79 |
+
USE_CHECKPOINT: False
|
80 |
+
USE_POSTLN: true
|
81 |
+
USE_POSTLN_IN_MODULATION: false
|
82 |
+
USE_LAYERSCALE: True
|
83 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
84 |
+
OUT_INDICES: [0, 1, 2, 3]
|
85 |
+
ENCODER:
|
86 |
+
NAME: transformer_encoder_fpn
|
87 |
+
IGNORE_VALUE: 255
|
88 |
+
NUM_CLASSES: 16
|
89 |
+
BINARY_CLASSES: False
|
90 |
+
LOSS_WEIGHT: 1.0
|
91 |
+
CONVS_DIM: 512
|
92 |
+
MASK_DIM: 512
|
93 |
+
NORM: "GN"
|
94 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
95 |
+
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
96 |
+
COMMON_STRIDE: 4
|
97 |
+
TRANSFORMER_ENC_LAYERS: 6
|
98 |
+
DECODER:
|
99 |
+
NAME: seem_v1
|
100 |
+
TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
|
101 |
+
MASK:
|
102 |
+
ENABLED: True
|
103 |
+
DETECTION: False
|
104 |
+
SPATIAL:
|
105 |
+
ENABLED: True
|
106 |
+
MAX_ITER: 1
|
107 |
+
GROUNDING:
|
108 |
+
ENABLED: True
|
109 |
+
MAX_LEN: 10
|
110 |
+
TEXT_WEIGHT: 2.0
|
111 |
+
CLASS_WEIGHT: 0.5
|
112 |
+
RETRIEVAL:
|
113 |
+
ENABLED: False
|
114 |
+
LVIS:
|
115 |
+
ENABLED: False
|
116 |
+
THRES: 0.7
|
117 |
+
OPENIMAGE:
|
118 |
+
ENABLED: False
|
119 |
+
NEGATIVE_SAMPLES: 5
|
120 |
+
GROUNDING:
|
121 |
+
ENABLED: False
|
122 |
+
MAX_LEN: 5
|
123 |
+
CAPTION:
|
124 |
+
ENABLED: False
|
125 |
+
PHRASE_PROB: 0.5
|
126 |
+
SIM_THRES: 0.95
|
127 |
+
DEEP_SUPERVISION: True
|
128 |
+
NO_OBJECT_WEIGHT: 0.1
|
129 |
+
GCLASS_WEIGHT: 0.4
|
130 |
+
GMASK_WEIGHT: 1.0
|
131 |
+
GDICE_WEIGHT: 1.0
|
132 |
+
SCLASS_WEIGHT: 0.4
|
133 |
+
SMASK_WEIGHT: 1.0
|
134 |
+
SDICE_WEIGHT: 1.0
|
135 |
+
OCLASS_WEIGHT: 0.4
|
136 |
+
OMASK_WEIGHT: 1.0
|
137 |
+
ODICE_WEIGHT: 1.0
|
138 |
+
CLASS_WEIGHT: 2.0
|
139 |
+
MASK_WEIGHT: 5.0
|
140 |
+
DICE_WEIGHT: 5.0
|
141 |
+
BBOX_WEIGHT: 5.0
|
142 |
+
GIOU_WEIGHT: 2.0
|
143 |
+
CAPTION_WEIGHT: 2.0
|
144 |
+
COST_SPATIAL:
|
145 |
+
CLASS_WEIGHT: 5.0
|
146 |
+
MASK_WEIGHT: 2.0
|
147 |
+
DICE_WEIGHT: 2.0
|
148 |
+
HIDDEN_DIM: 512
|
149 |
+
NUM_OBJECT_QUERIES: 101
|
150 |
+
NHEADS: 8
|
151 |
+
DROPOUT: 0.0
|
152 |
+
DIM_FEEDFORWARD: 2048
|
153 |
+
MAX_SPATIAL_LEN: [512, 512, 512, 512]
|
154 |
+
# ENC_LAYERS: 0
|
155 |
+
PRE_NORM: False
|
156 |
+
ENFORCE_INPUT_PROJ: False
|
157 |
+
SIZE_DIVISIBILITY: 32
|
158 |
+
TRAIN_NUM_POINTS: 12544
|
159 |
+
OVERSAMPLE_RATIO: 3.0
|
160 |
+
IMPORTANCE_SAMPLE_RATIO: 0.75
|
161 |
+
DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
|
162 |
+
TOP_GROUNDING_LAYERS: 10
|
163 |
+
TOP_CAPTION_LAYERS: 10
|
164 |
+
TOP_SPATIAL_LAYERS: 10
|
165 |
+
TOP_OPENIMAGE_LAYERS: 10
|
166 |
+
TEST:
|
167 |
+
SEMANTIC_ON: False
|
168 |
+
INSTANCE_ON: False
|
169 |
+
PANOPTIC_ON: False
|
170 |
+
OVERLAP_THRESHOLD: 0.8
|
171 |
+
OBJECT_MASK_THRESHOLD: 0.8
|
172 |
+
SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: true
|
173 |
+
|
174 |
+
# Spatial sampler
|
175 |
+
STROKE_SAMPLER:
|
176 |
+
MAX_CANDIDATE: 1
|
177 |
+
CANDIDATE_PROBS: [0.25, 0.25, 0.25, 0.25] # for training only
|
178 |
+
CANDIDATE_NAMES: ["Point", "Polygon", "Scribble", "Circle"]
|
179 |
+
DILATION: 3
|
180 |
+
CIRCLE:
|
181 |
+
NUM_STROKES: 5
|
182 |
+
STROKE_PRESET: ['object_like', 'object_like_middle', 'object_like_small']
|
183 |
+
STROKE_PROB: [0.33, 0.33, 0.33]
|
184 |
+
SCRIBBLE:
|
185 |
+
NUM_STROKES: 5
|
186 |
+
STROKE_PRESET: ['rand_curve', 'rand_curve_small']
|
187 |
+
STROKE_PROB: [0.5, 0.5]
|
188 |
+
POINT:
|
189 |
+
NUM_POINTS: 20
|
190 |
+
POLYGON:
|
191 |
+
MAX_POINTS: 9
|
192 |
+
EVAL:
|
193 |
+
MODE: 'best' # best/random/best_random
|
194 |
+
NEGATIVE: False
|
195 |
+
MAX_ITER: 1
|
196 |
+
IOU_ITER: 1
|
197 |
+
GROUNDING: True
|
198 |
+
|
199 |
+
# Multi-modal Architecture, order matters
|
200 |
+
ATTENTION_ARCH:
|
201 |
+
VARIABLE:
|
202 |
+
queries: ['object', 'grounding', 'spatial']
|
203 |
+
tokens: ['grounding', 'spatial']
|
204 |
+
memories: ['spatial']
|
205 |
+
SELF_ATTENTION:
|
206 |
+
queries:
|
207 |
+
object: ['queries_object']
|
208 |
+
grounding: ['queries_grounding', 'tokens_grounding']
|
209 |
+
spatial: ['queries_spatial', 'tokens_spatial', 'memories_spatial']
|
210 |
+
tokens:
|
211 |
+
grounding: ['queries_grounding', 'tokens_grounding']
|
212 |
+
spatial: ['tokens_spatial']
|
213 |
+
memories:
|
214 |
+
spatial: ['memories_spatial']
|
215 |
+
CROSS_ATTENTION:
|
216 |
+
queries:
|
217 |
+
object: True
|
218 |
+
grounding: True
|
219 |
+
spatial: True
|
220 |
+
memories:
|
221 |
+
spatial: True
|
222 |
+
tokens:
|
223 |
+
grounding: False
|
224 |
+
spatial: False
|
225 |
+
MASKING: ['tokens_spatial', 'tokens_grounding']
|
226 |
+
DUPLICATION:
|
227 |
+
queries:
|
228 |
+
grounding: 'queries_object'
|
229 |
+
spatial: 'queries_object'
|
230 |
+
SPATIAL_MEMORIES: 32
|
231 |
+
QUERY_NUMBER: 3
|
232 |
+
|
233 |
+
DATASETS:
|
234 |
+
TRAIN: [
|
235 |
+
'biomed_BiomedParseData-Demo_demo' # Add your registered training datasets here
|
236 |
+
]
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
TEST: [
|
241 |
+
'biomed_BiomedParseData-Demo_demo' # Add your registered test datasets here
|
242 |
+
]
|
243 |
+
|
244 |
+
CLASS_CONCAT: false
|
245 |
+
SIZE_DIVISIBILITY: 32
|
246 |
+
PROPOSAL_FILES_TRAIN: []
|
247 |
+
|
248 |
+
INPUT:
|
249 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
250 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
251 |
+
|
252 |
+
TRAIN:
|
253 |
+
ASPECT_RATIO_GROUPING: true
|
254 |
+
BATCH_SIZE_TOTAL: 4
|
255 |
+
BATCH_SIZE_PER_GPU: 4
|
256 |
+
SHUFFLE: true
|
257 |
+
|
258 |
+
TEST:
|
259 |
+
DETECTIONS_PER_IMAGE: 100
|
260 |
+
NAME: coco_eval
|
261 |
+
IOU_TYPE: ['bbox', 'segm']
|
262 |
+
USE_MULTISCALE: false
|
263 |
+
BATCH_SIZE_TOTAL: 4
|
264 |
+
MODEL_FILE: ''
|
265 |
+
AUG:
|
266 |
+
ENABLED: False
|
267 |
+
|
268 |
+
DATALOADER:
|
269 |
+
FILTER_EMPTY_ANNOTATIONS: False
|
270 |
+
NUM_WORKERS: 8
|
271 |
+
LOAD_PROPOSALS: False
|
272 |
+
SAMPLER_TRAIN: "TrainingSampler"
|
273 |
+
ASPECT_RATIO_GROUPING: True
|
274 |
+
|
275 |
+
|
276 |
+
BioMed:
|
277 |
+
INPUT:
|
278 |
+
PIXEL_MEAN: [64.284, 59.293, 59.962]
|
279 |
+
PIXEL_STD: [62.484, 60.865, 59.835]
|
280 |
+
DATASET_MAPPER_NAME: "biomed_interactive"
|
281 |
+
MIN_SIZE_TRAIN: 900
|
282 |
+
MAX_SIZE_TRAIN: 1100
|
283 |
+
MIN_SIZE_TRAIN_SAMPLING: 'choice'
|
284 |
+
MIN_SIZE_TEST: 900
|
285 |
+
MAX_SIZE_TEST: 1100
|
286 |
+
IMAGE_SIZE: 1024
|
287 |
+
MIN_SCALE: 0.9
|
288 |
+
MAX_SCALE: 1.1
|
289 |
+
IGNORE_VALUE: 255
|
290 |
+
COLOR_AUG_SSD: False
|
291 |
+
SIZE_DIVISIBILITY: 32
|
292 |
+
RANDOM_FLIP: "none"
|
293 |
+
RANDOM_ROTATE: False
|
294 |
+
MASK_FORMAT: "polygon"
|
295 |
+
MIN_AREA: 30
|
296 |
+
FORMAT: "RGB"
|
297 |
+
SPATIAL: True
|
298 |
+
CROP:
|
299 |
+
ENABLED: True
|
300 |
+
DATASET:
|
301 |
+
DATASET: "biomed"
|
302 |
+
|
303 |
+
|
304 |
+
# Detectron2 training config for optimizer and lr scheduler
|
305 |
+
SOLVER:
|
306 |
+
BASE_LR: 0.0001
|
307 |
+
STEPS: [0.88889, 0.96296]
|
308 |
+
MAX_ITER: 1
|
309 |
+
GAMMA: 0.1
|
310 |
+
WARMUP_FACTOR: 1.0
|
311 |
+
WARMUP_ITERS: 10
|
312 |
+
WARMUP_METHOD: "linear"
|
313 |
+
WEIGHT_DECAY: 0.05
|
314 |
+
OPTIMIZER: "ADAMW"
|
315 |
+
LR_SCHEDULER_NAME: "WarmupMultiStepLR"
|
316 |
+
LR_MULTIPLIER:
|
317 |
+
backbone: 0.1
|
318 |
+
lang_encoder: 0.1
|
319 |
+
FIX_PARAM:
|
320 |
+
backbone: True
|
321 |
+
lang_encoder: True
|
322 |
+
pixel_decoder: True
|
323 |
+
WEIGHT_DECAY_NORM: 0.0
|
324 |
+
WEIGHT_DECAY_EMBED: 0.0
|
325 |
+
CLIP_GRADIENTS:
|
326 |
+
ENABLED: True
|
327 |
+
CLIP_TYPE: "full_model"
|
328 |
+
CLIP_VALUE: 5.0 # 0.01
|
329 |
+
NORM_TYPE: 2.0
|
330 |
+
MAX_NUM_EPOCHS: 50
|
configs/biomedparse_inference.yaml
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Define Test/Trainer/Saving
|
2 |
+
PIPELINE: XDecoderPipeline
|
3 |
+
TRAINER: xdecoder
|
4 |
+
SAVE_DIR: '../../data/output/test'
|
5 |
+
base_path: "./"
|
6 |
+
|
7 |
+
# Resume Logistic
|
8 |
+
RESUME: false
|
9 |
+
WEIGHT: false
|
10 |
+
RESUME_FROM: ''
|
11 |
+
EVAL_AT_START: false
|
12 |
+
|
13 |
+
# Logging and Debug
|
14 |
+
WANDB: False
|
15 |
+
LOG_EVERY: 100
|
16 |
+
FIND_UNUSED_PARAMETERS: false
|
17 |
+
|
18 |
+
# Speed up training
|
19 |
+
FP16: false
|
20 |
+
PORT: '36873'
|
21 |
+
|
22 |
+
# misc
|
23 |
+
LOADER:
|
24 |
+
JOINT: False
|
25 |
+
KEY_DATASET: 'coco'
|
26 |
+
|
27 |
+
STANDARD_TEXT_FOR_EVAL: False
|
28 |
+
|
29 |
+
##################
|
30 |
+
# Task settings
|
31 |
+
##################
|
32 |
+
VERBOSE: true
|
33 |
+
MODEL:
|
34 |
+
device: "cuda" # or "cpu" if no GPU available
|
35 |
+
DEVICE: "cuda" # or "cpu" if no GPU available
|
36 |
+
NAME: seem_model_demo
|
37 |
+
HEAD: xdecoder_head
|
38 |
+
DIM_PROJ: 512
|
39 |
+
TEXT:
|
40 |
+
ARCH: vlpencoder
|
41 |
+
NAME: transformer
|
42 |
+
TOKENIZER: clip
|
43 |
+
CONTEXT_LENGTH: 77 # 77
|
44 |
+
WIDTH: 512
|
45 |
+
HEADS: 8
|
46 |
+
LAYERS: 12 # 6
|
47 |
+
AUTOGRESSIVE: True
|
48 |
+
BACKBONE:
|
49 |
+
NAME: focal
|
50 |
+
PRETRAINED: ''
|
51 |
+
LOAD_PRETRAINED: false
|
52 |
+
FOCAL:
|
53 |
+
PRETRAIN_IMG_SIZE: 224
|
54 |
+
PATCH_SIZE: 4
|
55 |
+
EMBED_DIM: 192
|
56 |
+
DEPTHS: [2, 2, 18, 2]
|
57 |
+
FOCAL_LEVELS: [4, 4, 4, 4]
|
58 |
+
FOCAL_WINDOWS: [3, 3, 3, 3]
|
59 |
+
DROP_PATH_RATE: 0.3
|
60 |
+
MLP_RATIO: 4.0
|
61 |
+
DROP_RATE: 0.0
|
62 |
+
PATCH_NORM: True
|
63 |
+
USE_CONV_EMBED: True
|
64 |
+
SCALING_MODULATOR: True
|
65 |
+
USE_CHECKPOINT: False
|
66 |
+
USE_POSTLN: true
|
67 |
+
USE_POSTLN_IN_MODULATION: false
|
68 |
+
USE_LAYERSCALE: True
|
69 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
70 |
+
OUT_INDICES: [0, 1, 2, 3]
|
71 |
+
ENCODER:
|
72 |
+
NAME: transformer_encoder_fpn
|
73 |
+
IGNORE_VALUE: 255
|
74 |
+
NUM_CLASSES: 16
|
75 |
+
BINARY_CLASSES: False
|
76 |
+
LOSS_WEIGHT: 1.0
|
77 |
+
CONVS_DIM: 512
|
78 |
+
MASK_DIM: 512
|
79 |
+
NORM: "GN"
|
80 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
81 |
+
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
82 |
+
COMMON_STRIDE: 4
|
83 |
+
TRANSFORMER_ENC_LAYERS: 6
|
84 |
+
DECODER:
|
85 |
+
NAME: seem_demo
|
86 |
+
TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
|
87 |
+
MASK:
|
88 |
+
ENABLED: False
|
89 |
+
DETECTION: False
|
90 |
+
SPATIAL:
|
91 |
+
ENABLED: True
|
92 |
+
MAX_ITER: 1
|
93 |
+
GROUNDING:
|
94 |
+
ENABLED: True
|
95 |
+
MAX_LEN: 5
|
96 |
+
TEXT_WEIGHT: 2.0
|
97 |
+
CLASS_WEIGHT: 0.5
|
98 |
+
VISUAL:
|
99 |
+
ENABLED: False
|
100 |
+
AUDIO:
|
101 |
+
ENABLED: False
|
102 |
+
RETRIEVAL:
|
103 |
+
ENABLED: False
|
104 |
+
LVIS:
|
105 |
+
ENABLED: True
|
106 |
+
THRES: 0.7
|
107 |
+
OPENIMAGE:
|
108 |
+
ENABLED: False
|
109 |
+
NEGATIVE_SAMPLES: 5
|
110 |
+
GROUNDING:
|
111 |
+
ENABLED: False
|
112 |
+
MAX_LEN: 5
|
113 |
+
CAPTION:
|
114 |
+
ENABLED: False
|
115 |
+
PHRASE_PROB: 0.5
|
116 |
+
SIM_THRES: 0.95
|
117 |
+
DEEP_SUPERVISION: True
|
118 |
+
NO_OBJECT_WEIGHT: 0.1
|
119 |
+
GCLASS_WEIGHT: 0.4
|
120 |
+
GMASK_WEIGHT: 1.0
|
121 |
+
GDICE_WEIGHT: 1.0
|
122 |
+
SCLASS_WEIGHT: 0.4
|
123 |
+
SMASK_WEIGHT: 1.0
|
124 |
+
SDICE_WEIGHT: 1.0
|
125 |
+
OCLASS_WEIGHT: 0.4
|
126 |
+
OMASK_WEIGHT: 1.0
|
127 |
+
ODICE_WEIGHT: 1.0
|
128 |
+
CLASS_WEIGHT: 2.0
|
129 |
+
MASK_WEIGHT: 5.0
|
130 |
+
DICE_WEIGHT: 5.0
|
131 |
+
BBOX_WEIGHT: 5.0
|
132 |
+
GIOU_WEIGHT: 2.0
|
133 |
+
CAPTION_WEIGHT: 2.0
|
134 |
+
COST_SPATIAL:
|
135 |
+
CLASS_WEIGHT: 5.0
|
136 |
+
MASK_WEIGHT: 2.0
|
137 |
+
DICE_WEIGHT: 2.0
|
138 |
+
HIDDEN_DIM: 512
|
139 |
+
NUM_OBJECT_QUERIES: 101
|
140 |
+
NHEADS: 8
|
141 |
+
DROPOUT: 0.0
|
142 |
+
DIM_FEEDFORWARD: 2048
|
143 |
+
MAX_SPATIAL_LEN: [512, 512, 512, 512]
|
144 |
+
# ENC_LAYERS: 0
|
145 |
+
PRE_NORM: False
|
146 |
+
ENFORCE_INPUT_PROJ: False
|
147 |
+
SIZE_DIVISIBILITY: 32
|
148 |
+
TRAIN_NUM_POINTS: 12544
|
149 |
+
OVERSAMPLE_RATIO: 3.0
|
150 |
+
IMPORTANCE_SAMPLE_RATIO: 0.75
|
151 |
+
DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
|
152 |
+
TOP_GROUNDING_LAYERS: 10
|
153 |
+
TOP_CAPTION_LAYERS: 10
|
154 |
+
TOP_SPATIAL_LAYERS: 10
|
155 |
+
TOP_OPENIMAGE_LAYERS: 10
|
156 |
+
TEST:
|
157 |
+
SEMANTIC_ON: True
|
158 |
+
INSTANCE_ON: True
|
159 |
+
PANOPTIC_ON: True
|
160 |
+
OVERLAP_THRESHOLD: 0.8
|
161 |
+
OBJECT_MASK_THRESHOLD: 0.4
|
162 |
+
SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false
|
163 |
+
DETECTIONS_PER_IMAGE: 100
|
164 |
+
|
165 |
+
# Multi-modal Architecture, order matters
|
166 |
+
ATTENTION_ARCH:
|
167 |
+
VARIABLE:
|
168 |
+
queries: ['object']
|
169 |
+
tokens: ['grounding', 'spatial', 'visual', 'audio']
|
170 |
+
SELF_ATTENTION:
|
171 |
+
queries:
|
172 |
+
object: ['queries_object', 'tokens_grounding', 'tokens_spatial', 'tokens_visual', 'tokens_audio']
|
173 |
+
tokens:
|
174 |
+
grounding: ['queries_object', 'tokens_grounding']
|
175 |
+
spatial: ['tokens_spatial']
|
176 |
+
visual: ['tokens_visual']
|
177 |
+
audio: ['queries_object', 'tokens_audio']
|
178 |
+
CROSS_ATTENTION:
|
179 |
+
queries:
|
180 |
+
object: True
|
181 |
+
tokens:
|
182 |
+
grounding: False
|
183 |
+
spatial: False
|
184 |
+
visual: False
|
185 |
+
audio: False
|
186 |
+
MASKING: ['tokens_spatial', 'tokens_grounding', 'tokens_visual', 'tokens_audio']
|
187 |
+
DUPLICATION:
|
188 |
+
queries:
|
189 |
+
grounding: 'queries_object'
|
190 |
+
spatial: 'queries_object'
|
191 |
+
SPATIAL_MEMORIES: 32
|
192 |
+
|
193 |
+
INPUT:
|
194 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
195 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
196 |
+
# INPUT:
|
197 |
+
# PIXEL_MEAN: [64.284, 59.293, 59.962]
|
198 |
+
# PIXEL_STD: [62.484, 60.865, 59.835]
|
datasets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from . import registration
|
2 |
+
from .build import build_train_dataloader, build_eval_dataloader, build_evaluator
|
datasets/build.py
ADDED
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Modified by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import itertools
|
12 |
+
import logging
|
13 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.utils.data
|
17 |
+
import torch.utils.data as torchdata
|
18 |
+
|
19 |
+
import detectron2.utils.comm as comm
|
20 |
+
from detectron2.data.build import (
|
21 |
+
build_batch_data_loader,
|
22 |
+
load_proposals_into_dataset,
|
23 |
+
trivial_batch_collator,
|
24 |
+
)
|
25 |
+
from detectron2.data import MetadataCatalog
|
26 |
+
from detectron2.data.catalog import DatasetCatalog
|
27 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
28 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
29 |
+
from detectron2.data.samplers import InferenceSampler, TrainingSampler
|
30 |
+
from detectron2.evaluation import (
|
31 |
+
CityscapesInstanceEvaluator,
|
32 |
+
CityscapesSemSegEvaluator,
|
33 |
+
COCOEvaluator,
|
34 |
+
DatasetEvaluators,
|
35 |
+
LVISEvaluator,
|
36 |
+
verify_results,
|
37 |
+
)
|
38 |
+
from fvcore.common.config import CfgNode
|
39 |
+
|
40 |
+
from .dataset_mappers import *
|
41 |
+
from .evaluation import (InstanceSegEvaluator,
|
42 |
+
ClassificationEvaluator,
|
43 |
+
SemSegEvaluator,
|
44 |
+
RetrievalEvaluator,
|
45 |
+
#CaptioningEvaluator,
|
46 |
+
COCOPanopticEvaluator,
|
47 |
+
GroundingEvaluator,
|
48 |
+
InteractiveEvaluator,
|
49 |
+
)
|
50 |
+
from modeling.utils import configurable
|
51 |
+
from utilities.distributed import get_world_size
|
52 |
+
|
53 |
+
class JointLoader(torchdata.IterableDataset):
|
54 |
+
"""
|
55 |
+
Randomly sampple from one of the dataloaders per worker in each iteration.
|
56 |
+
The sampling probability is determined by the size of each dataset.
|
57 |
+
All examples from one worker (GPU) are from the same dataset in the iteration.
|
58 |
+
Mixing is achieved through multiple workers (GPUs).
|
59 |
+
"""
|
60 |
+
def __init__(self, loaders, key_dataset, sample_prob, mixing_level):
|
61 |
+
dataset_names = []
|
62 |
+
for key, loader in loaders.items():
|
63 |
+
name = "{}".format(key.split('_')[0])
|
64 |
+
setattr(self, name, loader)
|
65 |
+
dataset_names += [name]
|
66 |
+
self.dataset_names = dataset_names
|
67 |
+
self.key_dataset = key_dataset
|
68 |
+
if sample_prob == 'prop':
|
69 |
+
self.sample_prob = [len(getattr(self, key)) for key in self.dataset_names]
|
70 |
+
elif sample_prob == 'equal':
|
71 |
+
self.sample_prob = [1 for key in self.dataset_names]
|
72 |
+
elif sample_prob == 'sqrt':
|
73 |
+
self.sample_prob = [np.sqrt(len(getattr(self, key))) for key in self.dataset_names]
|
74 |
+
self.sample_prob = [p/sum(self.sample_prob) for p in self.sample_prob]
|
75 |
+
self.mixing_level = mixing_level
|
76 |
+
|
77 |
+
# Not sure how expensive `len(getattr(self, name))` is. computing this once and cache.
|
78 |
+
# this assumes the len of the underlying data loaders do not change.
|
79 |
+
self._len = sum(len(getattr(self, name)) for name in self.dataset_names)
|
80 |
+
|
81 |
+
def __iter__(self):
|
82 |
+
# Reset iterators at the start of each new epoch
|
83 |
+
self.iterators = {name: iter(getattr(self, name)) for name in self.dataset_names}
|
84 |
+
self._count = 0
|
85 |
+
return self
|
86 |
+
|
87 |
+
def __next__(self):
|
88 |
+
while self._count < self._len:
|
89 |
+
# Randomly select a dataloader
|
90 |
+
name = np.random.choice(self.dataset_names, size=None, replace=False, p=self.sample_prob)
|
91 |
+
iterator = self.iterators[name]
|
92 |
+
|
93 |
+
try:
|
94 |
+
# Get next batch from the selected dataloader
|
95 |
+
self._count += 1
|
96 |
+
return next(iterator)
|
97 |
+
except StopIteration:
|
98 |
+
# If the selected dataloader is exhausted, reinitialize it
|
99 |
+
self.iterators[name] = iter(getattr(self, name))
|
100 |
+
raise StopIteration
|
101 |
+
|
102 |
+
def __len__(self):
|
103 |
+
return self._len
|
104 |
+
|
105 |
+
def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names):
|
106 |
+
"""
|
107 |
+
Filter out images with none annotations or only crowd annotations
|
108 |
+
(i.e., images without non-crowd annotations).
|
109 |
+
A common training-time preprocessing on COCO dataset.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
list[dict]: the same format, but filtered.
|
116 |
+
"""
|
117 |
+
num_before = len(dataset_dicts)
|
118 |
+
|
119 |
+
def valid(anns):
|
120 |
+
for ann in anns:
|
121 |
+
if isinstance(ann, list):
|
122 |
+
for instance in ann:
|
123 |
+
if instance.get("iscrowd", 0) == 0:
|
124 |
+
return True
|
125 |
+
else:
|
126 |
+
if ann.get("iscrowd", 0) == 0:
|
127 |
+
return True
|
128 |
+
return False
|
129 |
+
|
130 |
+
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
|
131 |
+
num_after = len(dataset_dicts)
|
132 |
+
logger = logging.getLogger(__name__)
|
133 |
+
logger.info(
|
134 |
+
"Removed {} images with no usable annotations. {} images left.".format(
|
135 |
+
num_before - num_after, num_after
|
136 |
+
)
|
137 |
+
)
|
138 |
+
return dataset_dicts
|
139 |
+
|
140 |
+
|
141 |
+
def get_detection_dataset_dicts(
|
142 |
+
dataset_names, filter_empty=True, proposal_files=None
|
143 |
+
):
|
144 |
+
"""
|
145 |
+
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
dataset_names (str or list[str]): a dataset name or a list of dataset names
|
149 |
+
filter_empty (bool): whether to filter out images without instance annotations
|
150 |
+
proposal_files (list[str]): if given, a list of object proposal files
|
151 |
+
that match each dataset in `dataset_names`.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
list[dict]: a list of dicts following the standard dataset dict format.
|
155 |
+
"""
|
156 |
+
if isinstance(dataset_names, str):
|
157 |
+
dataset_names = [dataset_names]
|
158 |
+
assert len(dataset_names)
|
159 |
+
|
160 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
|
161 |
+
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
|
162 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
163 |
+
|
164 |
+
if proposal_files is not None:
|
165 |
+
assert len(dataset_names) == len(proposal_files)
|
166 |
+
# load precomputed proposals from proposal files
|
167 |
+
dataset_dicts = [
|
168 |
+
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
|
169 |
+
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
|
170 |
+
]
|
171 |
+
|
172 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
173 |
+
|
174 |
+
has_instances = "annotations" in dataset_dicts[0]
|
175 |
+
if filter_empty and has_instances:
|
176 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names)
|
177 |
+
|
178 |
+
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names))
|
179 |
+
return dataset_dicts
|
180 |
+
|
181 |
+
|
182 |
+
def _test_loader_from_config(cfg, dataset_name, mapper=None):
|
183 |
+
"""
|
184 |
+
Uses the given `dataset_name` argument (instead of the names in cfg), because the
|
185 |
+
standard practice is to evaluate each test set individually (not combining them).
|
186 |
+
"""
|
187 |
+
if isinstance(dataset_name, str):
|
188 |
+
dataset_name = [dataset_name]
|
189 |
+
|
190 |
+
dataset = get_detection_dataset_dicts(
|
191 |
+
dataset_name,
|
192 |
+
filter_empty=False,
|
193 |
+
proposal_files=None,
|
194 |
+
)
|
195 |
+
if mapper is None:
|
196 |
+
mapper_cfg = CfgNode({'INPUT': cfg['INPUT'], 'MODEL': cfg['MODEL'], 'DATASETS': cfg['DATASETS']})
|
197 |
+
mapper = DatasetMapper(mapper_cfg, False)
|
198 |
+
assert cfg['TEST']['BATCH_SIZE_TOTAL'] % get_world_size() == 0, "Evaluation total batchsize is not divisible by gpu number"
|
199 |
+
#batch_size = cfg['TEST']['BATCH_SIZE_TOTAL'] // get_world_size()
|
200 |
+
batch_size = 1
|
201 |
+
|
202 |
+
return {
|
203 |
+
"dataset": dataset,
|
204 |
+
"mapper": mapper,
|
205 |
+
"num_workers": cfg['DATALOADER']['NUM_WORKERS'],
|
206 |
+
"sampler": InferenceSampler(len(dataset)),
|
207 |
+
"batch_size": batch_size,
|
208 |
+
}
|
209 |
+
|
210 |
+
|
211 |
+
@configurable(from_config=_test_loader_from_config)
|
212 |
+
def build_detection_test_loader(
|
213 |
+
dataset: Union[List[Any], torchdata.Dataset],
|
214 |
+
*,
|
215 |
+
mapper: Callable[[Dict[str, Any]], Any],
|
216 |
+
sampler: Optional[torchdata.Sampler] = None,
|
217 |
+
batch_size: int = 1,
|
218 |
+
num_workers: int = 0,
|
219 |
+
collate_fn: Optional[Callable[[List[Any]], Any]] = None,
|
220 |
+
) -> torchdata.DataLoader:
|
221 |
+
"""
|
222 |
+
Similar to `build_detection_train_loader`, with default batch size = 1,
|
223 |
+
and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
|
224 |
+
to produce the exact set of all samples.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
dataset: a list of dataset dicts,
|
228 |
+
or a pytorch dataset (either map-style or iterable). They can be obtained
|
229 |
+
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
230 |
+
mapper: a callable which takes a sample (dict) from dataset
|
231 |
+
and returns the format to be consumed by the model.
|
232 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
|
233 |
+
sampler: a sampler that produces
|
234 |
+
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
|
235 |
+
which splits the dataset across all workers. Sampler must be None
|
236 |
+
if `dataset` is iterable.
|
237 |
+
batch_size: the batch size of the data loader to be created.
|
238 |
+
Default to 1 image per worker since this is the standard when reporting
|
239 |
+
inference time in papers.
|
240 |
+
num_workers: number of parallel data loading workers
|
241 |
+
collate_fn: same as the argument of `torch.utils.data.DataLoader`.
|
242 |
+
Defaults to do no collation and return a list of data.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
246 |
+
dataset, with test-time transformation and batching.
|
247 |
+
|
248 |
+
Examples:
|
249 |
+
::
|
250 |
+
data_loader = build_detection_test_loader(
|
251 |
+
DatasetRegistry.get("my_test"),
|
252 |
+
mapper=DatasetMapper(...))
|
253 |
+
|
254 |
+
# or, instantiate with a CfgNode:
|
255 |
+
data_loader = build_detection_test_loader(cfg, "my_test")
|
256 |
+
"""
|
257 |
+
|
258 |
+
if isinstance(dataset, list):
|
259 |
+
dataset = DatasetFromList(dataset, copy=False)
|
260 |
+
if mapper is not None:
|
261 |
+
dataset = MapDataset(dataset, mapper)
|
262 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
263 |
+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
|
264 |
+
else:
|
265 |
+
if sampler is None:
|
266 |
+
sampler = InferenceSampler(len(dataset))
|
267 |
+
return torchdata.DataLoader(
|
268 |
+
dataset,
|
269 |
+
batch_size=batch_size,
|
270 |
+
sampler=sampler,
|
271 |
+
drop_last=False,
|
272 |
+
num_workers=num_workers,
|
273 |
+
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
|
274 |
+
)
|
275 |
+
|
276 |
+
|
277 |
+
def _train_loader_from_config(cfg, dataset_name, mapper, *, dataset=None, sampler=None):
|
278 |
+
cfg_datasets = cfg['DATASETS']
|
279 |
+
cfg_dataloader = cfg['DATALOADER']
|
280 |
+
|
281 |
+
if dataset is None:
|
282 |
+
dataset = get_detection_dataset_dicts(
|
283 |
+
dataset_name,
|
284 |
+
filter_empty=cfg_dataloader['FILTER_EMPTY_ANNOTATIONS'],
|
285 |
+
proposal_files=cfg_datasets['PROPOSAL_FILES_TRAIN'] if cfg_dataloader['LOAD_PROPOSALS'] else None,
|
286 |
+
)
|
287 |
+
|
288 |
+
if mapper is None:
|
289 |
+
mapper = DatasetMapper(cfg, True)
|
290 |
+
|
291 |
+
if sampler is None:
|
292 |
+
sampler_name = cfg_dataloader['SAMPLER_TRAIN']
|
293 |
+
logger = logging.getLogger(__name__)
|
294 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
295 |
+
sampler = TrainingSampler(len(dataset))
|
296 |
+
|
297 |
+
return {
|
298 |
+
"dataset": dataset,
|
299 |
+
"sampler": sampler,
|
300 |
+
"mapper": mapper,
|
301 |
+
"total_batch_size": cfg['TRAIN']['BATCH_SIZE_TOTAL'],
|
302 |
+
"aspect_ratio_grouping": cfg_dataloader['ASPECT_RATIO_GROUPING'],
|
303 |
+
"num_workers": cfg_dataloader['NUM_WORKERS'],
|
304 |
+
}
|
305 |
+
|
306 |
+
|
307 |
+
@configurable(from_config=_train_loader_from_config)
|
308 |
+
def build_detection_train_loader(
|
309 |
+
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
|
310 |
+
):
|
311 |
+
"""
|
312 |
+
Build a dataloader for object detection with some default features.
|
313 |
+
This interface is experimental.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
317 |
+
or a map-style pytorch dataset. They can be obtained by using
|
318 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
319 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
320 |
+
returns the format to be consumed by the model.
|
321 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
|
322 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that
|
323 |
+
produces indices to be applied on ``dataset``.
|
324 |
+
Default to :class:`TrainingSampler`, which coordinates a random shuffle
|
325 |
+
sequence across all workers.
|
326 |
+
total_batch_size (int): total batch size across all workers. Batching
|
327 |
+
simply puts data into a list.
|
328 |
+
aspect_ratio_grouping (bool): whether to group images with similar
|
329 |
+
aspect ratio for efficiency. When enabled, it requires each
|
330 |
+
element in dataset be a dict with keys "width" and "height".
|
331 |
+
num_workers (int): number of parallel data loading workers
|
332 |
+
|
333 |
+
Returns:
|
334 |
+
torch.utils.data.DataLoader: a dataloader. Each output from it is a
|
335 |
+
``list[mapped_element]`` of length ``total_batch_size / num_workers``,
|
336 |
+
where ``mapped_element`` is produced by the ``mapper``.
|
337 |
+
"""
|
338 |
+
if isinstance(dataset, list):
|
339 |
+
dataset = DatasetFromList(dataset, copy=False)
|
340 |
+
if mapper is not None:
|
341 |
+
dataset = MapDataset(dataset, mapper)
|
342 |
+
if sampler is None:
|
343 |
+
sampler = TrainingSampler(len(dataset))
|
344 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
345 |
+
return build_batch_data_loader(
|
346 |
+
dataset,
|
347 |
+
sampler,
|
348 |
+
total_batch_size,
|
349 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
350 |
+
num_workers=num_workers,
|
351 |
+
)
|
352 |
+
|
353 |
+
|
354 |
+
def get_config_from_name(cfg, dataset_name):
|
355 |
+
# adjust config according to dataset
|
356 |
+
if 'refcoco' in dataset_name:
|
357 |
+
cfg.update(cfg['REF'])
|
358 |
+
return cfg
|
359 |
+
elif 'cocomini' in dataset_name:
|
360 |
+
cfg.update(cfg['DAVIS'])
|
361 |
+
return cfg
|
362 |
+
elif 'ytvos' in dataset_name:
|
363 |
+
cfg.update(cfg['VOS'])
|
364 |
+
return cfg
|
365 |
+
elif 'ade600' in dataset_name:
|
366 |
+
cfg.update(cfg['DAVIS'])
|
367 |
+
return cfg
|
368 |
+
elif 'openimage600' in dataset_name:
|
369 |
+
cfg.update(cfg['DAVIS'])
|
370 |
+
return cfg
|
371 |
+
elif 'ade' in dataset_name:
|
372 |
+
if 'ADE20K' in cfg.keys():
|
373 |
+
cfg.update(cfg['ADE20K'])
|
374 |
+
return cfg
|
375 |
+
elif 'imagenet' in dataset_name:
|
376 |
+
if 'IMAGENET' in cfg.keys():
|
377 |
+
cfg.update(cfg['IMAGENET'])
|
378 |
+
return cfg
|
379 |
+
elif 'vlp' in dataset_name:
|
380 |
+
cfg.update(cfg['VLP'])
|
381 |
+
return cfg
|
382 |
+
elif 'coco' in dataset_name:
|
383 |
+
if 'COCO' in cfg.keys():
|
384 |
+
cfg.update(cfg['COCO'])
|
385 |
+
return cfg
|
386 |
+
elif 'voc' in dataset_name:
|
387 |
+
cfg.update(cfg['VOC'])
|
388 |
+
return cfg
|
389 |
+
elif 'context' in dataset_name:
|
390 |
+
cfg.update(cfg['CONTEXT'])
|
391 |
+
return cfg
|
392 |
+
elif 'sun' in dataset_name:
|
393 |
+
cfg.update(cfg['SUN'])
|
394 |
+
return cfg
|
395 |
+
elif 'scan' in dataset_name:
|
396 |
+
cfg.update(cfg['SCAN'])
|
397 |
+
return cfg
|
398 |
+
elif 'cityscape' in dataset_name:
|
399 |
+
cfg.update(cfg['CITY'])
|
400 |
+
return cfg
|
401 |
+
elif 'bdd' in dataset_name:
|
402 |
+
cfg.update(cfg['BDD'])
|
403 |
+
return cfg
|
404 |
+
elif 'tsv' in dataset_name:
|
405 |
+
cfg.update(cfg['TSV'])
|
406 |
+
return cfg
|
407 |
+
elif 'phrasecut' in dataset_name:
|
408 |
+
cfg.update(cfg['PHRASE'])
|
409 |
+
return cfg
|
410 |
+
elif 'object365' in dataset_name:
|
411 |
+
cfg.update(cfg['OBJECT365'])
|
412 |
+
return cfg
|
413 |
+
elif 'openimage' in dataset_name:
|
414 |
+
cfg.update(cfg['OPENIMAGE'])
|
415 |
+
return cfg
|
416 |
+
elif 'lvis' in dataset_name:
|
417 |
+
cfg.update(cfg['LVIS'])
|
418 |
+
return cfg
|
419 |
+
elif 'seginw' in dataset_name:
|
420 |
+
cfg.update(cfg['SEGINW'])
|
421 |
+
return cfg
|
422 |
+
elif 'sbd' in dataset_name:
|
423 |
+
cfg.update(cfg['SBD'])
|
424 |
+
return cfg
|
425 |
+
elif 'davis' in dataset_name:
|
426 |
+
cfg.update(cfg['DAVIS'])
|
427 |
+
return cfg
|
428 |
+
elif 'med_sam' in dataset_name:
|
429 |
+
cfg.update(cfg['MedSAM'])
|
430 |
+
return cfg
|
431 |
+
elif 'biomed' in dataset_name:
|
432 |
+
cfg.update(cfg['BioMed'])
|
433 |
+
return cfg
|
434 |
+
elif 'sam' in dataset_name:
|
435 |
+
cfg.update(cfg['SAM'])
|
436 |
+
return cfg
|
437 |
+
else:
|
438 |
+
assert False, "dataset not support."
|
439 |
+
|
440 |
+
|
441 |
+
def build_eval_dataloader(cfg, ):
|
442 |
+
dataloaders = []
|
443 |
+
for dataset_name in cfg['DATASETS']['TEST']:
|
444 |
+
cfg = get_config_from_name(cfg, dataset_name)
|
445 |
+
# adjust mapper according to dataset
|
446 |
+
if dataset_name == 'imagenet_val':
|
447 |
+
mapper = ImageNetDatasetMapper(cfg, False)
|
448 |
+
elif dataset_name == 'bdd10k_val_sem_seg':
|
449 |
+
mapper = BDDSemDatasetMapper(cfg, False)
|
450 |
+
elif dataset_name in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017"]:
|
451 |
+
mapper = VLPreDatasetMapper(cfg, False, dataset_name)
|
452 |
+
elif dataset_name in ["scannet_21_val_seg", "scannet_38_val_seg", "scannet_41_val_seg"]:
|
453 |
+
mapper = ScanNetSegDatasetMapper(cfg, False)
|
454 |
+
elif dataset_name in ["scannet_21_panoptic_val", 'bdd10k_40_panoptic_val']:
|
455 |
+
mapper = ScanNetPanoDatasetMapper(cfg, False)
|
456 |
+
elif "pascalvoc_val" in dataset_name:
|
457 |
+
mapper = PascalVOCSegDatasetMapperIX(cfg, False, dataset_name)
|
458 |
+
elif 'sun' in dataset_name:
|
459 |
+
mapper = SunRGBDSegDatasetMapper(cfg, False)
|
460 |
+
elif 'refcoco' in dataset_name:
|
461 |
+
mapper = RefCOCODatasetMapper(cfg, False)
|
462 |
+
elif 'med_sam' in dataset_name:
|
463 |
+
mapper = MedSAMDatasetMapper(cfg, False)
|
464 |
+
elif 'biomed' in dataset_name:
|
465 |
+
mapper = BioMedDatasetMapper(cfg, False)
|
466 |
+
else:
|
467 |
+
mapper = None
|
468 |
+
dataloaders += [build_detection_test_loader(cfg, dataset_name, mapper=mapper)]
|
469 |
+
return dataloaders
|
470 |
+
|
471 |
+
|
472 |
+
def build_train_dataloader(cfg, ):
|
473 |
+
dataset_names = cfg['DATASETS']['TRAIN']
|
474 |
+
|
475 |
+
loaders = {}
|
476 |
+
for dataset_name in dataset_names:
|
477 |
+
cfg = get_config_from_name(cfg, dataset_name)
|
478 |
+
mapper_name = cfg['INPUT']['DATASET_MAPPER_NAME']
|
479 |
+
# Semantic segmentation dataset mapper
|
480 |
+
if mapper_name == "mask_former_semantic":
|
481 |
+
mapper = MaskFormerSemanticDatasetMapper(cfg, True)
|
482 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
483 |
+
# Panoptic segmentation dataset mapper
|
484 |
+
elif mapper_name == "mask_former_panoptic":
|
485 |
+
mapper = MaskFormerPanopticDatasetMapper(cfg, True)
|
486 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
487 |
+
# Instance segmentation dataset mapper
|
488 |
+
elif mapper_name == "mask_former_instance":
|
489 |
+
mapper = MaskFormerInstanceDatasetMapper(cfg, True)
|
490 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
491 |
+
# coco instance segmentation lsj new baseline
|
492 |
+
elif mapper_name == "coco_instance_lsj":
|
493 |
+
mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)
|
494 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
495 |
+
# coco panoptic segmentation lsj new baseline
|
496 |
+
elif mapper_name == "coco_panoptic_lsj":
|
497 |
+
mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True)
|
498 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
499 |
+
elif mapper_name == "vlpretrain":
|
500 |
+
mapper = VLPreDatasetMapper(cfg, True, dataset_name)
|
501 |
+
loaders['vlp'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
502 |
+
elif mapper_name == "refcoco":
|
503 |
+
mapper = RefCOCODatasetMapper(cfg, True)
|
504 |
+
loaders['ref'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
505 |
+
elif mapper_name == "coco_interactive":
|
506 |
+
mapper = COCOPanopticInteractiveDatasetMapper(cfg, True)
|
507 |
+
loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
508 |
+
elif mapper_name == "medsam_interactive":
|
509 |
+
mapper = MedSAMDatasetMapper(cfg, True)
|
510 |
+
loaders['med_sam'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
511 |
+
elif mapper_name == "biomed_interactive":
|
512 |
+
mapper = BioMedDatasetMapper(cfg, True)
|
513 |
+
name_key = dataset_name.split("_")[1]
|
514 |
+
loaders[name_key] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
515 |
+
else:
|
516 |
+
mapper = None
|
517 |
+
loaders[dataset_name] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
|
518 |
+
|
519 |
+
if len(loaders) == 1 or not cfg['LOADER'].get('JOINT', False):
|
520 |
+
return list(loaders.values())[0]
|
521 |
+
else:
|
522 |
+
sample_prob = cfg['LOADER'].get('SAMPLE_PROB', 'prop')
|
523 |
+
mixing_level = cfg['LOADER'].get('MIXING_LEVEL', 1)
|
524 |
+
return JointLoader(loaders, key_dataset=cfg['LOADER'].get('KEY_DATASET', 'coco'), sample_prob=sample_prob, mixing_level=mixing_level)
|
525 |
+
|
526 |
+
|
527 |
+
def build_evaluator(cfg, dataset_name, output_folder=None):
|
528 |
+
"""
|
529 |
+
Create evaluator(s) for a given dataset.
|
530 |
+
This uses the special metadata "evaluator_type" associated with each
|
531 |
+
builtin dataset. For your own dataset, you can simply create an
|
532 |
+
evaluator manually in your script and do not have to worry about the
|
533 |
+
hacky if-else logic here.
|
534 |
+
"""
|
535 |
+
if output_folder is None:
|
536 |
+
output_folder = os.path.join(cfg["SAVE_DIR"], "inference")
|
537 |
+
evaluator_list = []
|
538 |
+
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
539 |
+
|
540 |
+
# semantic segmentation
|
541 |
+
if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
|
542 |
+
evaluator_list.append(
|
543 |
+
SemSegEvaluator(
|
544 |
+
dataset_name,
|
545 |
+
distributed=True,
|
546 |
+
output_dir=output_folder,
|
547 |
+
)
|
548 |
+
)
|
549 |
+
# instance segmentation
|
550 |
+
if evaluator_type == "coco":
|
551 |
+
evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
|
552 |
+
|
553 |
+
cfg_model_decoder_test = cfg["MODEL"]["DECODER"]["TEST"]
|
554 |
+
# panoptic segmentation
|
555 |
+
if evaluator_type in [
|
556 |
+
"coco_panoptic_seg",
|
557 |
+
"ade20k_panoptic_seg",
|
558 |
+
"cityscapes_panoptic_seg",
|
559 |
+
"mapillary_vistas_panoptic_seg",
|
560 |
+
"scannet_panoptic_seg",
|
561 |
+
"bdd_panoptic_pano"
|
562 |
+
]:
|
563 |
+
if cfg_model_decoder_test["PANOPTIC_ON"]:
|
564 |
+
evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
|
565 |
+
# COCO
|
566 |
+
if (evaluator_type == "coco_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]) or evaluator_type == "object365_od":
|
567 |
+
evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
|
568 |
+
if (evaluator_type == "coco_panoptic_seg" and cfg_model_decoder_test["SEMANTIC_ON"]) or evaluator_type == "coco_sem_seg":
|
569 |
+
evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
|
570 |
+
# Mapillary Vistas
|
571 |
+
if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]:
|
572 |
+
evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
|
573 |
+
if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg_model_decoder_test["SEMANTIC_ON"]:
|
574 |
+
evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
|
575 |
+
# Cityscapes
|
576 |
+
if evaluator_type == "cityscapes_instance":
|
577 |
+
assert (
|
578 |
+
torch.cuda.device_count() > comm.get_rank()
|
579 |
+
), "CityscapesEvaluator currently do not work with multiple machines."
|
580 |
+
return CityscapesInstanceEvaluator(dataset_name)
|
581 |
+
if evaluator_type == "cityscapes_sem_seg":
|
582 |
+
assert (
|
583 |
+
torch.cuda.device_count() > comm.get_rank()
|
584 |
+
), "CityscapesEvaluator currently do not work with multiple machines."
|
585 |
+
return CityscapesSemSegEvaluator(dataset_name)
|
586 |
+
if evaluator_type == "cityscapes_panoptic_seg":
|
587 |
+
if cfg_model_decoder_test["SEMANTIC_ON"]:
|
588 |
+
assert (
|
589 |
+
torch.cuda.device_count() > comm.get_rank()
|
590 |
+
), "CityscapesEvaluator currently do not work with multiple machines."
|
591 |
+
evaluator_list.append(CityscapesSemSegEvaluator(dataset_name))
|
592 |
+
if cfg_model_decoder_test["INSTANCE_ON"]:
|
593 |
+
assert (
|
594 |
+
torch.cuda.device_count() > comm.get_rank()
|
595 |
+
), "CityscapesEvaluator currently do not work with multiple machines."
|
596 |
+
evaluator_list.append(CityscapesInstanceEvaluator(dataset_name))
|
597 |
+
# ADE20K
|
598 |
+
if evaluator_type == "ade20k_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]:
|
599 |
+
evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
|
600 |
+
# SEGINW
|
601 |
+
if evaluator_type == "seginw" and cfg_model_decoder_test["INSTANCE_ON"]:
|
602 |
+
evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
|
603 |
+
# LVIS
|
604 |
+
if evaluator_type == "lvis":
|
605 |
+
return LVISEvaluator(dataset_name, output_dir=output_folder)
|
606 |
+
# Classification
|
607 |
+
if evaluator_type == "classification":
|
608 |
+
evaluator_list.append(ClassificationEvaluator(dataset_name, output_folder))
|
609 |
+
# Retrieval
|
610 |
+
if evaluator_type in ["retrieval"]:
|
611 |
+
evaluator_list.append(RetrievalEvaluator(dataset_name, output_folder, cfg['MODEL']['DECODER']['RETRIEVAL']['ENSEMBLE']))
|
612 |
+
if evaluator_type == "captioning":
|
613 |
+
evaluator_list.append(CaptioningEvaluator(dataset_name, output_folder, MetadataCatalog.get(dataset_name).gt_json))
|
614 |
+
if evaluator_type in ["grounding_refcoco", "grounding_phrasecut", "grounding_spatial", "grounding_entity"]:
|
615 |
+
evaluator_list.append(GroundingEvaluator(dataset_name))
|
616 |
+
# Interactive
|
617 |
+
if evaluator_type in ["interactive", "interactive_grounding"]:
|
618 |
+
evaluator_list.append(InteractiveEvaluator(dataset_name, output_dir=output_folder, max_clicks=cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'], iou_iter=cfg['STROKE_SAMPLER']['EVAL']['IOU_ITER']))
|
619 |
+
|
620 |
+
if len(evaluator_list) == 0:
|
621 |
+
raise NotImplementedError(
|
622 |
+
"no Evaluator for the dataset {} with the type {}".format(
|
623 |
+
dataset_name, evaluator_type
|
624 |
+
)
|
625 |
+
)
|
626 |
+
elif len(evaluator_list) == 1:
|
627 |
+
return evaluator_list[0]
|
628 |
+
|
629 |
+
|
630 |
+
return DatasetEvaluators(evaluator_list)
|
datasets/dataset_mappers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .biomed_dataset_mapper import BioMedDatasetMapper
|
datasets/dataset_mappers/biomed_dataset_mapper.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py
|
3 |
+
import copy
|
4 |
+
import logging
|
5 |
+
import random
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from transformers import AutoTokenizer, LlamaForCausalLM
|
11 |
+
|
12 |
+
from detectron2.data import detection_utils as utils
|
13 |
+
from detectron2.data import transforms as T
|
14 |
+
from detectron2.data.transforms import TransformGen
|
15 |
+
from detectron2.structures import BitMasks, Boxes, Instances, BoxMode
|
16 |
+
from detectron2.structures.boxes import pairwise_iou
|
17 |
+
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
|
18 |
+
from detectron2.data import MetadataCatalog
|
19 |
+
from pycocotools import mask as coco_mask
|
20 |
+
|
21 |
+
from utilities import prompt_engineering
|
22 |
+
from modeling.language import build_tokenizer
|
23 |
+
from modeling.language.misc import text_noun_with_prompt_all
|
24 |
+
from modeling.utils import configurable
|
25 |
+
|
26 |
+
from ..visual_sampler.sampler import build_shape_sampler
|
27 |
+
|
28 |
+
__all__ = ["BioMedDatasetMapper"]
|
29 |
+
|
30 |
+
|
31 |
+
def build_transform_gen(cfg, is_train):
|
32 |
+
"""
|
33 |
+
Create a list of default :class:`Augmentation` from config.
|
34 |
+
Now it includes resizing and flipping.
|
35 |
+
Returns:
|
36 |
+
list[Augmentation]
|
37 |
+
"""
|
38 |
+
assert is_train, "Only support training augmentation"
|
39 |
+
cfg_input = cfg['INPUT']
|
40 |
+
image_size = cfg_input['IMAGE_SIZE']
|
41 |
+
min_scale = cfg_input['MIN_SCALE']
|
42 |
+
max_scale = cfg_input['MAX_SCALE']
|
43 |
+
|
44 |
+
augmentation = []
|
45 |
+
|
46 |
+
if cfg_input['RANDOM_FLIP'] != "none":
|
47 |
+
augmentation.append(
|
48 |
+
T.RandomFlip(
|
49 |
+
horizontal=cfg_input['RANDOM_FLIP'] == "horizontal",
|
50 |
+
vertical=cfg_input['RANDOM_FLIP'] == "vertical",
|
51 |
+
)
|
52 |
+
)
|
53 |
+
|
54 |
+
augmentation.extend([
|
55 |
+
T.ResizeScale(
|
56 |
+
min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
|
57 |
+
),
|
58 |
+
T.FixedSizeCrop(crop_size=(image_size, image_size)),
|
59 |
+
])
|
60 |
+
|
61 |
+
return augmentation
|
62 |
+
|
63 |
+
def build_transform_gen_se(cfg, is_train):
|
64 |
+
# min_scale = cfg['INPUT']['MIN_SIZE_TEST']
|
65 |
+
# max_scale = cfg['INPUT']['MAX_SIZE_TEST']
|
66 |
+
|
67 |
+
augmentation = []
|
68 |
+
# augmentation.extend([
|
69 |
+
# T.ResizeShortestEdge(
|
70 |
+
# min_scale, max_size=max_scale
|
71 |
+
# ),
|
72 |
+
# ])
|
73 |
+
return augmentation
|
74 |
+
|
75 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
76 |
+
masks = []
|
77 |
+
for polygons in segmentations:
|
78 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
79 |
+
mask = coco_mask.decode(rles)
|
80 |
+
if len(mask.shape) < 3:
|
81 |
+
mask = mask[..., None]
|
82 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
83 |
+
mask = mask.any(dim=2)
|
84 |
+
masks.append(mask)
|
85 |
+
if masks:
|
86 |
+
masks = torch.stack(masks, dim=0)
|
87 |
+
else:
|
88 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
89 |
+
return masks
|
90 |
+
|
91 |
+
# This is specifically designed for the COCO dataset.
|
92 |
+
class BioMedDatasetMapper:
|
93 |
+
"""
|
94 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
95 |
+
and map it into a format used by MaskFormer.
|
96 |
+
|
97 |
+
This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
|
98 |
+
|
99 |
+
The callable currently does the following:
|
100 |
+
|
101 |
+
1. Read the image from "file_name"
|
102 |
+
2. Applies geometric transforms to the image and annotation
|
103 |
+
3. Find and applies suitable cropping to the image and annotation
|
104 |
+
4. Prepare image and annotation to Tensors
|
105 |
+
"""
|
106 |
+
|
107 |
+
@configurable
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
is_train=True,
|
111 |
+
*,
|
112 |
+
tfm_gens,
|
113 |
+
image_format,
|
114 |
+
caption_thres,
|
115 |
+
grounding,
|
116 |
+
lvis,
|
117 |
+
lvis_thres,
|
118 |
+
max_grounding_num,
|
119 |
+
shape_sampler,
|
120 |
+
retrieval,
|
121 |
+
max_token_num,
|
122 |
+
tokenizer,
|
123 |
+
binary_classes: bool,
|
124 |
+
rotate: bool,
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
NOTE: this interface is experimental.
|
128 |
+
Args:
|
129 |
+
is_train: for training or inference
|
130 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
131 |
+
crop_gen: crop augmentation
|
132 |
+
tfm_gens: data augmentation
|
133 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
134 |
+
"""
|
135 |
+
self.tfm_gens = tfm_gens
|
136 |
+
logging.getLogger(__name__).info(
|
137 |
+
"[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(
|
138 |
+
str(self.tfm_gens)
|
139 |
+
)
|
140 |
+
)
|
141 |
+
|
142 |
+
self.img_format = image_format
|
143 |
+
self.is_train = is_train
|
144 |
+
self.caption_thres = caption_thres
|
145 |
+
self.grounding = grounding
|
146 |
+
self.lvis = lvis
|
147 |
+
self.lvis_thres = lvis_thres
|
148 |
+
self.max_grounding_num = max_grounding_num
|
149 |
+
|
150 |
+
self.shape_sampler = shape_sampler
|
151 |
+
|
152 |
+
self.retrieval = retrieval
|
153 |
+
self.tokenizer = tokenizer
|
154 |
+
self.max_token_num = max_token_num
|
155 |
+
|
156 |
+
self.binary_classes = binary_classes
|
157 |
+
self.rotate = rotate
|
158 |
+
|
159 |
+
@classmethod
|
160 |
+
def from_config(cls, cfg, is_train=True):
|
161 |
+
# Build augmentation
|
162 |
+
if is_train:
|
163 |
+
tfm_gens = build_transform_gen(cfg, is_train)
|
164 |
+
else:
|
165 |
+
tfm_gens = build_transform_gen_se(cfg, is_train)
|
166 |
+
|
167 |
+
shape_sampler = build_shape_sampler(cfg)
|
168 |
+
|
169 |
+
retrieval = cfg['MODEL']['DECODER']['RETRIEVAL']['ENABLED']
|
170 |
+
tokenizer, max_token_num = None, None
|
171 |
+
if retrieval:
|
172 |
+
lang_model = cfg['MODEL']['TEXT']['NAME']
|
173 |
+
max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
|
174 |
+
if 'llama' in lang_model:
|
175 |
+
tokenizer = AutoTokenizer.from_pretrained(lang_model, padding_side='right')
|
176 |
+
tokenizer.pad_token = tokenizer.eos_token
|
177 |
+
else:
|
178 |
+
tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
|
179 |
+
|
180 |
+
ret = {
|
181 |
+
"is_train": is_train,
|
182 |
+
"tfm_gens": tfm_gens,
|
183 |
+
"image_format": cfg['INPUT']['FORMAT'],
|
184 |
+
"caption_thres": cfg['MODEL']['DECODER']['CAPTION']['SIM_THRES'],
|
185 |
+
"grounding": cfg['MODEL']['DECODER']['GROUNDING']['ENABLED'],
|
186 |
+
"lvis": cfg['MODEL']['DECODER']['LVIS']['ENABLED'],
|
187 |
+
"lvis_thres": cfg['MODEL']['DECODER']['LVIS']['THRES'],
|
188 |
+
"max_grounding_num": cfg['MODEL']['DECODER']['GROUNDING']['MAX_LEN'],
|
189 |
+
"shape_sampler": shape_sampler,
|
190 |
+
"retrieval": retrieval,
|
191 |
+
"max_token_num": max_token_num,
|
192 |
+
"tokenizer": tokenizer,
|
193 |
+
"binary_classes": cfg['MODEL']['ENCODER']['BINARY_CLASSES'],
|
194 |
+
"rotate": cfg['INPUT']['RANDOM_ROTATE'],
|
195 |
+
}
|
196 |
+
return ret
|
197 |
+
|
198 |
+
def __call__(self, dataset_dict):
|
199 |
+
"""
|
200 |
+
Args:
|
201 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
dict: a format that builtin models in detectron2 accept
|
205 |
+
"""
|
206 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
207 |
+
while True:
|
208 |
+
try:
|
209 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
210 |
+
break
|
211 |
+
except:
|
212 |
+
print('Image loading error:', dataset_dict["file_name"])
|
213 |
+
|
214 |
+
utils.check_image_size(dataset_dict, image)
|
215 |
+
|
216 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
217 |
+
image_shape = image.shape[:2] # h, w
|
218 |
+
|
219 |
+
rotate_time = 0
|
220 |
+
if self.is_train and self.rotate and random.random() < 0.5:
|
221 |
+
rotate_time = random.randint(1, 3)
|
222 |
+
if rotate_time > 0:
|
223 |
+
image = np.rot90(image, rotate_time)
|
224 |
+
|
225 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
226 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
227 |
+
# Therefore it's important to use torch.Tensor.
|
228 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
229 |
+
|
230 |
+
|
231 |
+
grounding_anno = dataset_dict['grounding_info']
|
232 |
+
if len(grounding_anno) == 0:
|
233 |
+
print(dataset_dict['file_name'])
|
234 |
+
assert len(grounding_anno) > 0
|
235 |
+
masks_grd = []
|
236 |
+
texts_grd = []
|
237 |
+
boxes_grd = []
|
238 |
+
hash_grd = []
|
239 |
+
classes = []
|
240 |
+
masks_orig = []
|
241 |
+
for ann in grounding_anno:
|
242 |
+
if 'segmentation' in ann:
|
243 |
+
if len(ann['segmentation']) == 0:
|
244 |
+
print('Empty segmentation!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
245 |
+
continue
|
246 |
+
rle = coco_mask.frPyObjects(
|
247 |
+
ann['segmentation'], dataset_dict['height'], dataset_dict['width'])
|
248 |
+
m = coco_mask.decode(rle)
|
249 |
+
masks_orig.append(m)
|
250 |
+
# sometimes there are multiple binary map (corresponding to multiple segs)
|
251 |
+
m = np.sum(m, axis=2)
|
252 |
+
else:
|
253 |
+
# directly read from mask file
|
254 |
+
while True:
|
255 |
+
try:
|
256 |
+
m = utils.read_image(ann["mask_file"], format=self.img_format)
|
257 |
+
break
|
258 |
+
except:
|
259 |
+
print('Image loading error:', ann["mask_file"])
|
260 |
+
m = np.sum(m, axis=2)
|
261 |
+
m = 1 * (m > 0)
|
262 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
263 |
+
m = transforms.apply_segmentation(255*m[:,:,None])[:,:,0]
|
264 |
+
if rotate_time > 0:
|
265 |
+
m = np.rot90(m, rotate_time)
|
266 |
+
masks_grd += [m]
|
267 |
+
rand_id = random.randint(0, len(ann['sentences'])-1)
|
268 |
+
texts_grd.append(ann['sentences'][rand_id]['raw'].lower())
|
269 |
+
hash_grd.append(hash(ann['sentences'][rand_id]['raw'].lower()))
|
270 |
+
if self.binary_classes:
|
271 |
+
ann["category_id"] = 1 * (ann["category_id"] > 0)
|
272 |
+
classes.append(ann["category_id"])
|
273 |
+
#masks_grd = torch.from_numpy(np.stack(masks_grd))
|
274 |
+
boxes_grd = torch.tensor(boxes_grd)
|
275 |
+
groundings = {'masks': masks_grd, 'texts': texts_grd, 'hash': hash_grd, 'mode': 'text'}
|
276 |
+
dataset_dict["groundings"] = groundings
|
277 |
+
|
278 |
+
masks_grd = torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks_grd])
|
279 |
+
|
280 |
+
instances = Instances(image_shape)
|
281 |
+
|
282 |
+
instances.gt_masks = BitMasks(masks_grd)
|
283 |
+
instances.gt_boxes = BitMasks(masks_grd).get_bounding_boxes()
|
284 |
+
|
285 |
+
classes = np.array(classes)
|
286 |
+
is_things = np.array([1 for _ in classes])
|
287 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
288 |
+
instances.is_things = torch.tensor(is_things, dtype=torch.int64)
|
289 |
+
|
290 |
+
dataset_dict["instances"] = instances
|
291 |
+
|
292 |
+
|
293 |
+
spatial_query_utils = self.shape_sampler(instances)
|
294 |
+
dataset_dict['spatial_query'] = spatial_query_utils
|
295 |
+
|
296 |
+
if self.retrieval:
|
297 |
+
captions = dataset_dict['captions']
|
298 |
+
tokens = self.tokenizer(
|
299 |
+
captions, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
|
300 |
+
)
|
301 |
+
dataset_dict['tokens'] = {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}
|
302 |
+
|
303 |
+
if self.grounding:
|
304 |
+
grounding_anno = dataset_dict['grounding_info']
|
305 |
+
grounding_len = random.randint(1, self.max_grounding_num-1)
|
306 |
+
if len(grounding_anno) > 0:
|
307 |
+
masks_grd = []
|
308 |
+
texts_grd = []
|
309 |
+
mode = 'text'
|
310 |
+
random.shuffle(grounding_anno)
|
311 |
+
for ann in grounding_anno:
|
312 |
+
if 'segmentation' in ann:
|
313 |
+
if len(ann['segmentation']) == 0:
|
314 |
+
print('Empty segmentation!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
315 |
+
continue
|
316 |
+
rle = coco_mask.frPyObjects(
|
317 |
+
ann['segmentation'], dataset_dict['height'], dataset_dict['width'])
|
318 |
+
m = coco_mask.decode(rle)
|
319 |
+
# sometimes there are multiple binary map (corresponding to multiple segs)
|
320 |
+
m = np.sum(m, axis=2)
|
321 |
+
else:
|
322 |
+
# directly read from mask file
|
323 |
+
while True:
|
324 |
+
try:
|
325 |
+
m = utils.read_image(ann["mask_file"], format=self.img_format)
|
326 |
+
break
|
327 |
+
except:
|
328 |
+
print('Image loading error:', ann["mask_file"])
|
329 |
+
m = np.sum(m, axis=2)
|
330 |
+
m = 1 * (m > 0)
|
331 |
+
|
332 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
333 |
+
m = transforms.apply_segmentation(m[:,:,None])[:,:,0]
|
334 |
+
if rotate_time > 0:
|
335 |
+
m = np.rot90(m, rotate_time)
|
336 |
+
masks_grd += [m]
|
337 |
+
# random select a sentence of a single annotation.
|
338 |
+
rand_index = random.randint(0, len(ann['sentences'])-1)
|
339 |
+
texts_grd += [ann['sentences'][rand_index]['raw'].lower()]
|
340 |
+
# max_len = min(grounding_len, len(texts_grd))
|
341 |
+
max_len = len(masks_grd)
|
342 |
+
indices = np.random.permutation(max_len)
|
343 |
+
texts_grd = list(np.array(texts_grd)[indices])
|
344 |
+
masks_grd = torch.tensor(np.stack(masks_grd)[indices])
|
345 |
+
hash_grd = np.array([hash(txt) for txt in texts_grd])
|
346 |
+
else:
|
347 |
+
masks_grd = instances.gt_masks.tensor
|
348 |
+
mode = 'class'
|
349 |
+
if len(masks_grd) == 0:
|
350 |
+
masks_grd = torch.tensor([])
|
351 |
+
texts_grd = ['none']
|
352 |
+
hash_grd = np.array([hash(txt) for txt in texts_grd])
|
353 |
+
else:
|
354 |
+
biomed_classes = ['liver', 'lung', 'kidney', 'pancreas', 'heart anatomies', 'brain anatomies',
|
355 |
+
'eye anatomies', 'vessel', 'other organ', 'tumor', 'infection', 'other lesion',
|
356 |
+
'fluid disturbance', 'other abnormality', 'histology structure', 'other']
|
357 |
+
if self.binary_classes:
|
358 |
+
biomed_classes = ['target']
|
359 |
+
texts_grd = np.array(biomed_classes)
|
360 |
+
hash_grd = np.array([hash(txt) for txt in texts_grd])
|
361 |
+
unique_hash_grd = np.unique(hash_grd)
|
362 |
+
np.random.shuffle(unique_hash_grd)
|
363 |
+
max_len = min(grounding_len, len(unique_hash_grd))
|
364 |
+
indices = np.random.permutation(max_len)
|
365 |
+
selected_unique_hash_grd = unique_hash_grd[indices]
|
366 |
+
selected_mask = np.in1d(hash_grd, selected_unique_hash_grd)
|
367 |
+
texts_grd = texts_grd[selected_mask]
|
368 |
+
hash_grd = hash_grd[selected_mask]
|
369 |
+
masks_grd = masks_grd[selected_mask]
|
370 |
+
texts_grd = [prompt_engineering(text.replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
|
371 |
+
for text in texts_grd]
|
372 |
+
groundings = {'masks': masks_grd, 'texts': texts_grd, 'mode': mode, 'hash': hash_grd}
|
373 |
+
dataset_dict["groundings"] = groundings
|
374 |
+
assert len(masks_grd) == len(dataset_dict['grounding_info']), f"len(masks_grd)={len(masks_grd)}, len(dataset_dict['grounding_info'])={len(dataset_dict['grounding_info'])}, mask shape={masks_grd.shape}, max_len={max_len}, grounding_len={grounding_len}, len(texts_grd)={len(texts_grd)}, len(hash_grd)={len(hash_grd)}"
|
375 |
+
# gt_masks_orisize = torch.stack([torch.from_numpy(m.squeeze(-1)) for m in masks_orig])
|
376 |
+
# dataset_dict['gt_masks_orisize'] = gt_masks_orisize # (nm,h,w)
|
377 |
+
|
378 |
+
return dataset_dict
|
datasets/evaluation/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .instance_evaluation import *
|
2 |
+
from .classification_evaluation import *
|
3 |
+
from .segmentation_evaluation import *
|
4 |
+
from .retrieval_evaluation import *
|
5 |
+
#from .captioning_evaluation import *
|
6 |
+
from .panoptic_evaluation import *
|
7 |
+
from .grounding_evaluation import *
|
8 |
+
from .interactive_evaluation import *
|
datasets/evaluation/captioning_evaluation.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# --------------------------------------------------------
|
3 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Modified by Xueyan Zou ([email protected])
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
import itertools
|
13 |
+
|
14 |
+
import detectron2.utils.comm as comm
|
15 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
16 |
+
|
17 |
+
from caption_pycocotools.coco import COCO
|
18 |
+
from pycocoevalcap.eval import COCOEvalCap
|
19 |
+
|
20 |
+
|
21 |
+
class CaptioningEvaluator(DatasetEvaluator):
|
22 |
+
"""
|
23 |
+
Evaluate AR for object proposals, AP for instance detection/segmentation, AP
|
24 |
+
for keypoint detection outputs using COCO's metrics.
|
25 |
+
See http://cocodataset.org/#detection-eval and
|
26 |
+
http://cocodataset.org/#keypoints-eval to understand its metrics.
|
27 |
+
The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
|
28 |
+
the metric cannot be computed (e.g. due to no predictions made).
|
29 |
+
In addition to COCO, this evaluator is able to support any bounding box detection,
|
30 |
+
instance segmentation, or keypoint detection dataset.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
distributed=True,
|
36 |
+
output_dir=None,
|
37 |
+
gt_json=None,
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
dataset_name (str): name of the dataset to be evaluated.
|
42 |
+
It must have either the following corresponding metadata:
|
43 |
+
"json_file": the path to the COCO format annotation
|
44 |
+
Or it must be in detectron2's standard dataset format
|
45 |
+
so it can be converted to COCO format automatically.
|
46 |
+
tasks (tuple[str]): tasks that can be evaluated under the given
|
47 |
+
configuration. A task is one of "bbox", "segm", "keypoints".
|
48 |
+
By default, will infer this automatically from predictions.
|
49 |
+
distributed (True): if True, will collect results from all ranks and run evaluation
|
50 |
+
in the main process.
|
51 |
+
Otherwise, will only evaluate the results in the current process.
|
52 |
+
output_dir (str): optional, an output directory to dump all
|
53 |
+
results predicted on the dataset. The dump contains two files:
|
54 |
+
1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
|
55 |
+
contains all the results in the format they are produced by the model.
|
56 |
+
2. "coco_instances_results.json" a json file in COCO's result format.
|
57 |
+
max_dets_per_image (int): limit on the maximum number of detections per image.
|
58 |
+
By default in COCO, this limit is to 100, but this can be customized
|
59 |
+
to be greater, as is needed in evaluation metrics AP fixed and AP pool
|
60 |
+
(see https://arxiv.org/pdf/2102.01066.pdf)
|
61 |
+
This doesn't affect keypoint evaluation.
|
62 |
+
use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
|
63 |
+
Although the results should be very close to the official implementation in COCO
|
64 |
+
API, it is still recommended to compute results with the official API for use in
|
65 |
+
papers. The faster implementation also uses more RAM.
|
66 |
+
kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
|
67 |
+
See http://cocodataset.org/#keypoints-eval
|
68 |
+
When empty, it will use the defaults in COCO.
|
69 |
+
Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
|
70 |
+
allow_cached_coco (bool): Whether to use cached coco json from previous validation
|
71 |
+
runs. You should set this to False if you need to use different validation data.
|
72 |
+
Defaults to True.
|
73 |
+
"""
|
74 |
+
self._logger = logging.getLogger(__name__)
|
75 |
+
self._distributed = distributed
|
76 |
+
self._output_dir = output_dir
|
77 |
+
self._gt_json = COCO(gt_json)
|
78 |
+
|
79 |
+
def reset(self):
|
80 |
+
self._gen_captions = []
|
81 |
+
self._image_ids = []
|
82 |
+
|
83 |
+
def process(self, inputs, outputs):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
|
87 |
+
It is a list of dict. Each dict corresponds to an image and
|
88 |
+
contains keys like "height", "width", "file_name", "image_id".
|
89 |
+
outputs: the outputs of a COCO model. It is a list of dicts with key
|
90 |
+
"instances" that contains :class:`Instances`.
|
91 |
+
"""
|
92 |
+
for output in outputs:
|
93 |
+
self._image_ids.append(output['image_id'])
|
94 |
+
self._gen_captions.append(output['captioning_text'])
|
95 |
+
|
96 |
+
def evaluate(self, img_ids=None):
|
97 |
+
"""
|
98 |
+
Args:
|
99 |
+
img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
|
100 |
+
"""
|
101 |
+
|
102 |
+
if self._distributed:
|
103 |
+
comm.synchronize()
|
104 |
+
def gather(x, move=False):
|
105 |
+
x = comm.gather(x)
|
106 |
+
x = list(itertools.chain(*x))
|
107 |
+
if move:
|
108 |
+
x = [xx.to(self._gen_captions[0].device) for xx in x]
|
109 |
+
return x
|
110 |
+
gen_captions = gather(self._gen_captions)
|
111 |
+
image_ids = gather(self._image_ids)
|
112 |
+
if not comm.is_main_process():
|
113 |
+
return {}
|
114 |
+
else:
|
115 |
+
gen_captions = self._gen_captions
|
116 |
+
image_ids = self._image_ids
|
117 |
+
|
118 |
+
assert len(gen_captions) == len(image_ids)
|
119 |
+
pred_captions = [{"image_id": image_id, "caption": gen_caption} for image_id, gen_caption in zip(image_ids, gen_captions)]
|
120 |
+
pred_pth = os.path.join(self._output_dir, 'results.json')
|
121 |
+
json.dump(pred_captions, open(pred_pth, "w"))
|
122 |
+
|
123 |
+
gt_captions = self._gt_json
|
124 |
+
pred_captions = gt_captions.loadRes(pred_pth)
|
125 |
+
|
126 |
+
cocoEval = COCOEvalCap(gt_captions, pred_captions)
|
127 |
+
cocoEval.params['image_id'] = pred_captions.getImgIds()
|
128 |
+
cocoEval.evaluate()
|
129 |
+
return cocoEval.eval
|
datasets/evaluation/classification_evaluation.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# --------------------------------------------------------
|
3 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Modified by Xueyan Zou ([email protected])
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import logging
|
11 |
+
|
12 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
13 |
+
|
14 |
+
from utilities.misc import AverageMeter
|
15 |
+
from utilities.distributed import get_world_size
|
16 |
+
|
17 |
+
|
18 |
+
@torch.no_grad()
|
19 |
+
def accuracy(output, target, topk=(1,)):
|
20 |
+
"""Computes the precision@k for the specified values of k"""
|
21 |
+
if isinstance(output, list):
|
22 |
+
output = output[-1]
|
23 |
+
|
24 |
+
n_classes = output.size()[1]
|
25 |
+
maxk = min(max(topk), n_classes)
|
26 |
+
batch_size = target.size(0)
|
27 |
+
_, pred = output.topk(maxk, 1, True, True)
|
28 |
+
pred = pred.t()
|
29 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
30 |
+
|
31 |
+
res = []
|
32 |
+
for k in topk:
|
33 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
34 |
+
res.append(correct_k.mul_(100.0 / batch_size).item())
|
35 |
+
return res
|
36 |
+
|
37 |
+
class ClassificationEvaluator(DatasetEvaluator):
|
38 |
+
def __init__(self, *args):
|
39 |
+
self.top1 = AverageMeter()
|
40 |
+
self.top5 = AverageMeter()
|
41 |
+
self._logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
def reset(self):
|
44 |
+
self.top1.reset()
|
45 |
+
self.top5.reset()
|
46 |
+
|
47 |
+
def process(self, inputs, outputs):
|
48 |
+
logits = torch.stack([o['pred_class'] for o in outputs])
|
49 |
+
y = torch.tensor([t['class_id'] for t in inputs], device=logits.device)
|
50 |
+
prec1, prec5 = accuracy(logits, y, (1, 5))
|
51 |
+
self.top1.update(prec1, y.size(0))
|
52 |
+
self.top5.update(prec5, y.size(0))
|
53 |
+
|
54 |
+
def evaluate(self):
|
55 |
+
if get_world_size() > 1:
|
56 |
+
tmp_tensor = torch.tensor(
|
57 |
+
[self.top1.sum, self.top5.sum, self.top1.count],
|
58 |
+
device=torch.cuda.current_device()
|
59 |
+
)
|
60 |
+
torch.distributed.all_reduce(
|
61 |
+
tmp_tensor, torch.distributed.ReduceOp.SUM
|
62 |
+
)
|
63 |
+
top1_sum, top5_sum, count = tmp_tensor.tolist()
|
64 |
+
else:
|
65 |
+
top1_sum = self.top1.sum
|
66 |
+
top5_sum = self.top5.sum
|
67 |
+
count = self.top1.count
|
68 |
+
|
69 |
+
results = {}
|
70 |
+
scores = {
|
71 |
+
'top1': top1_sum / count,
|
72 |
+
"top5": top5_sum / count
|
73 |
+
}
|
74 |
+
results['class'] = scores
|
75 |
+
self._logger.info(results)
|
76 |
+
return results
|
datasets/evaluation/grounding_evaluation.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Modified by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
from torchvision.ops import box_iou
|
10 |
+
|
11 |
+
from detectron2.structures import BoxMode
|
12 |
+
from detectron2.data import MetadataCatalog
|
13 |
+
from detectron2.utils.comm import all_gather, is_main_process, synchronize
|
14 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
15 |
+
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
import numpy as np
|
18 |
+
import os
|
19 |
+
|
20 |
+
import copy
|
21 |
+
|
22 |
+
class GroundingEvaluator(DatasetEvaluator):
|
23 |
+
"""
|
24 |
+
Evaluate grounding segmentation metrics.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
dataset_name,
|
30 |
+
compute_box=False,
|
31 |
+
distributed=True,
|
32 |
+
):
|
33 |
+
self._logger = logging.getLogger(__name__)
|
34 |
+
self._dataset_name = dataset_name
|
35 |
+
self._distributed = distributed
|
36 |
+
self._cpu_device = torch.device("cpu")
|
37 |
+
self._compute_box = compute_box
|
38 |
+
meta = MetadataCatalog.get(dataset_name)
|
39 |
+
|
40 |
+
def reset(self):
|
41 |
+
self.cum_I = 0
|
42 |
+
self.cum_U = 0
|
43 |
+
self.mIoU = 0
|
44 |
+
self.mDice = 0
|
45 |
+
self.cum_mean_area = 0
|
46 |
+
self.eval_seg_iou_list = [.5, .6, .7, .8, .9]
|
47 |
+
self.seg_correct = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
|
48 |
+
self.seg_total = 0
|
49 |
+
self.instance_results = []
|
50 |
+
if self._compute_box:
|
51 |
+
self.mIoU_box = 0
|
52 |
+
self.seg_correct_box = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def computeIoU(pred_seg, gd_seg):
|
56 |
+
I = (pred_seg & gd_seg)
|
57 |
+
U = (pred_seg | gd_seg)
|
58 |
+
return I, U
|
59 |
+
|
60 |
+
def get_metadata(self, _input):
|
61 |
+
"""
|
62 |
+
Extracts and returns specific metadata from the input dictionary.
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
_input (dict): A dictionary containing keys like 'file_name', 'image_id', and 'grounding_info'.
|
66 |
+
The 'grounding_info' is a list of dictionaries with keys like 'area', 'iscrowd', etc.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
dict: A dictionary containing filtered metadata.
|
70 |
+
"""
|
71 |
+
|
72 |
+
_input = copy.deepcopy(_input)
|
73 |
+
|
74 |
+
selected_input_keys = ['file_name', 'image_id', 'grounding_info']
|
75 |
+
selected_grounding_info_keys = ['area', 'mask_file', 'iscrowd', 'image_id', 'category_id', 'id', 'file_name', 'split', 'ann_id', 'ref_id']
|
76 |
+
|
77 |
+
filtered_input = {key: _input[key] for key in selected_input_keys if key in _input}
|
78 |
+
|
79 |
+
# Check if grounding_info is present and is a list
|
80 |
+
if 'grounding_info' in filtered_input and isinstance(filtered_input['grounding_info'], list):
|
81 |
+
# Filter each grounding_info dictionary
|
82 |
+
filtered_input['grounding_info'] = [
|
83 |
+
{key: info[key] for key in selected_grounding_info_keys if key in info}
|
84 |
+
for info in filtered_input['grounding_info']
|
85 |
+
]
|
86 |
+
|
87 |
+
return filtered_input
|
88 |
+
|
89 |
+
def process(self, inputs, outputs):
|
90 |
+
for input, output in zip(inputs, outputs):
|
91 |
+
pred = output['grounding_mask'].sigmoid() > 0.5
|
92 |
+
# # save pixel probability
|
93 |
+
# prob = output['grounding_mask'].sigmoid().cpu().numpy()[0] * 255
|
94 |
+
# pred_file = input['file_name'].split('.')[0].replace('test/', 'test_pred/') + '_' + input['groundings']['texts'][0].replace(' ', '+') + '.png'
|
95 |
+
# if not os.path.exists('/'.join(pred_file.split('/')[:-1])):
|
96 |
+
# os.makedirs('/'.join(pred_file.split('/')[:-1]), exist_ok=True)
|
97 |
+
# plt.imsave(pred_file,
|
98 |
+
# prob.astype(np.uint8), cmap='gray')
|
99 |
+
|
100 |
+
gt = input['groundings']['masks'].bool()
|
101 |
+
bsi = len(pred)
|
102 |
+
I, U = self.computeIoU(pred, gt)
|
103 |
+
self.cum_I += I.sum().cpu()
|
104 |
+
self.cum_U += U.sum().cpu()
|
105 |
+
IoU = I.reshape(bsi,-1).sum(-1)*1.0 / (U.reshape(bsi,-1).sum(-1) + 1e-6)
|
106 |
+
self.mIoU += IoU.sum().cpu()
|
107 |
+
# Add Dice score in eval
|
108 |
+
Dice = I.reshape(bsi,-1).sum(-1)*2.0 / (gt.reshape(bsi,-1).sum(-1) + pred.reshape(bsi,-1).sum(-1) + 1e-6)
|
109 |
+
self.mDice += Dice.sum().cpu()
|
110 |
+
self.cum_mean_area += ((gt.reshape(bsi,-1).sum(-1) + pred.reshape(bsi,-1).sum(-1)) / 2.0).sum().cpu()
|
111 |
+
|
112 |
+
if self._compute_box:
|
113 |
+
pred_box = BoxMode.convert(output['grounding_box'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
|
114 |
+
gt_box = BoxMode.convert(input['groundings']['boxes'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS).cpu()
|
115 |
+
IoU_box = box_iou(pred_box, gt_box).diagonal()
|
116 |
+
self.mIoU_box += IoU_box.sum()
|
117 |
+
|
118 |
+
for idx in range(len(self.eval_seg_iou_list)):
|
119 |
+
eval_seg_iou = self.eval_seg_iou_list[idx]
|
120 |
+
self.seg_correct[idx] += (IoU >= eval_seg_iou).sum().cpu()
|
121 |
+
if self._compute_box:
|
122 |
+
self.seg_correct_box[idx] += (IoU_box >= eval_seg_iou).sum().cpu()
|
123 |
+
self.seg_total += bsi
|
124 |
+
|
125 |
+
instance_result = {
|
126 |
+
'metadata': self.get_metadata(input),
|
127 |
+
'IoU': IoU.cpu().numpy().tolist(),
|
128 |
+
'Dice': Dice.cpu().numpy().tolist(),
|
129 |
+
'I': I.sum(dim=(1, 2)).cpu().numpy().tolist(),
|
130 |
+
'U': U.sum(dim=(1, 2)).cpu().numpy().tolist(),
|
131 |
+
'IoU_box': IoU_box.cpu().numpy().tolist() if self._compute_box else '',
|
132 |
+
'pred_area': pred.reshape(bsi,-1).sum(-1).cpu().numpy().tolist(),
|
133 |
+
}
|
134 |
+
|
135 |
+
iou_len = IoU.shape[0]
|
136 |
+
grounding_info_len = len(self.get_metadata(input)['grounding_info'])
|
137 |
+
assert iou_len == grounding_info_len, f'Number of IoU scores ({iou_len}) and grounding info ({grounding_info_len}) do not match.'
|
138 |
+
self.instance_results.append(instance_result)
|
139 |
+
|
140 |
+
def evaluate(self):
|
141 |
+
if self._distributed:
|
142 |
+
synchronize()
|
143 |
+
self.cum_I = torch.stack(all_gather(self.cum_I)).sum()
|
144 |
+
self.cum_U = torch.stack(all_gather(self.cum_U)).sum()
|
145 |
+
self.mIoU = torch.stack(all_gather(self.mIoU)).sum()
|
146 |
+
self.mDice = torch.stack(all_gather(self.mDice)).sum()
|
147 |
+
self.cum_mean_area = torch.stack(all_gather(self.cum_mean_area)).sum()
|
148 |
+
self.seg_correct = torch.stack(all_gather(self.seg_correct)).sum(0)
|
149 |
+
self.seg_total = sum(all_gather(self.seg_total))
|
150 |
+
self.instance_results = sum(all_gather(self.instance_results), [])
|
151 |
+
if self._compute_box:
|
152 |
+
self.mIoU_box = torch.stack(all_gather(self.mIoU_box)).sum()
|
153 |
+
self.seg_correct_box = torch.stack(all_gather(self.seg_correct_box)).sum(0)
|
154 |
+
if not is_main_process():
|
155 |
+
return
|
156 |
+
|
157 |
+
results = {}
|
158 |
+
for idx in range(len(self.eval_seg_iou_list)):
|
159 |
+
result_str = 'precision@{}'.format(self.eval_seg_iou_list[idx])
|
160 |
+
results[result_str] = (self.seg_correct[idx]*100 / self.seg_total).item()
|
161 |
+
results['cIoU'] = (self.cum_I*100./self.cum_U).item()
|
162 |
+
results['mIoU'] = (self.mIoU*100./self.seg_total).item()
|
163 |
+
results['cDice'] = (self.cum_I*100./self.cum_mean_area).item()
|
164 |
+
results['mDice'] = (self.mDice*100./self.seg_total).item()
|
165 |
+
|
166 |
+
if self._compute_box:
|
167 |
+
for idx in range(len(self.eval_seg_iou_list)):
|
168 |
+
result_str = 'precisionB@{}'.format(self.eval_seg_iou_list[idx])
|
169 |
+
results[result_str] = (self.seg_correct_box[idx]*100 / self.seg_total).item()
|
170 |
+
results['mBIoU'] = (self.mIoU_box*100./self.seg_total).item()
|
171 |
+
|
172 |
+
self._logger.info(results)
|
173 |
+
return {'grounding': {'scores': results, 'instance_results': self.instance_results}}
|
datasets/evaluation/instance_evaluation.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import contextlib
|
3 |
+
import copy
|
4 |
+
import io
|
5 |
+
import itertools
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import pickle
|
11 |
+
from collections import OrderedDict
|
12 |
+
import pycocotools.mask as mask_util
|
13 |
+
import torch
|
14 |
+
from pycocotools.coco import COCO
|
15 |
+
from pycocotools.cocoeval import COCOeval
|
16 |
+
from tabulate import tabulate
|
17 |
+
|
18 |
+
import detectron2.utils.comm as comm
|
19 |
+
from detectron2.config import CfgNode
|
20 |
+
from detectron2.data import MetadataCatalog
|
21 |
+
from detectron2.data.datasets.coco import convert_to_coco_json
|
22 |
+
from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco
|
23 |
+
from detectron2.evaluation.fast_eval_api import COCOeval_opt
|
24 |
+
from detectron2.structures import Boxes, BoxMode, pairwise_iou
|
25 |
+
from detectron2.utils.file_io import PathManager
|
26 |
+
from detectron2.utils.logger import create_small_table
|
27 |
+
|
28 |
+
|
29 |
+
# modified from COCOEvaluator for instance segmetnat
|
30 |
+
class InstanceSegEvaluator(COCOEvaluator):
|
31 |
+
"""
|
32 |
+
Evaluate AR for object proposals, AP for instance detection/segmentation, AP
|
33 |
+
for keypoint detection outputs using COCO's metrics.
|
34 |
+
See http://cocodataset.org/#detection-eval and
|
35 |
+
http://cocodataset.org/#keypoints-eval to understand its metrics.
|
36 |
+
The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
|
37 |
+
the metric cannot be computed (e.g. due to no predictions made).
|
38 |
+
|
39 |
+
In addition to COCO, this evaluator is able to support any bounding box detection,
|
40 |
+
instance segmentation, or keypoint detection dataset.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def _eval_predictions(self, predictions, img_ids=None):
|
44 |
+
"""
|
45 |
+
Evaluate predictions. Fill self._results with the metrics of the tasks.
|
46 |
+
"""
|
47 |
+
self._logger.info("Preparing results for COCO format ...")
|
48 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
49 |
+
tasks = self._tasks or self._tasks_from_predictions(coco_results)
|
50 |
+
|
51 |
+
# unmap the category ids for COCO
|
52 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
53 |
+
dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
|
54 |
+
# all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
|
55 |
+
# num_classes = len(all_contiguous_ids)
|
56 |
+
# assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
|
57 |
+
|
58 |
+
reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
|
59 |
+
for result in coco_results:
|
60 |
+
category_id = result["category_id"]
|
61 |
+
# assert category_id < num_classes, (
|
62 |
+
# f"A prediction has class={category_id}, "
|
63 |
+
# f"but the dataset only has {num_classes} classes and "
|
64 |
+
# f"predicted class id should be in [0, {num_classes - 1}]."
|
65 |
+
# )
|
66 |
+
assert category_id in reverse_id_mapping, (
|
67 |
+
f"A prediction has class={category_id}, "
|
68 |
+
f"but the dataset only has class ids in {dataset_id_to_contiguous_id}."
|
69 |
+
)
|
70 |
+
result["category_id"] = reverse_id_mapping[category_id]
|
71 |
+
|
72 |
+
if self._output_dir:
|
73 |
+
file_path = os.path.join(self._output_dir, "coco_instances_results.json")
|
74 |
+
self._logger.info("Saving results to {}".format(file_path))
|
75 |
+
with PathManager.open(file_path, "w") as f:
|
76 |
+
f.write(json.dumps(coco_results))
|
77 |
+
f.flush()
|
78 |
+
|
79 |
+
if not self._do_evaluation:
|
80 |
+
self._logger.info("Annotations are not available for evaluation.")
|
81 |
+
return
|
82 |
+
|
83 |
+
self._logger.info(
|
84 |
+
"Evaluating predictions with {} COCO API...".format(
|
85 |
+
"unofficial" if self._use_fast_impl else "official"
|
86 |
+
)
|
87 |
+
)
|
88 |
+
for task in sorted(tasks):
|
89 |
+
assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
|
90 |
+
coco_eval = (
|
91 |
+
_evaluate_predictions_on_coco(
|
92 |
+
self._coco_api,
|
93 |
+
coco_results,
|
94 |
+
task,
|
95 |
+
kpt_oks_sigmas=self._kpt_oks_sigmas,
|
96 |
+
use_fast_impl=self._use_fast_impl,
|
97 |
+
img_ids=img_ids,
|
98 |
+
max_dets_per_image=self._max_dets_per_image,
|
99 |
+
)
|
100 |
+
if len(coco_results) > 0
|
101 |
+
else None # cocoapi does not handle empty results very well
|
102 |
+
)
|
103 |
+
|
104 |
+
res = self._derive_coco_results(
|
105 |
+
coco_eval, task, class_names=self._metadata.get("thing_classes")
|
106 |
+
)
|
107 |
+
self._results[task] = res
|
datasets/evaluation/interactive_evaluation.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torchvision.ops import box_iou
|
8 |
+
|
9 |
+
from detectron2.structures import BoxMode
|
10 |
+
from detectron2.data import MetadataCatalog
|
11 |
+
from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize
|
12 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
13 |
+
|
14 |
+
|
15 |
+
class InteractiveEvaluator(DatasetEvaluator):
|
16 |
+
"""
|
17 |
+
Evaluate point interactive IoU metrics.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
dataset_name,
|
23 |
+
output_dir,
|
24 |
+
max_clicks=20,
|
25 |
+
iou_iter=1,
|
26 |
+
compute_box=False,
|
27 |
+
distributed=True,
|
28 |
+
):
|
29 |
+
self._logger = logging.getLogger(__name__)
|
30 |
+
self._dataset_name = dataset_name
|
31 |
+
self._distributed = distributed
|
32 |
+
self._cpu_device = torch.device("cpu")
|
33 |
+
self._output_dir = output_dir
|
34 |
+
|
35 |
+
self.max_clicks = max_clicks
|
36 |
+
self.iou_iter = iou_iter
|
37 |
+
meta = MetadataCatalog.get(dataset_name)
|
38 |
+
|
39 |
+
def reset(self):
|
40 |
+
self.iou_list = []
|
41 |
+
self.num_samples = 0
|
42 |
+
self.all_ious = [0.5, 0.8, 0.85, 0.9]
|
43 |
+
|
44 |
+
def process(self, inputs, outputs):
|
45 |
+
self.iou_list += [o['mask_iou'] for o in outputs]
|
46 |
+
self.num_samples += len(outputs)
|
47 |
+
|
48 |
+
def compute_noc(self):
|
49 |
+
def _get_noc(iou_arr, iou_thr):
|
50 |
+
vals = iou_arr >= iou_thr
|
51 |
+
return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks
|
52 |
+
|
53 |
+
noc_list = {}
|
54 |
+
for iou_thr in self.all_ious:
|
55 |
+
scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list]
|
56 |
+
noc_list[str(iou_thr)] = scores_arr
|
57 |
+
|
58 |
+
iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1]
|
59 |
+
noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()}
|
60 |
+
|
61 |
+
if self._distributed:
|
62 |
+
num_samples = sum(all_gather(self.num_samples))
|
63 |
+
noc_list_sum_gather = all_gather(noc_list_sum)
|
64 |
+
iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu())
|
65 |
+
|
66 |
+
noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]}
|
67 |
+
for nlg in noc_list_sum_gather:
|
68 |
+
for key, value in nlg.items():
|
69 |
+
noc_list_sum[key] += value
|
70 |
+
|
71 |
+
pred_noc = {}
|
72 |
+
if self._distributed and (not is_main_process()):
|
73 |
+
return pred_noc
|
74 |
+
|
75 |
+
for key, value in noc_list_sum.items():
|
76 |
+
pred_noc[key] = value / num_samples
|
77 |
+
|
78 |
+
pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples
|
79 |
+
return pred_noc
|
80 |
+
|
81 |
+
def evaluate(self):
|
82 |
+
pred_noc = self.compute_noc()
|
83 |
+
|
84 |
+
if self._distributed and (not is_main_process()):
|
85 |
+
return
|
86 |
+
|
87 |
+
def draw_iou_curve(iou_list, save_dir):
|
88 |
+
iou_list = torch.stack(iou_list, dim=0)
|
89 |
+
iou_list = iou_list.mean(dim=0).cpu().numpy()
|
90 |
+
# draw iou curve, with x-axis as number of clicks, y-axis as iou using matplotlib
|
91 |
+
import matplotlib.pyplot as plt
|
92 |
+
plt.figure()
|
93 |
+
plt.plot(range(1, self.max_clicks+1), iou_list)
|
94 |
+
plt.xlabel('Number of clicks')
|
95 |
+
plt.ylabel('IoU')
|
96 |
+
|
97 |
+
|
98 |
+
# create directory if not exist
|
99 |
+
import os
|
100 |
+
output_dir = os.path.join(save_dir, 'iou_by_clicks')
|
101 |
+
if not os.path.exists(output_dir):
|
102 |
+
os.makedirs(output_dir)
|
103 |
+
|
104 |
+
# get current time and format in 10 digits
|
105 |
+
import time
|
106 |
+
current_time = time.time()
|
107 |
+
current_time = int(current_time)
|
108 |
+
current_time = str(current_time)
|
109 |
+
|
110 |
+
# save iou curve
|
111 |
+
plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time)))
|
112 |
+
|
113 |
+
draw_iou_curve(self.iou_list, self._output_dir)
|
114 |
+
results = {}
|
115 |
+
for idx in range(len(self.all_ious)):
|
116 |
+
result_str = 'noc@{}'.format(self.all_ious[idx])
|
117 |
+
results[result_str] = pred_noc[str(self.all_ious[idx])]
|
118 |
+
|
119 |
+
results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter']
|
120 |
+
|
121 |
+
self._logger.info(results)
|
122 |
+
return {'interactive': results}
|
datasets/evaluation/panoptic_evaluation.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import contextlib
|
3 |
+
import io
|
4 |
+
import itertools
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
+
import tempfile
|
10 |
+
from collections import OrderedDict
|
11 |
+
from typing import Optional
|
12 |
+
from PIL import Image
|
13 |
+
from tabulate import tabulate
|
14 |
+
|
15 |
+
from detectron2.data import MetadataCatalog
|
16 |
+
from detectron2.utils import comm
|
17 |
+
from detectron2.utils.file_io import PathManager
|
18 |
+
|
19 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class COCOPanopticEvaluator(DatasetEvaluator):
|
25 |
+
"""
|
26 |
+
Evaluate Panoptic Quality metrics on COCO using PanopticAPI.
|
27 |
+
It saves panoptic segmentation prediction in `output_dir`
|
28 |
+
|
29 |
+
It contains a synchronize call and has to be called from all workers.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, dataset_name: str, output_dir: Optional[str] = None):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
dataset_name: name of the dataset
|
36 |
+
output_dir: output directory to save results for evaluation.
|
37 |
+
"""
|
38 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
39 |
+
self._thing_contiguous_id_to_dataset_id = {
|
40 |
+
v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
|
41 |
+
}
|
42 |
+
self._stuff_contiguous_id_to_dataset_id = {
|
43 |
+
v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items()
|
44 |
+
}
|
45 |
+
|
46 |
+
self._output_dir = output_dir
|
47 |
+
if self._output_dir is not None:
|
48 |
+
PathManager.mkdirs(self._output_dir)
|
49 |
+
|
50 |
+
def reset(self):
|
51 |
+
self._predictions = []
|
52 |
+
|
53 |
+
def _convert_category_id(self, segment_info):
|
54 |
+
isthing = segment_info.pop("isthing", None)
|
55 |
+
if isthing is None:
|
56 |
+
# the model produces panoptic category id directly. No more conversion needed
|
57 |
+
return segment_info
|
58 |
+
if isthing is True:
|
59 |
+
segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[
|
60 |
+
segment_info["category_id"]
|
61 |
+
]
|
62 |
+
else:
|
63 |
+
segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[
|
64 |
+
segment_info["category_id"]
|
65 |
+
]
|
66 |
+
return segment_info
|
67 |
+
|
68 |
+
def process(self, inputs, outputs):
|
69 |
+
from panopticapi.utils import id2rgb
|
70 |
+
|
71 |
+
for input, output in zip(inputs, outputs):
|
72 |
+
panoptic_img, segments_info = output["panoptic_seg"]
|
73 |
+
panoptic_img = panoptic_img.cpu().numpy()
|
74 |
+
if segments_info is None:
|
75 |
+
# If "segments_info" is None, we assume "panoptic_img" is a
|
76 |
+
# H*W int32 image storing the panoptic_id in the format of
|
77 |
+
# category_id * label_divisor + instance_id. We reserve -1 for
|
78 |
+
# VOID label, and add 1 to panoptic_img since the official
|
79 |
+
# evaluation script uses 0 for VOID label.
|
80 |
+
label_divisor = self._metadata.label_divisor
|
81 |
+
segments_info = []
|
82 |
+
for panoptic_label in np.unique(panoptic_img):
|
83 |
+
if panoptic_label == -1:
|
84 |
+
# VOID region.
|
85 |
+
continue
|
86 |
+
pred_class = panoptic_label // label_divisor
|
87 |
+
isthing = (
|
88 |
+
pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values()
|
89 |
+
)
|
90 |
+
segments_info.append(
|
91 |
+
{
|
92 |
+
"id": int(panoptic_label) + 1,
|
93 |
+
"category_id": int(pred_class),
|
94 |
+
"isthing": bool(isthing),
|
95 |
+
}
|
96 |
+
)
|
97 |
+
# Official evaluation script uses 0 for VOID label.
|
98 |
+
panoptic_img += 1
|
99 |
+
|
100 |
+
file_name = os.path.basename(input["file_name"])
|
101 |
+
file_name_png = os.path.splitext(file_name)[0] + ".png"
|
102 |
+
with io.BytesIO() as out:
|
103 |
+
Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
|
104 |
+
segments_info = [self._convert_category_id(x) for x in segments_info]
|
105 |
+
self._predictions.append(
|
106 |
+
{
|
107 |
+
"image_id": input["image_id"],
|
108 |
+
"file_name": file_name_png,
|
109 |
+
"png_string": out.getvalue(),
|
110 |
+
"segments_info": segments_info,
|
111 |
+
}
|
112 |
+
)
|
113 |
+
|
114 |
+
def evaluate(self):
|
115 |
+
comm.synchronize()
|
116 |
+
|
117 |
+
self._predictions = comm.gather(self._predictions)
|
118 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
119 |
+
if not comm.is_main_process():
|
120 |
+
return
|
121 |
+
|
122 |
+
# PanopticApi requires local files
|
123 |
+
gt_json = PathManager.get_local_path(self._metadata.panoptic_json)
|
124 |
+
gt_folder = PathManager.get_local_path(self._metadata.panoptic_root)
|
125 |
+
|
126 |
+
with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
|
127 |
+
logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
|
128 |
+
for p in self._predictions:
|
129 |
+
with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
|
130 |
+
f.write(p.pop("png_string"))
|
131 |
+
|
132 |
+
with open(gt_json, "r") as f:
|
133 |
+
json_data = json.load(f)
|
134 |
+
json_data["annotations"] = self._predictions
|
135 |
+
|
136 |
+
output_dir = self._output_dir or pred_dir
|
137 |
+
predictions_json = os.path.join(output_dir, "predictions.json")
|
138 |
+
with PathManager.open(predictions_json, "w") as f:
|
139 |
+
f.write(json.dumps(json_data))
|
140 |
+
|
141 |
+
from panopticapi.evaluation import pq_compute
|
142 |
+
|
143 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
144 |
+
pq_res = pq_compute(
|
145 |
+
gt_json,
|
146 |
+
PathManager.get_local_path(predictions_json),
|
147 |
+
gt_folder=gt_folder,
|
148 |
+
pred_folder=pred_dir,
|
149 |
+
)
|
150 |
+
|
151 |
+
res = {}
|
152 |
+
res["PQ"] = 100 * pq_res["All"]["pq"]
|
153 |
+
res["SQ"] = 100 * pq_res["All"]["sq"]
|
154 |
+
res["RQ"] = 100 * pq_res["All"]["rq"]
|
155 |
+
res["PQ_th"] = 100 * pq_res["Things"]["pq"]
|
156 |
+
res["SQ_th"] = 100 * pq_res["Things"]["sq"]
|
157 |
+
res["RQ_th"] = 100 * pq_res["Things"]["rq"]
|
158 |
+
res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
|
159 |
+
res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
|
160 |
+
res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
|
161 |
+
|
162 |
+
results = OrderedDict({"panoptic_seg": res})
|
163 |
+
_print_panoptic_results(pq_res)
|
164 |
+
|
165 |
+
return results
|
166 |
+
|
167 |
+
|
168 |
+
def _print_panoptic_results(pq_res):
|
169 |
+
headers = ["", "PQ", "SQ", "RQ", "#categories"]
|
170 |
+
data = []
|
171 |
+
for name in ["All", "Things", "Stuff"]:
|
172 |
+
row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
|
173 |
+
data.append(row)
|
174 |
+
table = tabulate(
|
175 |
+
data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
|
176 |
+
)
|
177 |
+
logger.info("Panoptic Evaluation Results:\n" + table)
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
from detectron2.utils.logger import setup_logger
|
182 |
+
|
183 |
+
logger = setup_logger()
|
184 |
+
import argparse
|
185 |
+
|
186 |
+
parser = argparse.ArgumentParser()
|
187 |
+
parser.add_argument("--gt-json")
|
188 |
+
parser.add_argument("--gt-dir")
|
189 |
+
parser.add_argument("--pred-json")
|
190 |
+
parser.add_argument("--pred-dir")
|
191 |
+
args = parser.parse_args()
|
192 |
+
|
193 |
+
from panopticapi.evaluation import pq_compute
|
194 |
+
|
195 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
196 |
+
pq_res = pq_compute(
|
197 |
+
args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir
|
198 |
+
)
|
199 |
+
_print_panoptic_results(pq_res)
|
datasets/evaluation/retrieval_evaluation.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Modified by Xueyan Zou ([email protected]), Ziyi Dou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import copy
|
8 |
+
import itertools
|
9 |
+
import logging
|
10 |
+
from collections import OrderedDict
|
11 |
+
import torch
|
12 |
+
from pycocotools.cocoeval import COCOeval
|
13 |
+
|
14 |
+
import detectron2.utils.comm as comm
|
15 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
16 |
+
|
17 |
+
try:
|
18 |
+
from detectron2.evaluation.fast_eval_api import COCOeval_opt
|
19 |
+
except ImportError:
|
20 |
+
COCOeval_opt = COCOeval
|
21 |
+
|
22 |
+
|
23 |
+
class RetrievalEvaluator(DatasetEvaluator):
|
24 |
+
"""
|
25 |
+
Evaluate AR for object proposals, AP for instance detection/segmentation, AP
|
26 |
+
for keypoint detection outputs using COCO's metrics.
|
27 |
+
See http://cocodataset.org/#detection-eval and
|
28 |
+
http://cocodataset.org/#keypoints-eval to understand its metrics.
|
29 |
+
The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
|
30 |
+
the metric cannot be computed (e.g. due to no predictions made).
|
31 |
+
In addition to COCO, this evaluator is able to support any bounding box detection,
|
32 |
+
instance segmentation, or keypoint detection dataset.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
dataset_name=None,
|
38 |
+
output_dir=None,
|
39 |
+
ensemble=False,
|
40 |
+
distributed=True,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
dataset_name (str): name of the dataset to be evaluated.
|
45 |
+
It must have either the following corresponding metadata:
|
46 |
+
"json_file": the path to the COCO format annotation
|
47 |
+
Or it must be in detectron2's standard dataset format
|
48 |
+
so it can be converted to COCO format automatically.
|
49 |
+
tasks (tuple[str]): tasks that can be evaluated under the given
|
50 |
+
configuration. A task is one of "bbox", "segm", "keypoints".
|
51 |
+
By default, will infer this automatically from predictions.
|
52 |
+
distributed (True): if True, will collect results from all ranks and run evaluation
|
53 |
+
in the main process.
|
54 |
+
Otherwise, will only evaluate the results in the current process.
|
55 |
+
output_dir (str): optional, an output directory to dump all
|
56 |
+
results predicted on the dataset. The dump contains two files:
|
57 |
+
1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
|
58 |
+
contains all the results in the format they are produced by the model.
|
59 |
+
2. "coco_instances_results.json" a json file in COCO's result format.
|
60 |
+
max_dets_per_image (int): limit on the maximum number of detections per image.
|
61 |
+
By default in COCO, this limit is to 100, but this can be customized
|
62 |
+
to be greater, as is needed in evaluation metrics AP fixed and AP pool
|
63 |
+
(see https://arxiv.org/pdf/2102.01066.pdf)
|
64 |
+
This doesn't affect keypoint evaluation.
|
65 |
+
use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
|
66 |
+
Although the results should be very close to the official implementation in COCO
|
67 |
+
API, it is still recommended to compute results with the official API for use in
|
68 |
+
papers. The faster implementation also uses more RAM.
|
69 |
+
kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
|
70 |
+
See http://cocodataset.org/#keypoints-eval
|
71 |
+
When empty, it will use the defaults in COCO.
|
72 |
+
Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
|
73 |
+
allow_cached_coco (bool): Whether to use cached coco json from previous validation
|
74 |
+
runs. You should set this to False if you need to use different validation data.
|
75 |
+
Defaults to True.
|
76 |
+
"""
|
77 |
+
self._logger = logging.getLogger(__name__)
|
78 |
+
self._dataset_name = dataset_name
|
79 |
+
self._output_dir = output_dir
|
80 |
+
self._ensemble = ensemble
|
81 |
+
self._distributed = distributed
|
82 |
+
|
83 |
+
if 'p2i' in dataset_name:
|
84 |
+
self.mode = 'patch2image'
|
85 |
+
elif 'interactive2i' in dataset_name:
|
86 |
+
self.mode = 'interactive2image'
|
87 |
+
else:
|
88 |
+
self.mode = 'default'
|
89 |
+
|
90 |
+
def reset(self):
|
91 |
+
self._text_embeds = []
|
92 |
+
self._image_embeds = []
|
93 |
+
self._image_embeds2 = []
|
94 |
+
self._text_ids = []
|
95 |
+
self._image_ids = []
|
96 |
+
|
97 |
+
def process(self, inputs, outputs):
|
98 |
+
"""
|
99 |
+
Args:
|
100 |
+
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
|
101 |
+
It is a list of dict. Each dict corresponds to an image and
|
102 |
+
contains keys like "height", "width", "file_name", "image_id".
|
103 |
+
outputs: the outputs of a COCO model. It is a list of dicts with key
|
104 |
+
"instances" that contains :class:`Instances`.
|
105 |
+
"""
|
106 |
+
for output in outputs:
|
107 |
+
self._text_ids.extend(output['caption']['caption_ids'])
|
108 |
+
self._image_ids.append(output['caption']['image_ids'])
|
109 |
+
self._text_embeds.append(output['caption']['text_embeds'])
|
110 |
+
self._image_embeds.append(output['caption']['image_embeds'][0])
|
111 |
+
if self._ensemble:
|
112 |
+
self._image_embeds2.append(output['caption']['image_embeds'][1])
|
113 |
+
|
114 |
+
def evaluate(self, img_ids=None):
|
115 |
+
if self.mode == 'default':
|
116 |
+
return self.evaluate_default(img_ids)
|
117 |
+
elif self.mode in ['patch2image', 'interactive2image']:
|
118 |
+
return self.evaluate_p2i(img_ids)
|
119 |
+
else:
|
120 |
+
assert False, "Unknown mode for retrieval evaluation"
|
121 |
+
|
122 |
+
def evaluate_default(self, img_ids=None):
|
123 |
+
"""
|
124 |
+
Args:
|
125 |
+
img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
|
126 |
+
"""
|
127 |
+
|
128 |
+
if self._distributed:
|
129 |
+
comm.synchronize()
|
130 |
+
def gather(x, move=False):
|
131 |
+
x = comm.gather(x)
|
132 |
+
x = list(itertools.chain(*x))
|
133 |
+
if move:
|
134 |
+
x = [xx.to(self._text_embeds[0].device) for xx in x]
|
135 |
+
return x
|
136 |
+
text_embeds = gather(self._text_embeds, move=True)
|
137 |
+
image_embeds = gather(self._image_embeds, move=True)
|
138 |
+
if self._ensemble:
|
139 |
+
image_embeds2 = gather(self._image_embeds2, move=True)
|
140 |
+
text_ids = gather(self._text_ids)
|
141 |
+
image_ids = gather(self._image_ids)
|
142 |
+
if not comm.is_main_process():
|
143 |
+
return {}
|
144 |
+
else:
|
145 |
+
text_embeds = self._text_embeds
|
146 |
+
image_embeds = self._image_embeds
|
147 |
+
if self._ensemble:
|
148 |
+
image_embeds2 = self._image_embeds2
|
149 |
+
text_ids = self._text_ids
|
150 |
+
image_ids = self._image_ids
|
151 |
+
if len(text_embeds) == 0:
|
152 |
+
self._logger.warning("[COCOCaptionEvaluator] Did not receive valid predictions.")
|
153 |
+
return {}
|
154 |
+
iids = torch.tensor(image_ids).view(-1)
|
155 |
+
tiids = torch.tensor(text_ids).view(-1)
|
156 |
+
image_embeds = torch.cat(image_embeds)
|
157 |
+
text_embeds = torch.cat(text_embeds)
|
158 |
+
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
159 |
+
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
160 |
+
scores = image_embeds @ text_embeds.t()
|
161 |
+
|
162 |
+
if self._ensemble:
|
163 |
+
image_embeds2 = torch.cat(image_embeds2)
|
164 |
+
image_embeds2 = image_embeds2 / image_embeds2.norm(dim=-1, keepdim=True)
|
165 |
+
scores2 = image_embeds2 @ text_embeds.t()
|
166 |
+
scores = scores2 * 0.5 + scores * 0.5
|
167 |
+
|
168 |
+
topk10 = scores.topk(10, dim=1)
|
169 |
+
topk5 = scores.topk(5, dim=1)
|
170 |
+
topk1 = scores.topk(1, dim=1)
|
171 |
+
topk10_iids = tiids[topk10.indices]
|
172 |
+
topk5_iids = tiids[topk5.indices]
|
173 |
+
topk1_iids = tiids[topk1.indices]
|
174 |
+
tr_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
|
175 |
+
tr_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
|
176 |
+
tr_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()
|
177 |
+
topk10 = scores.topk(10, dim=0)
|
178 |
+
topk5 = scores.topk(5, dim=0)
|
179 |
+
topk1 = scores.topk(1, dim=0)
|
180 |
+
topk10_iids = iids[topk10.indices]
|
181 |
+
topk5_iids = iids[topk5.indices]
|
182 |
+
topk1_iids = iids[topk1.indices]
|
183 |
+
ir_r10 = (tiids.unsqueeze(0) == topk10_iids).float().max(dim=0)[0].mean()
|
184 |
+
ir_r5 = (tiids.unsqueeze(0) == topk5_iids).float().max(dim=0)[0].mean()
|
185 |
+
ir_r1 = (tiids.unsqueeze(0) == topk1_iids).float().max(dim=0)[0].mean()
|
186 |
+
self._results = OrderedDict()
|
187 |
+
# Copy so the caller can do whatever with results
|
188 |
+
self._results['recall'] = {}
|
189 |
+
self._results['recall']['irtr'] = float("{:.3f}".format((ir_r1 + tr_r1).item() * 100))
|
190 |
+
self._results['recall']['ir1'] = float("{:.3f}".format(ir_r1.item() * 100))
|
191 |
+
self._results['recall']['ir5'] = float("{:.3f}".format(ir_r5.item() * 100))
|
192 |
+
self._results['recall']['ir10'] = float("{:.3f}".format(ir_r10.item() * 100))
|
193 |
+
self._results['recall']['tr1'] = float("{:.3f}".format(tr_r1.item() * 100))
|
194 |
+
self._results['recall']['tr5'] = float("{:.3f}".format(tr_r5.item() * 100))
|
195 |
+
self._results['recall']['tr10'] = float("{:.3f}".format(tr_r10.item() * 100))
|
196 |
+
self._logger.info(self._results)
|
197 |
+
return copy.deepcopy(self._results)
|
198 |
+
|
199 |
+
def evaluate_p2i(self, img_ids=None):
|
200 |
+
"""
|
201 |
+
Args:
|
202 |
+
img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
|
203 |
+
"""
|
204 |
+
|
205 |
+
if self._distributed:
|
206 |
+
comm.synchronize()
|
207 |
+
def gather(x, move=False):
|
208 |
+
x = comm.gather(x)
|
209 |
+
x = list(itertools.chain(*x))
|
210 |
+
if move:
|
211 |
+
x = [xx.to(self._text_embeds[0].device) for xx in x]
|
212 |
+
return x
|
213 |
+
text_embeds = gather(self._text_embeds, move=True)
|
214 |
+
image_embeds = gather(self._image_embeds, move=True)
|
215 |
+
image_embeds2 = gather(self._image_embeds2, move=True)
|
216 |
+
text_ids = gather(self._text_ids)
|
217 |
+
image_ids = gather(self._image_ids)
|
218 |
+
if not comm.is_main_process():
|
219 |
+
return {}
|
220 |
+
else:
|
221 |
+
text_embeds = self._text_embeds
|
222 |
+
image_embeds = self._image_embeds
|
223 |
+
image_embeds2 = self._image_embeds2
|
224 |
+
text_ids = self._text_ids
|
225 |
+
image_ids = self._image_ids
|
226 |
+
|
227 |
+
if len(text_embeds) == 0:
|
228 |
+
self._logger.warning("[COCOCaptionEvaluator] Did not receive valid predictions.")
|
229 |
+
return {}
|
230 |
+
|
231 |
+
iids = torch.tensor(image_ids).view(-1)
|
232 |
+
tiids = torch.tensor(text_ids).view(-1)
|
233 |
+
image_embeds = torch.cat(image_embeds)
|
234 |
+
text_embeds = torch.cat(text_embeds)
|
235 |
+
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
236 |
+
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
237 |
+
|
238 |
+
image_embeds2 = torch.cat(image_embeds2)
|
239 |
+
image_embeds2 = image_embeds2 / image_embeds2.norm(dim=-1, keepdim=True)
|
240 |
+
|
241 |
+
# compute image to image retrieval
|
242 |
+
self._results = OrderedDict()
|
243 |
+
self._results['recall'] = {}
|
244 |
+
ii_scores = image_embeds2 @ image_embeds.t()
|
245 |
+
|
246 |
+
topk10 = ii_scores.topk(10, dim=1)
|
247 |
+
topk5 = ii_scores.topk(5, dim=1)
|
248 |
+
topk1 = ii_scores.topk(1, dim=1)
|
249 |
+
topk10_iids = iids[topk10.indices]
|
250 |
+
topk5_iids = iids[topk5.indices]
|
251 |
+
topk1_iids = iids[topk1.indices]
|
252 |
+
iir_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
|
253 |
+
iir_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
|
254 |
+
iir_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()
|
255 |
+
# Copy so the caller can do whatever with results
|
256 |
+
self._results['recall']['p2ir1'] = float("{:.3f}".format(iir_r1.item() * 100))
|
257 |
+
self._results['recall']['p2ir5'] = float("{:.3f}".format(iir_r5.item() * 100))
|
258 |
+
self._results['recall']['p2ir10'] = float("{:.3f}".format(iir_r10.item() * 100))
|
259 |
+
self._logger.info(self._results)
|
260 |
+
return copy.deepcopy(self._results)
|
datasets/evaluation/segmentation_evaluation.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
from collections import OrderedDict
|
8 |
+
import PIL.Image as Image
|
9 |
+
import pycocotools.mask as mask_util
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
13 |
+
from detectron2.utils.comm import all_gather, is_main_process
|
14 |
+
from detectron2.utils.file_io import PathManager
|
15 |
+
from detectron2.evaluation.evaluator import DatasetEvaluator
|
16 |
+
from utilities.distributed import synchronize
|
17 |
+
|
18 |
+
from ..semseg_loader import load_semseg
|
19 |
+
|
20 |
+
|
21 |
+
class SemSegEvaluator(DatasetEvaluator):
|
22 |
+
"""
|
23 |
+
Evaluate semantic segmentation metrics.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
dataset_name,
|
29 |
+
distributed=True,
|
30 |
+
output_dir=None,
|
31 |
+
*,
|
32 |
+
num_classes=None,
|
33 |
+
ignore_label=None,
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
dataset_name (str): name of the dataset to be evaluated.
|
38 |
+
distributed (bool): if True, will collect results from all ranks for evaluation.
|
39 |
+
Otherwise, will evaluate the results in the current process.
|
40 |
+
output_dir (str): an output directory to dump results.
|
41 |
+
num_classes, ignore_label: deprecated argument
|
42 |
+
"""
|
43 |
+
self._logger = logging.getLogger(__name__)
|
44 |
+
if num_classes is not None:
|
45 |
+
self._logger.warn(
|
46 |
+
"SemSegEvaluator(num_classes) is deprecated! It should be obtained from metadata."
|
47 |
+
)
|
48 |
+
if ignore_label is not None:
|
49 |
+
self._logger.warn(
|
50 |
+
"SemSegEvaluator(ignore_label) is deprecated! It should be obtained from metadata."
|
51 |
+
)
|
52 |
+
self._dataset_name = dataset_name
|
53 |
+
self._distributed = distributed
|
54 |
+
self._output_dir = output_dir
|
55 |
+
|
56 |
+
self._cpu_device = torch.device("cpu")
|
57 |
+
|
58 |
+
self.input_file_to_gt_file = {
|
59 |
+
dataset_record["file_name"]: dataset_record["sem_seg_file_name"]
|
60 |
+
for dataset_record in DatasetCatalog.get(dataset_name)
|
61 |
+
}
|
62 |
+
|
63 |
+
meta = MetadataCatalog.get(dataset_name)
|
64 |
+
# Dict that maps contiguous training ids to COCO category ids
|
65 |
+
try:
|
66 |
+
c2d = meta.stuff_dataset_id_to_contiguous_id
|
67 |
+
self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()}
|
68 |
+
except AttributeError:
|
69 |
+
self._contiguous_id_to_dataset_id = None
|
70 |
+
self._class_names = meta.stuff_classes
|
71 |
+
self._class_offset = meta.class_offset if hasattr(meta, 'class_offset') else 0
|
72 |
+
self._num_classes = len(meta.stuff_classes)
|
73 |
+
self._semseg_loader = meta.semseg_loader if hasattr(meta, 'semseg_loader') else 'PIL'
|
74 |
+
|
75 |
+
if num_classes is not None:
|
76 |
+
assert self._num_classes == num_classes, f"{self._num_classes} != {num_classes}"
|
77 |
+
self._ignore_label = ignore_label if ignore_label is not None else meta.ignore_label
|
78 |
+
|
79 |
+
def reset(self):
|
80 |
+
self._conf_matrix = np.zeros((self._num_classes + 1, self._num_classes + 1), dtype=np.int64)
|
81 |
+
self._predictions = []
|
82 |
+
|
83 |
+
def process(self, inputs, outputs):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
inputs: the inputs to a model.
|
87 |
+
It is a list of dicts. Each dict corresponds to an image and
|
88 |
+
contains keys like "height", "width", "file_name".
|
89 |
+
outputs: the outputs of a model. It is either list of semantic segmentation predictions
|
90 |
+
(Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
|
91 |
+
segmentation prediction in the same format.
|
92 |
+
"""
|
93 |
+
for input, output in zip(inputs, outputs):
|
94 |
+
output = output["sem_seg"].argmax(dim=0).to(self._cpu_device)
|
95 |
+
pred = np.array(output, dtype=np.int)
|
96 |
+
|
97 |
+
with PathManager.open(self.input_file_to_gt_file[input["file_name"]], "rb") as f:
|
98 |
+
gt = load_semseg(f, self._semseg_loader) - self._class_offset
|
99 |
+
|
100 |
+
if isinstance(self._ignore_label, int):
|
101 |
+
ignore_label = self._ignore_label - self._class_offset
|
102 |
+
gt[gt == self._ignore_label] = self._num_classes
|
103 |
+
elif isinstance(self._ignore_label, list):
|
104 |
+
for ignore_label in self._ignore_label:
|
105 |
+
ignore_label = ignore_label - self._class_offset
|
106 |
+
gt[gt == ignore_label] = self._num_classes
|
107 |
+
|
108 |
+
self._conf_matrix += np.bincount(
|
109 |
+
(self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
|
110 |
+
minlength=self._conf_matrix.size,
|
111 |
+
).reshape(self._conf_matrix.shape)
|
112 |
+
|
113 |
+
self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
|
114 |
+
|
115 |
+
def evaluate(self):
|
116 |
+
"""
|
117 |
+
Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
|
118 |
+
|
119 |
+
* Mean intersection-over-union averaged across classes (mIoU)
|
120 |
+
* Frequency Weighted IoU (fwIoU)
|
121 |
+
* Mean pixel accuracy averaged across classes (mACC)
|
122 |
+
* Pixel Accuracy (pACC)
|
123 |
+
"""
|
124 |
+
if self._distributed:
|
125 |
+
synchronize()
|
126 |
+
conf_matrix_list = all_gather(self._conf_matrix)
|
127 |
+
self._predictions = all_gather(self._predictions)
|
128 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
129 |
+
if not is_main_process():
|
130 |
+
return
|
131 |
+
self._conf_matrix = np.zeros_like(self._conf_matrix)
|
132 |
+
for conf_matrix in conf_matrix_list:
|
133 |
+
self._conf_matrix += conf_matrix
|
134 |
+
|
135 |
+
if self._output_dir:
|
136 |
+
PathManager.mkdirs(self._output_dir)
|
137 |
+
file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
|
138 |
+
with PathManager.open(file_path, "w") as f:
|
139 |
+
f.write(json.dumps(self._predictions))
|
140 |
+
|
141 |
+
acc = np.full(self._num_classes, np.nan, dtype=np.float)
|
142 |
+
iou = np.full(self._num_classes, np.nan, dtype=np.float)
|
143 |
+
tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
|
144 |
+
pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
|
145 |
+
class_weights = pos_gt / np.sum(pos_gt)
|
146 |
+
pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
|
147 |
+
acc_valid = pos_gt > 0
|
148 |
+
acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
|
149 |
+
iou_valid = (pos_gt + pos_pred) > 0
|
150 |
+
union = pos_gt + pos_pred - tp
|
151 |
+
iou[acc_valid] = tp[acc_valid] / union[acc_valid]
|
152 |
+
macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
|
153 |
+
miou = np.sum(iou[acc_valid]) / np.sum(iou_valid)
|
154 |
+
fiou = np.sum(iou[acc_valid] * class_weights[acc_valid])
|
155 |
+
pacc = np.sum(tp) / np.sum(pos_gt)
|
156 |
+
|
157 |
+
res = {}
|
158 |
+
res["mIoU"] = 100 * miou
|
159 |
+
res["fwIoU"] = 100 * fiou
|
160 |
+
for i, name in enumerate(self._class_names):
|
161 |
+
res["IoU-{}".format(name)] = 100 * iou[i]
|
162 |
+
res["mACC"] = 100 * macc
|
163 |
+
res["pACC"] = 100 * pacc
|
164 |
+
for i, name in enumerate(self._class_names):
|
165 |
+
res["ACC-{}".format(name)] = 100 * acc[i]
|
166 |
+
|
167 |
+
if self._output_dir:
|
168 |
+
file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
|
169 |
+
with PathManager.open(file_path, "wb") as f:
|
170 |
+
torch.save(res, f)
|
171 |
+
results = OrderedDict({"sem_seg": res})
|
172 |
+
self._logger.info(results)
|
173 |
+
return results
|
174 |
+
|
175 |
+
def encode_json_sem_seg(self, sem_seg, input_file_name):
|
176 |
+
"""
|
177 |
+
Convert semantic segmentation to COCO stuff format with segments encoded as RLEs.
|
178 |
+
See http://cocodataset.org/#format-results
|
179 |
+
"""
|
180 |
+
json_list = []
|
181 |
+
for label in np.unique(sem_seg):
|
182 |
+
if self._contiguous_id_to_dataset_id is not None:
|
183 |
+
assert (
|
184 |
+
label in self._contiguous_id_to_dataset_id
|
185 |
+
), "Label {} is not in the metadata info for {}".format(label, self._dataset_name)
|
186 |
+
dataset_id = self._contiguous_id_to_dataset_id[label]
|
187 |
+
else:
|
188 |
+
dataset_id = int(label)
|
189 |
+
mask = (sem_seg == label).astype(np.uint8)
|
190 |
+
mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0]
|
191 |
+
mask_rle["counts"] = mask_rle["counts"].decode("utf-8")
|
192 |
+
json_list.append(
|
193 |
+
{"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle}
|
194 |
+
)
|
195 |
+
return json_list
|
datasets/refer.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__author__ = 'licheng'
|
2 |
+
|
3 |
+
"""
|
4 |
+
This interface provides access to four datasets:
|
5 |
+
1) refclef
|
6 |
+
2) refcoco
|
7 |
+
3) refcoco+
|
8 |
+
4) refcocog
|
9 |
+
split by unc and google
|
10 |
+
|
11 |
+
The following API functions are defined:
|
12 |
+
REFER - REFER api class
|
13 |
+
getRefIds - get ref ids that satisfy given filter conditions.
|
14 |
+
getAnnIds - get ann ids that satisfy given filter conditions.
|
15 |
+
getImgIds - get image ids that satisfy given filter conditions.
|
16 |
+
getCatIds - get category ids that satisfy given filter conditions.
|
17 |
+
loadRefs - load refs with the specified ref ids.
|
18 |
+
loadAnns - load anns with the specified ann ids.
|
19 |
+
loadImgs - load images with the specified image ids.
|
20 |
+
loadCats - load category names with the specified category ids.
|
21 |
+
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
|
22 |
+
showRef - show image, segmentation or box of the referred object with the ref
|
23 |
+
getMask - get mask and area of the referred object given ref
|
24 |
+
showMask - show mask of the referred object given ref
|
25 |
+
"""
|
26 |
+
|
27 |
+
from doctest import REPORT_ONLY_FIRST_FAILURE
|
28 |
+
import sys
|
29 |
+
import os.path as osp
|
30 |
+
import json
|
31 |
+
import pickle
|
32 |
+
import time
|
33 |
+
import itertools
|
34 |
+
import skimage.io as io
|
35 |
+
import matplotlib.pyplot as plt
|
36 |
+
from matplotlib.collections import PatchCollection
|
37 |
+
from matplotlib.patches import Polygon, Rectangle
|
38 |
+
from pprint import pprint
|
39 |
+
import numpy as np
|
40 |
+
from pycocotools import mask
|
41 |
+
# import cv2
|
42 |
+
# from skimage.measure import label, regionprops
|
43 |
+
|
44 |
+
|
45 |
+
class REFER:
|
46 |
+
def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
|
47 |
+
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
|
48 |
+
# also provide dataset name and splitBy information
|
49 |
+
# e.g., dataset = 'refcoco', splitBy = 'unc'
|
50 |
+
print('loading dataset {} into memory...'.format(dataset))
|
51 |
+
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
52 |
+
self.DATA_DIR = osp.join(data_root, dataset)
|
53 |
+
if dataset in ['refcoco', 'refcoco+', 'refcocog']:
|
54 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
|
55 |
+
elif dataset == 'refclef':
|
56 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
|
57 |
+
else:
|
58 |
+
print('No refer dataset is called [{}]'.format(dataset))
|
59 |
+
sys.exit()
|
60 |
+
|
61 |
+
# load refs from data/dataset/refs(dataset).json
|
62 |
+
tic = time.time()
|
63 |
+
ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
|
64 |
+
self.data = {}
|
65 |
+
self.data['dataset'] = dataset
|
66 |
+
self.data['refs'] = pickle.load(open(ref_file, 'rb'))
|
67 |
+
|
68 |
+
# load annotations from data/dataset/instances.json
|
69 |
+
instances_file = osp.join(self.DATA_DIR, 'instances.json')
|
70 |
+
instances = json.load(open(instances_file, 'r'))
|
71 |
+
self.data['images'] = instances['images']
|
72 |
+
self.data['annotations'] = instances['annotations']
|
73 |
+
self.data['categories'] = instances['categories']
|
74 |
+
|
75 |
+
# create index
|
76 |
+
self.createIndex()
|
77 |
+
print('DONE (t=%.2fs)'.format(time.time()-tic))
|
78 |
+
|
79 |
+
def createIndex(self):
|
80 |
+
# create sets of mapping
|
81 |
+
# 1) Refs: {ref_id: ref}
|
82 |
+
# 2) Anns: {ann_id: ann}
|
83 |
+
# 3) Imgs: {image_id: image}
|
84 |
+
# 4) Cats: {category_id: category_name}
|
85 |
+
# 5) Sents: {sent_id: sent}
|
86 |
+
# 6) imgToRefs: {image_id: refs}
|
87 |
+
# 7) imgToAnns: {image_id: anns}
|
88 |
+
# 8) refToAnn: {ref_id: ann}
|
89 |
+
# 9) annToRef: {ann_id: ref}
|
90 |
+
# 10) catToRefs: {category_id: refs}
|
91 |
+
# 11) sentToRef: {sent_id: ref}
|
92 |
+
# 12) sentToTokens: {sent_id: tokens}
|
93 |
+
print('creating index...')
|
94 |
+
# fetch info from instances
|
95 |
+
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
96 |
+
for ann in self.data['annotations']:
|
97 |
+
Anns[ann['id']] = ann
|
98 |
+
imgToAnns[ann['image_id']] = imgToAnns.get(
|
99 |
+
ann['image_id'], []) + [ann]
|
100 |
+
for img in self.data['images']:
|
101 |
+
Imgs[img['id']] = img
|
102 |
+
for cat in self.data['categories']:
|
103 |
+
Cats[cat['id']] = cat['name']
|
104 |
+
|
105 |
+
# fetch info from refs
|
106 |
+
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
107 |
+
Sents, sentToRef, sentToTokens = {}, {}, {}
|
108 |
+
for ref in self.data['refs']:
|
109 |
+
# ids
|
110 |
+
ref_id = ref['ref_id']
|
111 |
+
ann_id = ref['ann_id']
|
112 |
+
category_id = ref['category_id']
|
113 |
+
image_id = ref['image_id']
|
114 |
+
|
115 |
+
# add mapping related to ref
|
116 |
+
Refs[ref_id] = ref
|
117 |
+
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
118 |
+
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
|
119 |
+
refToAnn[ref_id] = Anns[ann_id]
|
120 |
+
annToRef[ann_id] = ref
|
121 |
+
|
122 |
+
# add mapping of sent
|
123 |
+
for sent in ref['sentences']:
|
124 |
+
Sents[sent['sent_id']] = sent
|
125 |
+
sentToRef[sent['sent_id']] = ref
|
126 |
+
sentToTokens[sent['sent_id']] = sent['tokens']
|
127 |
+
|
128 |
+
# create class members
|
129 |
+
self.Refs = Refs
|
130 |
+
self.Anns = Anns
|
131 |
+
self.Imgs = Imgs
|
132 |
+
self.Cats = Cats
|
133 |
+
self.Sents = Sents
|
134 |
+
self.imgToRefs = imgToRefs
|
135 |
+
self.imgToAnns = imgToAnns
|
136 |
+
self.refToAnn = refToAnn
|
137 |
+
self.annToRef = annToRef
|
138 |
+
self.catToRefs = catToRefs
|
139 |
+
self.sentToRef = sentToRef
|
140 |
+
self.sentToTokens = sentToTokens
|
141 |
+
print('index created.')
|
142 |
+
|
143 |
+
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
|
144 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
145 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
146 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
147 |
+
|
148 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
|
149 |
+
refs = self.data['refs']
|
150 |
+
else:
|
151 |
+
if not len(image_ids) == 0:
|
152 |
+
refs = [self.imgToRefs[image_id] for image_id in image_ids]
|
153 |
+
else:
|
154 |
+
refs = self.data['refs']
|
155 |
+
if not len(cat_ids) == 0:
|
156 |
+
refs = [ref for ref in refs if ref['category_id'] in cat_ids]
|
157 |
+
if not len(ref_ids) == 0:
|
158 |
+
refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
|
159 |
+
if not len(split) == 0:
|
160 |
+
if split in ['testA', 'testB', 'testC']:
|
161 |
+
# we also consider testAB, testBC, ...
|
162 |
+
refs = [ref for ref in refs if split[-1] in ref['split']]
|
163 |
+
elif split in ['testAB', 'testBC', 'testAC']:
|
164 |
+
# rarely used I guess...
|
165 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
166 |
+
elif split == 'test':
|
167 |
+
refs = [ref for ref in refs if 'test' in ref['split']]
|
168 |
+
elif split == 'train' or split == 'val':
|
169 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
170 |
+
else:
|
171 |
+
print('No such split [{}]'.format(split))
|
172 |
+
sys.exit()
|
173 |
+
ref_ids = [ref['ref_id'] for ref in refs]
|
174 |
+
return ref_ids
|
175 |
+
|
176 |
+
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
|
177 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
178 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
179 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
180 |
+
|
181 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
|
182 |
+
ann_ids = [ann['id'] for ann in self.data['annotations']]
|
183 |
+
else:
|
184 |
+
if not len(image_ids) == 0:
|
185 |
+
lists = [self.imgToAnns[image_id]
|
186 |
+
for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
|
187 |
+
anns = list(itertools.chain.from_iterable(lists))
|
188 |
+
else:
|
189 |
+
anns = self.data['annotations']
|
190 |
+
if not len(cat_ids) == 0:
|
191 |
+
anns = [ann for ann in anns if ann['category_id'] in cat_ids]
|
192 |
+
ann_ids = [ann['id'] for ann in anns]
|
193 |
+
if not len(ref_ids) == 0:
|
194 |
+
ids = set(ann_ids).intersection(
|
195 |
+
set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
|
196 |
+
return ann_ids
|
197 |
+
|
198 |
+
def getImgIds(self, ref_ids=[]):
|
199 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
200 |
+
|
201 |
+
if not len(ref_ids) == 0:
|
202 |
+
image_ids = list(set([self.Refs[ref_id]['image_id']
|
203 |
+
for ref_id in ref_ids]))
|
204 |
+
else:
|
205 |
+
image_ids = self.Imgs.keys()
|
206 |
+
return image_ids
|
207 |
+
|
208 |
+
def getCatIds(self):
|
209 |
+
return self.Cats.keys()
|
210 |
+
|
211 |
+
def loadRefs(self, ref_ids=[]):
|
212 |
+
if type(ref_ids) == list:
|
213 |
+
return [self.Refs[ref_id] for ref_id in ref_ids]
|
214 |
+
elif type(ref_ids) == int:
|
215 |
+
return [self.Refs[ref_ids]]
|
216 |
+
|
217 |
+
def loadAnns(self, ann_ids=[]):
|
218 |
+
if type(ann_ids) == list:
|
219 |
+
return [self.Anns[ann_id] for ann_id in ann_ids]
|
220 |
+
elif type(ann_ids) == int or type(ann_ids) == unicode:
|
221 |
+
return [self.Anns[ann_ids]]
|
222 |
+
|
223 |
+
def loadImgs(self, image_ids=[]):
|
224 |
+
if type(image_ids) == list:
|
225 |
+
return [self.Imgs[image_id] for image_id in image_ids]
|
226 |
+
elif type(image_ids) == int:
|
227 |
+
return [self.Imgs[image_ids]]
|
228 |
+
|
229 |
+
def loadCats(self, cat_ids=[]):
|
230 |
+
if type(cat_ids) == list:
|
231 |
+
return [self.Cats[cat_id] for cat_id in cat_ids]
|
232 |
+
elif type(cat_ids) == int:
|
233 |
+
return [self.Cats[cat_ids]]
|
234 |
+
|
235 |
+
def getRefBox(self, ref_id):
|
236 |
+
ref = self.Refs[ref_id]
|
237 |
+
ann = self.refToAnn[ref_id]
|
238 |
+
return ann['bbox'] # [x, y, w, h]
|
239 |
+
|
240 |
+
def showRef(self, ref, seg_box='seg'):
|
241 |
+
ax = plt.gca()
|
242 |
+
# show image
|
243 |
+
image = self.Imgs[ref['image_id']]
|
244 |
+
I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
245 |
+
ax.imshow(I)
|
246 |
+
# show refer expression
|
247 |
+
for sid, sent in enumerate(ref['sentences']):
|
248 |
+
print('{}. {}'.format(sid+1, sent['sent']))
|
249 |
+
# show segmentations
|
250 |
+
if seg_box == 'seg':
|
251 |
+
ann_id = ref['ann_id']
|
252 |
+
ann = self.Anns[ann_id]
|
253 |
+
polygons = []
|
254 |
+
color = []
|
255 |
+
c = 'none'
|
256 |
+
if type(ann['segmentation'][0]) == list:
|
257 |
+
# polygon used for refcoco*
|
258 |
+
for seg in ann['segmentation']:
|
259 |
+
poly = np.array(seg).reshape((len(seg)/2, 2))
|
260 |
+
polygons.append(Polygon(poly, True, alpha=0.4))
|
261 |
+
color.append(c)
|
262 |
+
p = PatchCollection(polygons, facecolors=color, edgecolors=(
|
263 |
+
1, 1, 0, 0), linewidths=3, alpha=1)
|
264 |
+
ax.add_collection(p) # thick yellow polygon
|
265 |
+
p = PatchCollection(polygons, facecolors=color, edgecolors=(
|
266 |
+
1, 0, 0, 0), linewidths=1, alpha=1)
|
267 |
+
ax.add_collection(p) # thin red polygon
|
268 |
+
else:
|
269 |
+
# mask used for refclef
|
270 |
+
rle = ann['segmentation']
|
271 |
+
m = mask.decode(rle)
|
272 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
273 |
+
color_mask = np.array([2.0, 166.0, 101.0])/255
|
274 |
+
for i in range(3):
|
275 |
+
img[:, :, i] = color_mask[i]
|
276 |
+
ax.imshow(np.dstack((img, m*0.5)))
|
277 |
+
# show bounding-box
|
278 |
+
elif seg_box == 'box':
|
279 |
+
ann_id = ref['ann_id']
|
280 |
+
ann = self.Anns[ann_id]
|
281 |
+
bbox = self.getRefBox(ref['ref_id'])
|
282 |
+
box_plot = Rectangle(
|
283 |
+
(bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
|
284 |
+
ax.add_patch(box_plot)
|
285 |
+
|
286 |
+
def getMask(self, ref):
|
287 |
+
# return mask, area and mask-center
|
288 |
+
ann = self.refToAnn[ref['ref_id']]
|
289 |
+
image = self.Imgs[ref['image_id']]
|
290 |
+
if type(ann['segmentation'][0]) == list: # polygon
|
291 |
+
rle = mask.frPyObjects(
|
292 |
+
ann['segmentation'], image['height'], image['width'])
|
293 |
+
else:
|
294 |
+
rle = ann['segmentation']
|
295 |
+
m = mask.decode(rle)
|
296 |
+
# sometimes there are multiple binary map (corresponding to multiple segs)
|
297 |
+
m = np.sum(m, axis=2)
|
298 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
299 |
+
# compute area
|
300 |
+
area = sum(mask.area(rle)) # should be close to ann['area']
|
301 |
+
return {'mask': m, 'area': area}
|
302 |
+
# # position
|
303 |
+
# position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
|
304 |
+
# position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
|
305 |
+
# # mass position (if there were multiple regions, we use the largest one.)
|
306 |
+
# label_m = label(m, connectivity=m.ndim)
|
307 |
+
# regions = regionprops(label_m)
|
308 |
+
# if len(regions) > 0:
|
309 |
+
# largest_id = np.argmax(np.array([props.filled_area for props in regions]))
|
310 |
+
# largest_props = regions[largest_id]
|
311 |
+
# mass_y, mass_x = largest_props.centroid
|
312 |
+
# else:
|
313 |
+
# mass_x, mass_y = position_x, position_y
|
314 |
+
# # if centroid is not in mask, we find the closest point to it from mask
|
315 |
+
# if m[mass_y, mass_x] != 1:
|
316 |
+
# print 'Finding closes mask point ...'
|
317 |
+
# kernel = np.ones((10, 10),np.uint8)
|
318 |
+
# me = cv2.erode(m, kernel, iterations = 1)
|
319 |
+
# points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
|
320 |
+
# points = np.array(points)
|
321 |
+
# dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
|
322 |
+
# id = np.argsort(dist)[0]
|
323 |
+
# mass_y, mass_x = points[id]
|
324 |
+
# # return
|
325 |
+
# return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
|
326 |
+
# # show image and mask
|
327 |
+
# I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
328 |
+
# plt.figure()
|
329 |
+
# plt.imshow(I)
|
330 |
+
# ax = plt.gca()
|
331 |
+
# img = np.ones( (m.shape[0], m.shape[1], 3) )
|
332 |
+
# color_mask = np.array([2.0,166.0,101.0])/255
|
333 |
+
# for i in range(3):
|
334 |
+
# img[:,:,i] = color_mask[i]
|
335 |
+
# ax.imshow(np.dstack( (img, m*0.5) ))
|
336 |
+
# plt.show()
|
337 |
+
|
338 |
+
def showMask(self, ref):
|
339 |
+
M = self.getMask(ref)
|
340 |
+
msk = M['mask']
|
341 |
+
ax = plt.gca()
|
342 |
+
ax.imshow(msk)
|
343 |
+
|
344 |
+
|
345 |
+
if __name__ == '__main__':
|
346 |
+
refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg',
|
347 |
+
dataset='refcocog', splitBy='google')
|
348 |
+
ref_ids = refer.getRefIds()
|
349 |
+
print(len(ref_ids))
|
350 |
+
|
351 |
+
print(len(refer.Imgs))
|
352 |
+
print(len(refer.imgToRefs))
|
353 |
+
|
354 |
+
ref_ids = refer.getRefIds(split='train')
|
355 |
+
print('There are {} training referred objects.' % len(ref_ids))
|
356 |
+
|
357 |
+
for ref_id in ref_ids:
|
358 |
+
ref = refer.loadRefs(ref_id)[0]
|
359 |
+
if len(ref['sentences']) < 2:
|
360 |
+
continue
|
361 |
+
|
362 |
+
pprint(ref)
|
363 |
+
print('The label is {}.'.format(refer.Cats[ref['category_id']]))
|
364 |
+
|
365 |
+
# plt.figure()
|
366 |
+
# refer.showRef(ref, seg_box='box')
|
367 |
+
# plt.show()
|
368 |
+
|
369 |
+
# plt.figure()
|
370 |
+
# refer.showMask(ref)
|
371 |
+
# plt.show()
|
datasets/registration/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from . import (
|
2 |
+
register_biomed_datasets
|
3 |
+
)
|
datasets/registration/register_biomed_datasets.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Modified by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import collections
|
10 |
+
|
11 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
12 |
+
from detectron2.data.datasets import load_sem_seg
|
13 |
+
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
|
14 |
+
from detectron2.utils.file_io import PathManager
|
15 |
+
|
16 |
+
|
17 |
+
_PREDEFINED_SPLITS_BIOMED = {}
|
18 |
+
|
19 |
+
# example of registering a dataset
|
20 |
+
datasets = ['BiomedParseData-Demo', ] # provide name of the dataset under biomedparse_datasets
|
21 |
+
splits = ['demo'] # provide split name, e.g., train, test, val. Here there is only one 'demo' split in the example demo dataset
|
22 |
+
|
23 |
+
# Here we register all the splits of the dataset
|
24 |
+
for name in datasets:
|
25 |
+
for split in splits:
|
26 |
+
dataname = f'biomed_{name.replace("/", "-")}_{split}'
|
27 |
+
image_root = f"{name}/{split}"
|
28 |
+
ann_root = f"{name}/{split}.json"
|
29 |
+
_PREDEFINED_SPLITS_BIOMED[dataname] = (image_root, ann_root)
|
30 |
+
# The resulting dataset name is: biomed_BiomedParseData-Demo_demo
|
31 |
+
|
32 |
+
# # Add your dataset here
|
33 |
+
# datasets = ['YOUR_DATASET_NAME', ] # provide name of the dataset under biomedparse_datasets
|
34 |
+
# splits = ['train', 'test'] # provide split name, e.g., train, test, val
|
35 |
+
|
36 |
+
# # Here we register all the splits of the dataset
|
37 |
+
# for name in datasets:
|
38 |
+
# for split in splits:
|
39 |
+
# dataname = f'biomed_{name.replace("/", "-")}_{split}'
|
40 |
+
# image_root = f"{name}/{split}"
|
41 |
+
# ann_root = f"{name}/{split}.json"
|
42 |
+
# _PREDEFINED_SPLITS_BIOMED[dataname] = (image_root, ann_root)
|
43 |
+
# # The resulting dataset names are: biomed_YOUR_DATASET_NAME_train, biomed_YOUR_DATASET_NAME_test
|
44 |
+
|
45 |
+
|
46 |
+
def get_metadata():
|
47 |
+
meta = {}
|
48 |
+
return meta
|
49 |
+
|
50 |
+
|
51 |
+
def load_biomed_json(image_root, annot_json, metadata):
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
|
55 |
+
gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
|
56 |
+
json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
|
57 |
+
Returns:
|
58 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
59 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
60 |
+
"""
|
61 |
+
|
62 |
+
with PathManager.open(annot_json) as f:
|
63 |
+
json_info = json.load(f)
|
64 |
+
|
65 |
+
# build dictionary for grounding
|
66 |
+
grd_dict = collections.defaultdict(list)
|
67 |
+
for grd_ann in json_info['annotations']:
|
68 |
+
image_id = int(grd_ann["image_id"])
|
69 |
+
grd_dict[image_id].append(grd_ann)
|
70 |
+
|
71 |
+
mask_root = image_root + '_mask'
|
72 |
+
ret = []
|
73 |
+
for image in json_info["images"]:
|
74 |
+
image_id = int(image["id"])
|
75 |
+
image_file = os.path.join(image_root, image['file_name'])
|
76 |
+
grounding_anno = grd_dict[image_id]
|
77 |
+
for ann in grounding_anno:
|
78 |
+
if 'mask_file' not in ann:
|
79 |
+
ann['mask_file'] = image['file_name']
|
80 |
+
ann['mask_file'] = os.path.join(mask_root, ann['mask_file'])
|
81 |
+
ret.append(
|
82 |
+
{
|
83 |
+
"file_name": image_file,
|
84 |
+
"image_id": image_id,
|
85 |
+
"grounding_info": [ann],
|
86 |
+
}
|
87 |
+
)
|
88 |
+
assert len(ret), f"No images found in {image_root}!"
|
89 |
+
assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
|
90 |
+
return ret
|
91 |
+
|
92 |
+
|
93 |
+
def register_biomed(
|
94 |
+
name, metadata, image_root, annot_json):
|
95 |
+
DatasetCatalog.register(
|
96 |
+
name,
|
97 |
+
lambda: load_biomed_json(image_root, annot_json, metadata),
|
98 |
+
)
|
99 |
+
MetadataCatalog.get(name).set(
|
100 |
+
image_root=image_root,
|
101 |
+
json_file=annot_json,
|
102 |
+
evaluator_type="grounding_refcoco",
|
103 |
+
ignore_label=255,
|
104 |
+
label_divisor=1000,
|
105 |
+
**metadata,
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def register_all_biomed(root):
|
110 |
+
for (
|
111 |
+
prefix,
|
112 |
+
(image_root, annot_root),
|
113 |
+
) in _PREDEFINED_SPLITS_BIOMED.items():
|
114 |
+
register_biomed(
|
115 |
+
prefix,
|
116 |
+
get_metadata(),
|
117 |
+
os.path.join(root, image_root),
|
118 |
+
os.path.join(root, annot_root),
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
_root = os.getenv("DATASET", "datasets")
|
123 |
+
register_all_biomed(_root)
|
datasets/semseg_loader.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import scipy.io
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def load_semseg(filename, loader_type):
|
6 |
+
if loader_type == 'PIL':
|
7 |
+
semseg = np.array(Image.open(filename), dtype=np.int)
|
8 |
+
elif loader_type == 'MAT':
|
9 |
+
semseg = scipy.io.loadmat(filename)['LabelMap']
|
10 |
+
return semseg
|
datasets/utils/refcoco2json.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from refer import REFER
|
4 |
+
|
5 |
+
coco_root = '/pth/to/coco'
|
6 |
+
ref_root = '/pth/to/refcocoseg'
|
7 |
+
|
8 |
+
coco_train_annot = json.load(open(os.path.join(coco_root, 'annotations/instances_train2017.json')))
|
9 |
+
coco_train_id = []
|
10 |
+
image_annot = {}
|
11 |
+
for i in range(len(coco_train_annot['images'])):
|
12 |
+
coco_train_id.append(coco_train_annot['images'][i]['id'])
|
13 |
+
image_annot[coco_train_annot['images'][i]['id']] = coco_train_annot['images'][i]
|
14 |
+
|
15 |
+
refg = REFER(data_root=ref_root,
|
16 |
+
dataset='refcocog', splitBy='umd')
|
17 |
+
refg_val_ids = refg.getRefIds(split='val')
|
18 |
+
|
19 |
+
full_anno = []
|
20 |
+
for ref_id in refg_val_ids:
|
21 |
+
ref = refg.loadRefs(ref_id)[0]
|
22 |
+
anno = refg.refToAnn[ref_id]
|
23 |
+
anno.update(ref)
|
24 |
+
full_anno.append(anno)
|
25 |
+
|
26 |
+
imageid_list = []
|
27 |
+
final_anno = {}
|
28 |
+
for anno in full_anno:
|
29 |
+
imageid_list += [anno['image_id']]
|
30 |
+
final_anno[anno['ann_id']] = anno
|
31 |
+
|
32 |
+
annotations = [value for key, value in final_anno.items()]
|
33 |
+
|
34 |
+
iamges = []
|
35 |
+
for image_id in list(set(imageid_list)):
|
36 |
+
iamges += [image_annot[image_id]]
|
37 |
+
|
38 |
+
outputs = {'images': iamges, 'annotations': annotations}
|
39 |
+
print(len(iamges))
|
40 |
+
print(len(annotations))
|
41 |
+
json.dump(outputs, open(os.path.join(coco_root, 'annotations/refcocog_umd_train.json'), 'w'))
|
datasets/utils/refer.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This code is modified from https://github.com/lichengunc/refer, and with minor modification of python2/3 format
|
2 |
+
__author__ = 'licheng'
|
3 |
+
|
4 |
+
"""
|
5 |
+
This interface provides access to four datasets:
|
6 |
+
1) refclef
|
7 |
+
2) refcoco
|
8 |
+
3) refcoco+
|
9 |
+
4) refcocog
|
10 |
+
split by unc and google
|
11 |
+
|
12 |
+
The following API functions are defined:
|
13 |
+
REFER - REFER api class
|
14 |
+
getRefIds - get ref ids that satisfy given filter conditions.
|
15 |
+
getAnnIds - get ann ids that satisfy given filter conditions.
|
16 |
+
getImgIds - get image ids that satisfy given filter conditions.
|
17 |
+
getCatIds - get category ids that satisfy given filter conditions.
|
18 |
+
loadRefs - load refs with the specified ref ids.
|
19 |
+
loadAnns - load anns with the specified ann ids.
|
20 |
+
loadImgs - load images with the specified image ids.
|
21 |
+
loadCats - load category names with the specified category ids.
|
22 |
+
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
|
23 |
+
showRef - show image, segmentation or box of the referred object with the ref
|
24 |
+
getMask - get mask and area of the referred object given ref
|
25 |
+
showMask - show mask of the referred object given ref
|
26 |
+
"""
|
27 |
+
|
28 |
+
from doctest import REPORT_ONLY_FIRST_FAILURE
|
29 |
+
import sys
|
30 |
+
import os.path as osp
|
31 |
+
import json
|
32 |
+
import pickle
|
33 |
+
import time
|
34 |
+
import itertools
|
35 |
+
import skimage.io as io
|
36 |
+
import matplotlib.pyplot as plt
|
37 |
+
from matplotlib.collections import PatchCollection
|
38 |
+
from matplotlib.patches import Polygon, Rectangle
|
39 |
+
from pprint import pprint
|
40 |
+
import numpy as np
|
41 |
+
from pycocotools import mask
|
42 |
+
# import cv2
|
43 |
+
# from skimage.measure import label, regionprops
|
44 |
+
|
45 |
+
|
46 |
+
class REFER:
|
47 |
+
def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
|
48 |
+
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
|
49 |
+
# also provide dataset name and splitBy information
|
50 |
+
# e.g., dataset = 'refcoco', splitBy = 'unc'
|
51 |
+
print('loading dataset {} into memory...'.format(dataset))
|
52 |
+
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
53 |
+
self.DATA_DIR = osp.join(data_root, dataset)
|
54 |
+
if dataset in ['refcoco', 'refcoco+', 'refcocog']:
|
55 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
|
56 |
+
elif dataset == 'refclef':
|
57 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
|
58 |
+
else:
|
59 |
+
print('No refer dataset is called [{}]'.format(dataset))
|
60 |
+
sys.exit()
|
61 |
+
|
62 |
+
# load refs from data/dataset/refs(dataset).json
|
63 |
+
tic = time.time()
|
64 |
+
ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
|
65 |
+
self.data = {}
|
66 |
+
self.data['dataset'] = dataset
|
67 |
+
self.data['refs'] = pickle.load(open(ref_file, 'rb'))
|
68 |
+
|
69 |
+
# load annotations from data/dataset/instances.json
|
70 |
+
instances_file = osp.join(self.DATA_DIR, 'instances.json')
|
71 |
+
instances = json.load(open(instances_file, 'r'))
|
72 |
+
self.data['images'] = instances['images']
|
73 |
+
self.data['annotations'] = instances['annotations']
|
74 |
+
self.data['categories'] = instances['categories']
|
75 |
+
|
76 |
+
# create index
|
77 |
+
self.createIndex()
|
78 |
+
print('DONE (t=%.2fs)'.format(time.time()-tic))
|
79 |
+
|
80 |
+
def createIndex(self):
|
81 |
+
# create sets of mapping
|
82 |
+
# 1) Refs: {ref_id: ref}
|
83 |
+
# 2) Anns: {ann_id: ann}
|
84 |
+
# 3) Imgs: {image_id: image}
|
85 |
+
# 4) Cats: {category_id: category_name}
|
86 |
+
# 5) Sents: {sent_id: sent}
|
87 |
+
# 6) imgToRefs: {image_id: refs}
|
88 |
+
# 7) imgToAnns: {image_id: anns}
|
89 |
+
# 8) refToAnn: {ref_id: ann}
|
90 |
+
# 9) annToRef: {ann_id: ref}
|
91 |
+
# 10) catToRefs: {category_id: refs}
|
92 |
+
# 11) sentToRef: {sent_id: ref}
|
93 |
+
# 12) sentToTokens: {sent_id: tokens}
|
94 |
+
print('creating index...')
|
95 |
+
# fetch info from instances
|
96 |
+
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
97 |
+
for ann in self.data['annotations']:
|
98 |
+
Anns[ann['id']] = ann
|
99 |
+
imgToAnns[ann['image_id']] = imgToAnns.get(
|
100 |
+
ann['image_id'], []) + [ann]
|
101 |
+
for img in self.data['images']:
|
102 |
+
Imgs[img['id']] = img
|
103 |
+
for cat in self.data['categories']:
|
104 |
+
Cats[cat['id']] = cat['name']
|
105 |
+
|
106 |
+
# fetch info from refs
|
107 |
+
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
108 |
+
Sents, sentToRef, sentToTokens = {}, {}, {}
|
109 |
+
for ref in self.data['refs']:
|
110 |
+
# ids
|
111 |
+
ref_id = ref['ref_id']
|
112 |
+
ann_id = ref['ann_id']
|
113 |
+
category_id = ref['category_id']
|
114 |
+
image_id = ref['image_id']
|
115 |
+
|
116 |
+
# add mapping related to ref
|
117 |
+
Refs[ref_id] = ref
|
118 |
+
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
119 |
+
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
|
120 |
+
refToAnn[ref_id] = Anns[ann_id]
|
121 |
+
annToRef[ann_id] = ref
|
122 |
+
|
123 |
+
# add mapping of sent
|
124 |
+
for sent in ref['sentences']:
|
125 |
+
Sents[sent['sent_id']] = sent
|
126 |
+
sentToRef[sent['sent_id']] = ref
|
127 |
+
sentToTokens[sent['sent_id']] = sent['tokens']
|
128 |
+
|
129 |
+
# create class members
|
130 |
+
self.Refs = Refs
|
131 |
+
self.Anns = Anns
|
132 |
+
self.Imgs = Imgs
|
133 |
+
self.Cats = Cats
|
134 |
+
self.Sents = Sents
|
135 |
+
self.imgToRefs = imgToRefs
|
136 |
+
self.imgToAnns = imgToAnns
|
137 |
+
self.refToAnn = refToAnn
|
138 |
+
self.annToRef = annToRef
|
139 |
+
self.catToRefs = catToRefs
|
140 |
+
self.sentToRef = sentToRef
|
141 |
+
self.sentToTokens = sentToTokens
|
142 |
+
print('index created.')
|
143 |
+
|
144 |
+
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
|
145 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
146 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
147 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
148 |
+
|
149 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
|
150 |
+
refs = self.data['refs']
|
151 |
+
else:
|
152 |
+
if not len(image_ids) == 0:
|
153 |
+
refs = [self.imgToRefs[image_id] for image_id in image_ids]
|
154 |
+
else:
|
155 |
+
refs = self.data['refs']
|
156 |
+
if not len(cat_ids) == 0:
|
157 |
+
refs = [ref for ref in refs if ref['category_id'] in cat_ids]
|
158 |
+
if not len(ref_ids) == 0:
|
159 |
+
refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
|
160 |
+
if not len(split) == 0:
|
161 |
+
if split in ['testA', 'testB', 'testC']:
|
162 |
+
# we also consider testAB, testBC, ...
|
163 |
+
refs = [ref for ref in refs if split[-1] in ref['split']]
|
164 |
+
elif split in ['testAB', 'testBC', 'testAC']:
|
165 |
+
# rarely used I guess...
|
166 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
167 |
+
elif split == 'test':
|
168 |
+
refs = [ref for ref in refs if 'test' in ref['split']]
|
169 |
+
elif split == 'train' or split == 'val':
|
170 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
171 |
+
else:
|
172 |
+
print('No such split [{}]'.format(split))
|
173 |
+
sys.exit()
|
174 |
+
ref_ids = [ref['ref_id'] for ref in refs]
|
175 |
+
return ref_ids
|
176 |
+
|
177 |
+
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
|
178 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
179 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
180 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
181 |
+
|
182 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
|
183 |
+
ann_ids = [ann['id'] for ann in self.data['annotations']]
|
184 |
+
else:
|
185 |
+
if not len(image_ids) == 0:
|
186 |
+
lists = [self.imgToAnns[image_id]
|
187 |
+
for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
|
188 |
+
anns = list(itertools.chain.from_iterable(lists))
|
189 |
+
else:
|
190 |
+
anns = self.data['annotations']
|
191 |
+
if not len(cat_ids) == 0:
|
192 |
+
anns = [ann for ann in anns if ann['category_id'] in cat_ids]
|
193 |
+
ann_ids = [ann['id'] for ann in anns]
|
194 |
+
if not len(ref_ids) == 0:
|
195 |
+
ids = set(ann_ids).intersection(
|
196 |
+
set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
|
197 |
+
return ann_ids
|
198 |
+
|
199 |
+
def getImgIds(self, ref_ids=[]):
|
200 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
201 |
+
|
202 |
+
if not len(ref_ids) == 0:
|
203 |
+
image_ids = list(set([self.Refs[ref_id]['image_id']
|
204 |
+
for ref_id in ref_ids]))
|
205 |
+
else:
|
206 |
+
image_ids = self.Imgs.keys()
|
207 |
+
return image_ids
|
208 |
+
|
209 |
+
def getCatIds(self):
|
210 |
+
return self.Cats.keys()
|
211 |
+
|
212 |
+
def loadRefs(self, ref_ids=[]):
|
213 |
+
if type(ref_ids) == list:
|
214 |
+
return [self.Refs[ref_id] for ref_id in ref_ids]
|
215 |
+
elif type(ref_ids) == int:
|
216 |
+
return [self.Refs[ref_ids]]
|
217 |
+
|
218 |
+
def loadAnns(self, ann_ids=[]):
|
219 |
+
if type(ann_ids) == list:
|
220 |
+
return [self.Anns[ann_id] for ann_id in ann_ids]
|
221 |
+
elif type(ann_ids) == int or type(ann_ids) == unicode:
|
222 |
+
return [self.Anns[ann_ids]]
|
223 |
+
|
224 |
+
def loadImgs(self, image_ids=[]):
|
225 |
+
if type(image_ids) == list:
|
226 |
+
return [self.Imgs[image_id] for image_id in image_ids]
|
227 |
+
elif type(image_ids) == int:
|
228 |
+
return [self.Imgs[image_ids]]
|
229 |
+
|
230 |
+
def loadCats(self, cat_ids=[]):
|
231 |
+
if type(cat_ids) == list:
|
232 |
+
return [self.Cats[cat_id] for cat_id in cat_ids]
|
233 |
+
elif type(cat_ids) == int:
|
234 |
+
return [self.Cats[cat_ids]]
|
235 |
+
|
236 |
+
def getRefBox(self, ref_id):
|
237 |
+
ref = self.Refs[ref_id]
|
238 |
+
ann = self.refToAnn[ref_id]
|
239 |
+
return ann['bbox'] # [x, y, w, h]
|
240 |
+
|
241 |
+
def showRef(self, ref, seg_box='seg'):
|
242 |
+
ax = plt.gca()
|
243 |
+
# show image
|
244 |
+
image = self.Imgs[ref['image_id']]
|
245 |
+
I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
246 |
+
ax.imshow(I)
|
247 |
+
# show refer expression
|
248 |
+
for sid, sent in enumerate(ref['sentences']):
|
249 |
+
print('{}. {}'.format(sid+1, sent['sent']))
|
250 |
+
# show segmentations
|
251 |
+
if seg_box == 'seg':
|
252 |
+
ann_id = ref['ann_id']
|
253 |
+
ann = self.Anns[ann_id]
|
254 |
+
polygons = []
|
255 |
+
color = []
|
256 |
+
c = 'none'
|
257 |
+
if type(ann['segmentation'][0]) == list:
|
258 |
+
# polygon used for refcoco*
|
259 |
+
for seg in ann['segmentation']:
|
260 |
+
poly = np.array(seg).reshape((len(seg)/2, 2))
|
261 |
+
polygons.append(Polygon(poly, True, alpha=0.4))
|
262 |
+
color.append(c)
|
263 |
+
p = PatchCollection(polygons, facecolors=color, edgecolors=(
|
264 |
+
1, 1, 0, 0), linewidths=3, alpha=1)
|
265 |
+
ax.add_collection(p) # thick yellow polygon
|
266 |
+
p = PatchCollection(polygons, facecolors=color, edgecolors=(
|
267 |
+
1, 0, 0, 0), linewidths=1, alpha=1)
|
268 |
+
ax.add_collection(p) # thin red polygon
|
269 |
+
else:
|
270 |
+
# mask used for refclef
|
271 |
+
rle = ann['segmentation']
|
272 |
+
m = mask.decode(rle)
|
273 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
274 |
+
color_mask = np.array([2.0, 166.0, 101.0])/255
|
275 |
+
for i in range(3):
|
276 |
+
img[:, :, i] = color_mask[i]
|
277 |
+
ax.imshow(np.dstack((img, m*0.5)))
|
278 |
+
# show bounding-box
|
279 |
+
elif seg_box == 'box':
|
280 |
+
ann_id = ref['ann_id']
|
281 |
+
ann = self.Anns[ann_id]
|
282 |
+
bbox = self.getRefBox(ref['ref_id'])
|
283 |
+
box_plot = Rectangle(
|
284 |
+
(bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
|
285 |
+
ax.add_patch(box_plot)
|
286 |
+
|
287 |
+
def getMask(self, ref):
|
288 |
+
# return mask, area and mask-center
|
289 |
+
ann = self.refToAnn[ref['ref_id']]
|
290 |
+
image = self.Imgs[ref['image_id']]
|
291 |
+
if type(ann['segmentation'][0]) == list: # polygon
|
292 |
+
rle = mask.frPyObjects(
|
293 |
+
ann['segmentation'], image['height'], image['width'])
|
294 |
+
else:
|
295 |
+
rle = ann['segmentation']
|
296 |
+
m = mask.decode(rle)
|
297 |
+
# sometimes there are multiple binary map (corresponding to multiple segs)
|
298 |
+
m = np.sum(m, axis=2)
|
299 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
300 |
+
# compute area
|
301 |
+
area = sum(mask.area(rle)) # should be close to ann['area']
|
302 |
+
return {'mask': m, 'area': area}
|
303 |
+
# # position
|
304 |
+
# position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
|
305 |
+
# position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
|
306 |
+
# # mass position (if there were multiple regions, we use the largest one.)
|
307 |
+
# label_m = label(m, connectivity=m.ndim)
|
308 |
+
# regions = regionprops(label_m)
|
309 |
+
# if len(regions) > 0:
|
310 |
+
# largest_id = np.argmax(np.array([props.filled_area for props in regions]))
|
311 |
+
# largest_props = regions[largest_id]
|
312 |
+
# mass_y, mass_x = largest_props.centroid
|
313 |
+
# else:
|
314 |
+
# mass_x, mass_y = position_x, position_y
|
315 |
+
# # if centroid is not in mask, we find the closest point to it from mask
|
316 |
+
# if m[mass_y, mass_x] != 1:
|
317 |
+
# print 'Finding closes mask point ...'
|
318 |
+
# kernel = np.ones((10, 10),np.uint8)
|
319 |
+
# me = cv2.erode(m, kernel, iterations = 1)
|
320 |
+
# points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
|
321 |
+
# points = np.array(points)
|
322 |
+
# dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
|
323 |
+
# id = np.argsort(dist)[0]
|
324 |
+
# mass_y, mass_x = points[id]
|
325 |
+
# # return
|
326 |
+
# return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
|
327 |
+
# # show image and mask
|
328 |
+
# I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
329 |
+
# plt.figure()
|
330 |
+
# plt.imshow(I)
|
331 |
+
# ax = plt.gca()
|
332 |
+
# img = np.ones( (m.shape[0], m.shape[1], 3) )
|
333 |
+
# color_mask = np.array([2.0,166.0,101.0])/255
|
334 |
+
# for i in range(3):
|
335 |
+
# img[:,:,i] = color_mask[i]
|
336 |
+
# ax.imshow(np.dstack( (img, m*0.5) ))
|
337 |
+
# plt.show()
|
338 |
+
|
339 |
+
def showMask(self, ref):
|
340 |
+
M = self.getMask(ref)
|
341 |
+
msk = M['mask']
|
342 |
+
ax = plt.gca()
|
343 |
+
ax.imshow(msk)
|
344 |
+
|
345 |
+
|
346 |
+
if __name__ == '__main__':
|
347 |
+
refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg',
|
348 |
+
dataset='refcocog', splitBy='google')
|
349 |
+
ref_ids = refer.getRefIds()
|
350 |
+
print(len(ref_ids))
|
351 |
+
|
352 |
+
print(len(refer.Imgs))
|
353 |
+
print(len(refer.imgToRefs))
|
354 |
+
|
355 |
+
ref_ids = refer.getRefIds(split='train')
|
356 |
+
print('There are {} training referred objects.' % len(ref_ids))
|
357 |
+
|
358 |
+
for ref_id in ref_ids:
|
359 |
+
ref = refer.loadRefs(ref_id)[0]
|
360 |
+
if len(ref['sentences']) < 2:
|
361 |
+
continue
|
362 |
+
|
363 |
+
pprint(ref)
|
364 |
+
print('The label is {}.'.format(refer.Cats[ref['category_id']]))
|
365 |
+
|
366 |
+
# plt.figure()
|
367 |
+
# refer.showRef(ref, seg_box='box')
|
368 |
+
# plt.show()
|
369 |
+
|
370 |
+
# plt.figure()
|
371 |
+
# refer.showMask(ref)
|
372 |
+
# plt.show()
|
datasets/visual_sampler/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sampler import ShapeSampler
|
2 |
+
from .simpleclick_sampler import SimpleClickSampler
|
3 |
+
|
4 |
+
|
5 |
+
def build_shape_sampler(cfg, **kwargs):
|
6 |
+
sampler_name = cfg['STROKE_SAMPLER']['EVAL']['MODE']
|
7 |
+
if sampler_name == 'random':
|
8 |
+
return ShapeSampler(cfg, **kwargs)
|
9 |
+
elif sampler_name in ['best', 'best_random']:
|
10 |
+
return SimpleClickSampler(cfg, **kwargs)
|
11 |
+
else:
|
12 |
+
assert False, "not implemented"
|
datasets/visual_sampler/circle.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .mask_generators import get_mask_by_input_strokes
|
5 |
+
|
6 |
+
class Circle:
|
7 |
+
def __init__(self, cfg, is_train=True):
|
8 |
+
self.num_stroke = cfg['STROKE_SAMPLER']['CIRCLE']['NUM_STROKES']
|
9 |
+
self.stroke_preset = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PRESET']
|
10 |
+
self.stroke_prob = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PROB']
|
11 |
+
self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
|
12 |
+
self.is_train = is_train
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def get_stroke_preset(stroke_preset):
|
16 |
+
if stroke_preset == 'object_like':
|
17 |
+
return {
|
18 |
+
"nVertexBound": [5, 30],
|
19 |
+
"maxHeadSpeed": 15,
|
20 |
+
"maxHeadAcceleration": (10, 1.5),
|
21 |
+
"brushWidthBound": (20, 50),
|
22 |
+
"nMovePointRatio": 0.5,
|
23 |
+
"maxPiontMove": 10,
|
24 |
+
"maxLineAcceleration": (5, 0.5),
|
25 |
+
"boarderGap": None,
|
26 |
+
"maxInitSpeed": 10,
|
27 |
+
}
|
28 |
+
elif stroke_preset == 'object_like_middle':
|
29 |
+
return {
|
30 |
+
"nVertexBound": [5, 15],
|
31 |
+
"maxHeadSpeed": 8,
|
32 |
+
"maxHeadAcceleration": (4, 1.5),
|
33 |
+
"brushWidthBound": (20, 50),
|
34 |
+
"nMovePointRatio": 0.5,
|
35 |
+
"maxPiontMove": 5,
|
36 |
+
"maxLineAcceleration": (5, 0.5),
|
37 |
+
"boarderGap": None,
|
38 |
+
"maxInitSpeed": 10,
|
39 |
+
}
|
40 |
+
elif stroke_preset == 'object_like_small':
|
41 |
+
return {
|
42 |
+
"nVertexBound": [5, 20],
|
43 |
+
"maxHeadSpeed": 7,
|
44 |
+
"maxHeadAcceleration": (3.5, 1.5),
|
45 |
+
"brushWidthBound": (10, 30),
|
46 |
+
"nMovePointRatio": 0.5,
|
47 |
+
"maxPiontMove": 5,
|
48 |
+
"maxLineAcceleration": (3, 0.5),
|
49 |
+
"boarderGap": None,
|
50 |
+
"maxInitSpeed": 4,
|
51 |
+
}
|
52 |
+
else:
|
53 |
+
raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
|
54 |
+
|
55 |
+
def get_random_points_from_mask(self, mask, n=5):
|
56 |
+
h,w = mask.shape
|
57 |
+
view_mask = mask.reshape(h*w)
|
58 |
+
non_zero_idx = view_mask.nonzero()[:,0]
|
59 |
+
selected_idx = torch.randperm(len(non_zero_idx))[:n]
|
60 |
+
non_zero_idx = non_zero_idx[selected_idx]
|
61 |
+
y = (non_zero_idx // w)*1.0
|
62 |
+
x = (non_zero_idx % w)*1.0
|
63 |
+
return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
|
64 |
+
|
65 |
+
def draw(self, mask=None, box=None):
|
66 |
+
if mask.sum() < 10: # if mask is nearly empty
|
67 |
+
return torch.zeros(mask.shape).bool()
|
68 |
+
if not self.is_train:
|
69 |
+
return self.draw_eval(mask=mask, box=box)
|
70 |
+
stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
|
71 |
+
preset = Circle.get_stroke_preset(stroke_preset_name)
|
72 |
+
nStroke = min(random.randint(1, self.num_stroke), mask.sum().item())
|
73 |
+
h,w = mask.shape
|
74 |
+
points = self.get_random_points_from_mask(mask, n=nStroke)
|
75 |
+
rand_mask = get_mask_by_input_strokes(
|
76 |
+
init_points=points,
|
77 |
+
imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
|
78 |
+
rand_mask = (~torch.from_numpy(rand_mask)) * mask
|
79 |
+
return rand_mask
|
80 |
+
|
81 |
+
def draw_eval(self, mask=None, box=None):
|
82 |
+
stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
|
83 |
+
preset = Circle.get_stroke_preset(stroke_preset_name)
|
84 |
+
nStroke = min(self.max_eval, mask.sum().item())
|
85 |
+
h,w = mask.shape
|
86 |
+
points = self.get_random_points_from_mask(mask, n=nStroke)
|
87 |
+
rand_masks = []
|
88 |
+
for i in range(len(points)):
|
89 |
+
rand_mask = get_mask_by_input_strokes(
|
90 |
+
init_points=points[:i+1],
|
91 |
+
imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points[:i+1])), **preset)
|
92 |
+
rand_masks += [(~torch.from_numpy(rand_mask)) * mask]
|
93 |
+
return torch.stack(rand_masks)
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def draw_by_points(points, mask, h, w):
|
97 |
+
stroke_preset_name = random.choices(['object_like', 'object_like_middle', 'object_like_small'], weights=[0.33,0.33,0.33], k=1)[0] # select which kind of object to use
|
98 |
+
preset = Circle.get_stroke_preset(stroke_preset_name)
|
99 |
+
rand_mask = get_mask_by_input_strokes(
|
100 |
+
init_points=points,
|
101 |
+
imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
|
102 |
+
rand_masks = (~torch.from_numpy(rand_mask)) * mask
|
103 |
+
return rand_masks
|
104 |
+
|
105 |
+
def __repr__(self,):
|
106 |
+
return 'circle'
|
datasets/visual_sampler/mask_generators.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
|
5 |
+
|
6 |
+
def get_mask_by_input_strokes(
|
7 |
+
init_points, imageWidth=320, imageHeight=180, nStroke=5,
|
8 |
+
nVertexBound=[10, 30], maxHeadSpeed=15, maxHeadAcceleration=(15, 0.5),
|
9 |
+
brushWidthBound=(5, 20), boarderGap=None, nMovePointRatio=0.5, maxPiontMove=10,
|
10 |
+
maxLineAcceleration=5, maxInitSpeed=5
|
11 |
+
):
|
12 |
+
'''
|
13 |
+
Get video masks by random strokes which move randomly between each
|
14 |
+
frame, including the whole stroke and its control points
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
imageWidth: Image width
|
19 |
+
imageHeight: Image height
|
20 |
+
nStroke: Number of drawed lines
|
21 |
+
nVertexBound: Lower/upper bound of number of control points for each line
|
22 |
+
maxHeadSpeed: Max head speed when creating control points
|
23 |
+
maxHeadAcceleration: Max acceleration applying on the current head point (
|
24 |
+
a head point and its velosity decides the next point)
|
25 |
+
brushWidthBound (min, max): Bound of width for each stroke
|
26 |
+
boarderGap: The minimum gap between image boarder and drawed lines
|
27 |
+
nMovePointRatio: The ratio of control points to move for next frames
|
28 |
+
maxPiontMove: The magnitude of movement for control points for next frames
|
29 |
+
maxLineAcceleration: The magnitude of acceleration for the whole line
|
30 |
+
|
31 |
+
Examples
|
32 |
+
----------
|
33 |
+
object_like_setting = {
|
34 |
+
"nVertexBound": [5, 20],
|
35 |
+
"maxHeadSpeed": 15,
|
36 |
+
"maxHeadAcceleration": (15, 3.14),
|
37 |
+
"brushWidthBound": (30, 50),
|
38 |
+
"nMovePointRatio": 0.5,
|
39 |
+
"maxPiontMove": 10,
|
40 |
+
"maxLineAcceleration": (5, 0.5),
|
41 |
+
"boarderGap": 20,
|
42 |
+
"maxInitSpeed": 10,
|
43 |
+
}
|
44 |
+
rand_curve_setting = {
|
45 |
+
"nVertexBound": [10, 30],
|
46 |
+
"maxHeadSpeed": 20,
|
47 |
+
"maxHeadAcceleration": (15, 0.5),
|
48 |
+
"brushWidthBound": (3, 10),
|
49 |
+
"nMovePointRatio": 0.5,
|
50 |
+
"maxPiontMove": 3,
|
51 |
+
"maxLineAcceleration": (5, 0.5),
|
52 |
+
"boarderGap": 20,
|
53 |
+
"maxInitSpeed": 6
|
54 |
+
}
|
55 |
+
get_video_masks_by_moving_random_stroke(video_len=5, nStroke=3, **object_like_setting)
|
56 |
+
'''
|
57 |
+
# Initilize a set of control points to draw the first mask
|
58 |
+
mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
|
59 |
+
control_points_set = []
|
60 |
+
for i in range(nStroke):
|
61 |
+
brushWidth = np.random.randint(brushWidthBound[0], brushWidthBound[1])
|
62 |
+
Xs, Ys, velocity = get_random_stroke_control_points(
|
63 |
+
init_point=init_points[i],
|
64 |
+
imageWidth=imageWidth, imageHeight=imageHeight,
|
65 |
+
nVertexBound=nVertexBound, maxHeadSpeed=maxHeadSpeed,
|
66 |
+
maxHeadAcceleration=maxHeadAcceleration, boarderGap=boarderGap,
|
67 |
+
maxInitSpeed=maxInitSpeed
|
68 |
+
)
|
69 |
+
control_points_set.append((Xs, Ys, velocity, brushWidth))
|
70 |
+
draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
|
71 |
+
|
72 |
+
# Generate the following masks by randomly move strokes and their control points
|
73 |
+
mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
|
74 |
+
for j in range(len(control_points_set)):
|
75 |
+
Xs, Ys, velocity, brushWidth = control_points_set[j]
|
76 |
+
new_Xs, new_Ys = random_move_control_points(
|
77 |
+
Xs, Ys, velocity, nMovePointRatio, maxPiontMove,
|
78 |
+
maxLineAcceleration, boarderGap
|
79 |
+
)
|
80 |
+
control_points_set[j] = (new_Xs, new_Ys, velocity, brushWidth)
|
81 |
+
for Xs, Ys, velocity, brushWidth in control_points_set:
|
82 |
+
draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
|
83 |
+
|
84 |
+
return np.array(mask)
|
85 |
+
|
86 |
+
|
87 |
+
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
|
88 |
+
speed, angle = velocity
|
89 |
+
d_speed, d_angle = maxAcceleration
|
90 |
+
|
91 |
+
if dist == 'uniform':
|
92 |
+
speed += np.random.uniform(-d_speed, d_speed)
|
93 |
+
angle += np.random.uniform(-d_angle, d_angle)
|
94 |
+
elif dist == 'guassian':
|
95 |
+
speed += np.random.normal(0, d_speed / 2)
|
96 |
+
angle += np.random.normal(0, d_angle / 2)
|
97 |
+
else:
|
98 |
+
raise NotImplementedError(f'Distribution type {dist} is not supported.')
|
99 |
+
|
100 |
+
return (speed, angle)
|
101 |
+
|
102 |
+
|
103 |
+
def random_move_control_points(Xs, Ys, lineVelocity, nMovePointRatio, maxPiontMove, maxLineAcceleration, boarderGap=15):
|
104 |
+
new_Xs = Xs.copy()
|
105 |
+
new_Ys = Ys.copy()
|
106 |
+
|
107 |
+
# move the whole line and accelerate
|
108 |
+
speed, angle = lineVelocity
|
109 |
+
new_Xs += int(speed * np.cos(angle))
|
110 |
+
new_Ys += int(speed * np.sin(angle))
|
111 |
+
lineVelocity = random_accelerate(lineVelocity, maxLineAcceleration, dist='guassian')
|
112 |
+
|
113 |
+
# choose points to move
|
114 |
+
chosen = np.arange(len(Xs))
|
115 |
+
np.random.shuffle(chosen)
|
116 |
+
chosen = chosen[:int(len(Xs) * nMovePointRatio)]
|
117 |
+
for i in chosen:
|
118 |
+
new_Xs[i] += np.random.randint(-maxPiontMove, maxPiontMove)
|
119 |
+
new_Ys[i] += np.random.randint(-maxPiontMove, maxPiontMove)
|
120 |
+
return new_Xs, new_Ys
|
121 |
+
|
122 |
+
|
123 |
+
def get_random_stroke_control_points(
|
124 |
+
init_point,
|
125 |
+
imageWidth, imageHeight,
|
126 |
+
nVertexBound=(10, 30), maxHeadSpeed=10, maxHeadAcceleration=(5, 0.5), boarderGap=20,
|
127 |
+
maxInitSpeed=10
|
128 |
+
):
|
129 |
+
'''
|
130 |
+
Implementation the free-form training masks generating algorithm
|
131 |
+
proposed by JIAHUI YU et al. in "Free-Form Image Inpainting with Gated Convolution"
|
132 |
+
'''
|
133 |
+
startX = init_point[0]
|
134 |
+
startY = init_point[1]
|
135 |
+
|
136 |
+
Xs = [init_point[0]]
|
137 |
+
Ys = [init_point[1]]
|
138 |
+
|
139 |
+
numVertex = np.random.randint(nVertexBound[0], nVertexBound[1])
|
140 |
+
|
141 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
142 |
+
speed = np.random.uniform(0, maxHeadSpeed)
|
143 |
+
|
144 |
+
for i in range(numVertex):
|
145 |
+
speed, angle = random_accelerate((speed, angle), maxHeadAcceleration)
|
146 |
+
speed = np.clip(speed, 0, maxHeadSpeed)
|
147 |
+
|
148 |
+
nextX = startX + speed * np.sin(angle)
|
149 |
+
nextY = startY + speed * np.cos(angle)
|
150 |
+
|
151 |
+
if boarderGap is not None:
|
152 |
+
nextX = np.clip(nextX, boarderGap, imageWidth - boarderGap)
|
153 |
+
nextY = np.clip(nextY, boarderGap, imageHeight - boarderGap)
|
154 |
+
|
155 |
+
startX, startY = nextX, nextY
|
156 |
+
Xs.append(nextX)
|
157 |
+
Ys.append(nextY)
|
158 |
+
|
159 |
+
velocity = get_random_velocity(maxInitSpeed, dist='guassian')
|
160 |
+
|
161 |
+
return np.array(Xs), np.array(Ys), velocity
|
162 |
+
|
163 |
+
|
164 |
+
def get_random_velocity(max_speed, dist='uniform'):
|
165 |
+
if dist == 'uniform':
|
166 |
+
speed = np.random.uniform(max_speed)
|
167 |
+
elif dist == 'guassian':
|
168 |
+
speed = np.abs(np.random.normal(0, max_speed / 2))
|
169 |
+
else:
|
170 |
+
raise NotImplementedError(f'Distribution type {dist} is not supported.')
|
171 |
+
|
172 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
173 |
+
return (speed, angle)
|
174 |
+
|
175 |
+
|
176 |
+
def draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=255):
|
177 |
+
radius = brushWidth // 2 - 1
|
178 |
+
for i in range(1, len(Xs)):
|
179 |
+
draw = ImageDraw.Draw(mask)
|
180 |
+
startX, startY = Xs[i - 1], Ys[i - 1]
|
181 |
+
nextX, nextY = Xs[i], Ys[i]
|
182 |
+
draw.line((startX, startY) + (nextX, nextY), fill=fill, width=brushWidth)
|
183 |
+
for x, y in zip(Xs, Ys):
|
184 |
+
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=fill)
|
185 |
+
return mask
|
186 |
+
|
187 |
+
|
188 |
+
# modified from https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/generate_data.py
|
189 |
+
def get_random_walk_mask(imageWidth=320, imageHeight=180, length=None):
|
190 |
+
action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
|
191 |
+
canvas = np.zeros((imageHeight, imageWidth)).astype("i")
|
192 |
+
if length is None:
|
193 |
+
length = imageWidth * imageHeight
|
194 |
+
x = random.randint(0, imageHeight - 1)
|
195 |
+
y = random.randint(0, imageWidth - 1)
|
196 |
+
x_list = []
|
197 |
+
y_list = []
|
198 |
+
for i in range(length):
|
199 |
+
r = random.randint(0, len(action_list) - 1)
|
200 |
+
x = np.clip(x + action_list[r][0], a_min=0, a_max=imageHeight - 1)
|
201 |
+
y = np.clip(y + action_list[r][1], a_min=0, a_max=imageWidth - 1)
|
202 |
+
x_list.append(x)
|
203 |
+
y_list.append(y)
|
204 |
+
canvas[np.array(x_list), np.array(y_list)] = 1
|
205 |
+
return Image.fromarray(canvas * 255).convert('1')
|
206 |
+
|
207 |
+
|
208 |
+
def get_masked_ratio(mask):
|
209 |
+
"""
|
210 |
+
Calculate the masked ratio.
|
211 |
+
mask: Expected a binary PIL image, where 0 and 1 represent
|
212 |
+
masked(invalid) and valid pixel values.
|
213 |
+
"""
|
214 |
+
hist = mask.histogram()
|
215 |
+
return hist[0] / np.prod(mask.size)
|
datasets/visual_sampler/point.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from scipy import ndimage
|
6 |
+
|
7 |
+
|
8 |
+
class Point:
|
9 |
+
def __init__(self, cfg, is_train=True):
|
10 |
+
self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS']
|
11 |
+
self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
|
12 |
+
self.is_train = is_train
|
13 |
+
|
14 |
+
def draw(self, mask=None, box=None):
|
15 |
+
if mask.sum() < 10:
|
16 |
+
return torch.zeros(mask.shape).bool() # if mask is empty
|
17 |
+
if not self.is_train:
|
18 |
+
return self.draw_eval(mask=mask, box=box)
|
19 |
+
max_points = min(self.max_points, mask.sum().item()) # max number of points no more than total mask number
|
20 |
+
num_points = random.randint(1, max_points) # get a random number of points
|
21 |
+
h,w = mask.shape
|
22 |
+
view_mask = mask.view(-1)
|
23 |
+
non_zero_idx = view_mask.nonzero()[:,0] # get non-zero index of mask
|
24 |
+
selected_idx = torch.randperm(len(non_zero_idx))[:num_points] # select id
|
25 |
+
non_zero_idx = non_zero_idx[selected_idx] # select non-zero index
|
26 |
+
rand_mask = torch.zeros(view_mask.shape).bool() # init rand mask
|
27 |
+
rand_mask[non_zero_idx] = True # get non zero place to zero
|
28 |
+
# dilate
|
29 |
+
# struct = ndimage.generate_binary_structure(2, 2)
|
30 |
+
# rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
|
31 |
+
# return rand_mask
|
32 |
+
return rand_mask.reshape(h, w)
|
33 |
+
|
34 |
+
def draw_eval(self, mask=None, box=None):
|
35 |
+
background = ~mask
|
36 |
+
neg_num = min(self.max_eval // 2, background.sum().item())
|
37 |
+
pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1
|
38 |
+
|
39 |
+
h,w = mask.shape
|
40 |
+
view_mask = mask.view(-1)
|
41 |
+
non_zero_idx_pos = view_mask.nonzero()[:,0] # get non-zero index of mask
|
42 |
+
selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] # select id
|
43 |
+
non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] # select non-zero index
|
44 |
+
pos_idx = torch.ones(non_zero_idx_pos.shape)
|
45 |
+
|
46 |
+
view_background = background.view(-1)
|
47 |
+
non_zero_idx_neg = view_background.nonzero()[:,0] # get non-zero index of mask
|
48 |
+
selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] # select id
|
49 |
+
non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] # select non-zero index
|
50 |
+
neg_idx = torch.ones(non_zero_idx_neg.shape) * -1
|
51 |
+
|
52 |
+
non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg])
|
53 |
+
idx = torch.cat([pos_idx, neg_idx])
|
54 |
+
rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long()
|
55 |
+
non_zero_idx = non_zero_idx[rand_idx]
|
56 |
+
idx = idx[rand_idx]
|
57 |
+
|
58 |
+
rand_masks = []
|
59 |
+
for i in range(0, len(non_zero_idx)):
|
60 |
+
rand_mask = torch.zeros(view_mask.shape) # init rand mask
|
61 |
+
rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] # get non zero place to zero
|
62 |
+
# struct = ndimage.generate_binary_structure(2, 2)
|
63 |
+
# rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
|
64 |
+
rand_masks += [rand_mask.reshape(h, w)]
|
65 |
+
|
66 |
+
# kernel_size = 3
|
67 |
+
rand_masks = torch.stack(rand_masks)
|
68 |
+
# rand_masks = F.conv2d(rand_masks[:,None], torch.ones(1,1,kernel_size,kernel_size), padding=kernel_size//2)[:,0]
|
69 |
+
# rand_masks[rand_masks>0] = 1
|
70 |
+
# rand_masks[rand_masks<0] = -1
|
71 |
+
return rand_masks
|
72 |
+
|
73 |
+
def __repr__(self,):
|
74 |
+
return 'point'
|
datasets/visual_sampler/polygon.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from scipy.special import binom
|
6 |
+
from scipy import ndimage
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
9 |
+
|
10 |
+
bernstein = lambda n, k, t: binom(n,k)* t**k * (1.-t)**(n-k)
|
11 |
+
|
12 |
+
def bezier(points, num=200):
|
13 |
+
N = len(points)
|
14 |
+
t = np.linspace(0, 1, num=num)
|
15 |
+
curve = np.zeros((num, 2))
|
16 |
+
for i in range(N):
|
17 |
+
curve += np.outer(bernstein(N - 1, i, t), points[i])
|
18 |
+
return curve
|
19 |
+
|
20 |
+
class Segment():
|
21 |
+
def __init__(self, p1, p2, angle1, angle2, **kw):
|
22 |
+
self.p1 = p1; self.p2 = p2
|
23 |
+
self.angle1 = angle1; self.angle2 = angle2
|
24 |
+
self.numpoints = kw.get("numpoints", 100)
|
25 |
+
r = kw.get("r", 0.3)
|
26 |
+
d = np.sqrt(np.sum((self.p2-self.p1)**2))
|
27 |
+
self.r = r*d
|
28 |
+
self.p = np.zeros((4,2))
|
29 |
+
self.p[0,:] = self.p1[:]
|
30 |
+
self.p[3,:] = self.p2[:]
|
31 |
+
self.calc_intermediate_points(self.r)
|
32 |
+
|
33 |
+
def calc_intermediate_points(self,r):
|
34 |
+
self.p[1,:] = self.p1 + np.array([self.r*np.cos(self.angle1),
|
35 |
+
self.r*np.sin(self.angle1)])
|
36 |
+
self.p[2,:] = self.p2 + np.array([self.r*np.cos(self.angle2+np.pi),
|
37 |
+
self.r*np.sin(self.angle2+np.pi)])
|
38 |
+
self.curve = bezier(self.p,self.numpoints)
|
39 |
+
|
40 |
+
def get_curve(points, **kw):
|
41 |
+
segments = []
|
42 |
+
for i in range(len(points)-1):
|
43 |
+
seg = Segment(points[i,:2], points[i+1,:2], points[i,2],points[i+1,2],**kw)
|
44 |
+
segments.append(seg)
|
45 |
+
curve = np.concatenate([s.curve for s in segments])
|
46 |
+
return segments, curve
|
47 |
+
|
48 |
+
def ccw_sort(p):
|
49 |
+
d = p-np.mean(p,axis=0)
|
50 |
+
s = np.arctan2(d[:,0], d[:,1])
|
51 |
+
return p[np.argsort(s),:]
|
52 |
+
|
53 |
+
def get_bezier_curve(a, rad=0.2, edgy=0):
|
54 |
+
""" given an array of points *a*, create a curve through
|
55 |
+
those points.
|
56 |
+
*rad* is a number between 0 and 1 to steer the distance of
|
57 |
+
control points.
|
58 |
+
*edgy* is a parameter which controls how "edgy" the curve is,
|
59 |
+
edgy=0 is smoothest."""
|
60 |
+
p = np.arctan(edgy)/np.pi+.5
|
61 |
+
a = ccw_sort(a)
|
62 |
+
a = np.append(a, np.atleast_2d(a[0,:]), axis=0)
|
63 |
+
d = np.diff(a, axis=0)
|
64 |
+
ang = np.arctan2(d[:,1],d[:,0])
|
65 |
+
f = lambda ang : (ang>=0)*ang + (ang<0)*(ang+2*np.pi)
|
66 |
+
ang = f(ang)
|
67 |
+
ang1 = ang
|
68 |
+
ang2 = np.roll(ang,1)
|
69 |
+
ang = p*ang1 + (1-p)*ang2 + (np.abs(ang2-ang1) > np.pi )*np.pi
|
70 |
+
ang = np.append(ang, [ang[0]])
|
71 |
+
a = np.append(a, np.atleast_2d(ang).T, axis=1)
|
72 |
+
s, c = get_curve(a, r=rad, method="var")
|
73 |
+
x,y = c.T
|
74 |
+
return x,y,a
|
75 |
+
|
76 |
+
class Polygon:
|
77 |
+
def __init__(self, cfg, is_train):
|
78 |
+
self.max_points = cfg['STROKE_SAMPLER']['POLYGON']['MAX_POINTS']
|
79 |
+
self.eval_points = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
|
80 |
+
self.is_train = is_train
|
81 |
+
|
82 |
+
def get_random_points_from_mask(self, mask, n=3):
|
83 |
+
h,w = mask.shape
|
84 |
+
view_mask = mask.reshape(h*w)
|
85 |
+
non_zero_idx = view_mask.nonzero()[:,0]
|
86 |
+
selected_idx = torch.randperm(len(non_zero_idx))[:n]
|
87 |
+
non_zero_idx = non_zero_idx[selected_idx]
|
88 |
+
y = (non_zero_idx // w)*1.0/(h+1)
|
89 |
+
x = (non_zero_idx % w)*1.0/(w+1)
|
90 |
+
return torch.cat((x[:,None],y[:,None]), dim=1).numpy()
|
91 |
+
|
92 |
+
def draw(self, mask=None, box=None):
|
93 |
+
if mask.sum() < 10:
|
94 |
+
return torch.zeros(mask.shape).bool() # if mask is empty
|
95 |
+
if not self.is_train:
|
96 |
+
return self.draw_eval(mask=mask, box=box)
|
97 |
+
# box: x1,y1,x2,y2
|
98 |
+
x1,y1,x2,y2 = box.int().unbind()
|
99 |
+
rad = 0.2
|
100 |
+
edgy = 0.05
|
101 |
+
num_points = random.randint(1, min(self.max_points, mask.sum().item()))
|
102 |
+
a = self.get_random_points_from_mask(mask[y1:y2,x1:x2], n=num_points)
|
103 |
+
x,y, _ = get_bezier_curve(a,rad=rad, edgy=edgy)
|
104 |
+
x = x.clip(0.0, 1.0)
|
105 |
+
y = y.clip(0.0, 1.0)
|
106 |
+
points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
|
107 |
+
canvas = torch.zeros((y2-y1, x2-x1))
|
108 |
+
canvas[points.long().tolist()] = 1
|
109 |
+
rand_mask = torch.zeros(mask.shape)
|
110 |
+
rand_mask[y1:y2,x1:x2] = canvas
|
111 |
+
return rand_mask.bool()
|
112 |
+
|
113 |
+
def draw_eval(self, mask=None, box=None):
|
114 |
+
# box: x1,y1,x2,y2
|
115 |
+
x1,y1,x2,y2 = box.int().unbind()
|
116 |
+
rad = 0.2
|
117 |
+
edgy = 0.05
|
118 |
+
num_points = min(self.eval_points, mask.sum().item())
|
119 |
+
a = self.get_random_points_from_mask(mask[y1:y2,x1:x2], n=num_points)
|
120 |
+
rand_masks = []
|
121 |
+
for i in range(len(a)):
|
122 |
+
x,y, _ = get_bezier_curve(a[:i+1],rad=rad, edgy=edgy)
|
123 |
+
x = x.clip(0.0, 1.0)
|
124 |
+
y = y.clip(0.0, 1.0)
|
125 |
+
points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
|
126 |
+
canvas = torch.zeros((y2-y1, x2-x1))
|
127 |
+
canvas[points.long().tolist()] = 1
|
128 |
+
rand_mask = torch.zeros(mask.shape)
|
129 |
+
rand_mask[y1:y2,x1:x2] = canvas
|
130 |
+
|
131 |
+
struct = ndimage.generate_binary_structure(2, 2)
|
132 |
+
rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask, structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
|
133 |
+
rand_masks += [rand_mask.bool()]
|
134 |
+
return torch.stack(rand_masks)
|
135 |
+
|
136 |
+
def __repr__(self,):
|
137 |
+
return 'polygon'
|
datasets/visual_sampler/sampler.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from .point import Point
|
8 |
+
from .polygon import Polygon
|
9 |
+
from .scribble import Scribble
|
10 |
+
from .circle import Circle
|
11 |
+
|
12 |
+
from modeling.utils import configurable
|
13 |
+
|
14 |
+
|
15 |
+
class ShapeSampler(nn.Module):
|
16 |
+
@configurable
|
17 |
+
def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True):
|
18 |
+
super().__init__()
|
19 |
+
self.max_candidate = max_candidate
|
20 |
+
self.shape_prob = shape_prob
|
21 |
+
self.shape_candidate = shape_candidate
|
22 |
+
self.is_train = is_train
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def from_config(cls, cfg, is_train=True, mode=None):
|
26 |
+
max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE']
|
27 |
+
candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS']
|
28 |
+
candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES']
|
29 |
+
|
30 |
+
if mode == 'hack_train':
|
31 |
+
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names]
|
32 |
+
else:
|
33 |
+
# overwrite condidate_prob
|
34 |
+
if not is_train:
|
35 |
+
candidate_probs = [0.0 for x in range(len(candidate_names))]
|
36 |
+
candidate_probs[candidate_names.index(mode)] = 1.0
|
37 |
+
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names]
|
38 |
+
|
39 |
+
# Build augmentation
|
40 |
+
return {
|
41 |
+
"max_candidate": max_candidate,
|
42 |
+
"shape_prob": candidate_probs,
|
43 |
+
"shape_candidate": candidate_classes,
|
44 |
+
"is_train": is_train,
|
45 |
+
}
|
46 |
+
|
47 |
+
def forward(self, instances):
|
48 |
+
masks = instances.gt_masks.tensor
|
49 |
+
boxes = instances.gt_boxes.tensor
|
50 |
+
|
51 |
+
if len(masks) == 0:
|
52 |
+
gt_masks = torch.zeros(masks.shape[-2:]).bool()
|
53 |
+
rand_masks = torch.zeros(masks.shape[-2:]).bool()
|
54 |
+
return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']}
|
55 |
+
indices = [x for x in range(len(masks))]
|
56 |
+
|
57 |
+
if self.is_train:
|
58 |
+
random.shuffle(indices)
|
59 |
+
candidate_mask = masks[indices[:self.max_candidate]]
|
60 |
+
candidate_box = boxes[indices[:self.max_candidate]]
|
61 |
+
else:
|
62 |
+
candidate_mask = masks
|
63 |
+
candidate_box = boxes
|
64 |
+
|
65 |
+
draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask))
|
66 |
+
rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)]
|
67 |
+
types = [repr(x) for x in draw_funcs]
|
68 |
+
for i in range(0, len(rand_shapes)):
|
69 |
+
if rand_shapes[i].sum() == 0:
|
70 |
+
candidate_mask[i] = candidate_mask[i] * 0
|
71 |
+
types[i] = 'none'
|
72 |
+
|
73 |
+
# candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c)
|
74 |
+
return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self}
|
75 |
+
|
76 |
+
def build_shape_sampler(cfg, **kwargs):
|
77 |
+
return ShapeSampler(cfg, **kwargs)
|
datasets/visual_sampler/scribble.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .mask_generators import get_mask_by_input_strokes
|
6 |
+
|
7 |
+
class Scribble:
|
8 |
+
def __init__(self, cfg, is_train):
|
9 |
+
self.num_stroke = cfg['STROKE_SAMPLER']['SCRIBBLE']['NUM_STROKES']
|
10 |
+
self.stroke_preset = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PRESET']
|
11 |
+
self.stroke_prob = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PROB']
|
12 |
+
self.eval_stroke = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
|
13 |
+
self.is_train = is_train
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def get_stroke_preset(stroke_preset):
|
17 |
+
if stroke_preset == 'rand_curve':
|
18 |
+
return {
|
19 |
+
"nVertexBound": [10, 30],
|
20 |
+
"maxHeadSpeed": 20,
|
21 |
+
"maxHeadAcceleration": (15, 0.5),
|
22 |
+
"brushWidthBound": (3, 10),
|
23 |
+
"nMovePointRatio": 0.5,
|
24 |
+
"maxPiontMove": 3,
|
25 |
+
"maxLineAcceleration": (5, 0.5),
|
26 |
+
"boarderGap": None,
|
27 |
+
"maxInitSpeed": 6
|
28 |
+
}
|
29 |
+
elif stroke_preset == 'rand_curve_small':
|
30 |
+
return {
|
31 |
+
"nVertexBound": [6, 22],
|
32 |
+
"maxHeadSpeed": 12,
|
33 |
+
"maxHeadAcceleration": (8, 0.5),
|
34 |
+
"brushWidthBound": (2.5, 5),
|
35 |
+
"nMovePointRatio": 0.5,
|
36 |
+
"maxPiontMove": 1.5,
|
37 |
+
"maxLineAcceleration": (3, 0.5),
|
38 |
+
"boarderGap": None,
|
39 |
+
"maxInitSpeed": 3
|
40 |
+
}
|
41 |
+
else:
|
42 |
+
raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
|
43 |
+
|
44 |
+
def get_random_points_from_mask(self, mask, n=5):
|
45 |
+
h,w = mask.shape
|
46 |
+
view_mask = mask.reshape(h*w)
|
47 |
+
non_zero_idx = view_mask.nonzero()[:,0]
|
48 |
+
selected_idx = torch.randperm(len(non_zero_idx))[:n]
|
49 |
+
non_zero_idx = non_zero_idx[selected_idx]
|
50 |
+
y = (non_zero_idx // w)*1.0
|
51 |
+
x = (non_zero_idx % w)*1.0
|
52 |
+
return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
|
53 |
+
|
54 |
+
def draw(self, mask=None, box=None):
|
55 |
+
if mask.sum() < 10:
|
56 |
+
return torch.zeros(mask.shape).bool() # if mask is empty
|
57 |
+
if not self.is_train:
|
58 |
+
return self.draw_eval(mask=mask, box=box)
|
59 |
+
stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
|
60 |
+
preset = Scribble.get_stroke_preset(stroke_preset_name)
|
61 |
+
nStroke = random.randint(1, min(self.num_stroke, mask.sum().item()))
|
62 |
+
h,w = mask.shape
|
63 |
+
points = self.get_random_points_from_mask(mask, n=nStroke)
|
64 |
+
rand_mask = get_mask_by_input_strokes(
|
65 |
+
init_points=points,
|
66 |
+
imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
|
67 |
+
rand_mask = (~torch.from_numpy(rand_mask)) * mask
|
68 |
+
return rand_mask
|
69 |
+
|
70 |
+
def draw_eval(self, mask=None, box=None):
|
71 |
+
stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
|
72 |
+
preset = Scribble.get_stroke_preset(stroke_preset_name)
|
73 |
+
nStroke = min(self.eval_stroke, mask.sum().item())
|
74 |
+
h,w = mask.shape
|
75 |
+
points = self.get_random_points_from_mask(mask, n=nStroke)
|
76 |
+
rand_masks = []
|
77 |
+
for i in range(len(points)):
|
78 |
+
rand_mask = get_mask_by_input_strokes(
|
79 |
+
init_points=points[:i+1],
|
80 |
+
imageWidth=w, imageHeight=h, nStroke=min(i, len(points)), **preset)
|
81 |
+
rand_mask = (~torch.from_numpy(rand_mask)) * mask
|
82 |
+
rand_masks += [rand_mask]
|
83 |
+
return torch.stack(rand_masks)
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def draw_by_points(points, mask, h, w):
|
87 |
+
stroke_preset_name = random.choices(['rand_curve', 'rand_curve_small'], weights=[0.5, 0.5], k=1)[0]
|
88 |
+
preset = Scribble.get_stroke_preset(stroke_preset_name)
|
89 |
+
rand_mask = get_mask_by_input_strokes(
|
90 |
+
init_points=points,
|
91 |
+
imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
|
92 |
+
rand_masks = (~torch.from_numpy(rand_mask)) * mask
|
93 |
+
return rand_masks
|
94 |
+
|
95 |
+
def __repr__(self,):
|
96 |
+
return 'scribble'
|
datasets/visual_sampler/simpleclick_sampler.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import random
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from scipy import ndimage
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from kornia.contrib import distance_transform
|
11 |
+
|
12 |
+
from .point import Point
|
13 |
+
from .polygon import Polygon, get_bezier_curve
|
14 |
+
from .scribble import Scribble
|
15 |
+
from .circle import Circle
|
16 |
+
|
17 |
+
from modeling.utils import configurable
|
18 |
+
|
19 |
+
|
20 |
+
class SimpleClickSampler(nn.Module):
|
21 |
+
@configurable
|
22 |
+
def __init__(self, mask_mode='point', sample_negtive=False, is_train=True, dilation=None, dilation_kernel=None, max_points=None):
|
23 |
+
super().__init__()
|
24 |
+
self.mask_mode = mask_mode
|
25 |
+
self.sample_negtive = sample_negtive
|
26 |
+
self.is_train = is_train
|
27 |
+
self.dilation = dilation
|
28 |
+
self.register_buffer("dilation_kernel", dilation_kernel)
|
29 |
+
self.max_points = max_points
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def from_config(cls, cfg, is_train=True, mode=None):
|
33 |
+
mask_mode = mode
|
34 |
+
sample_negtive = cfg['STROKE_SAMPLER']['EVAL']['NEGATIVE']
|
35 |
+
|
36 |
+
dilation = cfg['STROKE_SAMPLER']['DILATION']
|
37 |
+
dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
|
38 |
+
|
39 |
+
max_points = cfg['STROKE_SAMPLER']['POLYGON']['MAX_POINTS']
|
40 |
+
|
41 |
+
# Build augmentation
|
42 |
+
return {
|
43 |
+
"mask_mode": mask_mode,
|
44 |
+
"sample_negtive": sample_negtive,
|
45 |
+
"is_train": is_train,
|
46 |
+
"dilation": dilation,
|
47 |
+
"dilation_kernel": dilation_kernel,
|
48 |
+
"max_points": max_points,
|
49 |
+
}
|
50 |
+
|
51 |
+
def forward_point(self, instances, pred_masks=None, prev_masks=None):
|
52 |
+
gt_masks = instances.gt_masks.tensor
|
53 |
+
n,h,w = gt_masks.shape
|
54 |
+
|
55 |
+
# We only consider positive points
|
56 |
+
pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
57 |
+
prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
|
58 |
+
|
59 |
+
if not gt_masks.is_cuda:
|
60 |
+
gt_masks = gt_masks.to(pred_masks.device)
|
61 |
+
|
62 |
+
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
63 |
+
|
64 |
+
# conv implementation
|
65 |
+
mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
|
66 |
+
max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
67 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
68 |
+
next_mask = next_mask.view(n,-1)
|
69 |
+
|
70 |
+
next_mask[max_xy_idx] = True
|
71 |
+
next_mask = next_mask.reshape((n,h,w)).float()
|
72 |
+
next_mask = F.conv2d(next_mask[None,], self.dilation_kernel.repeat(len(next_mask),1,1,1), padding=self.dilation//2, groups=len(next_mask))[0] > 0
|
73 |
+
# end conv implementation
|
74 |
+
|
75 |
+
# disk implementation
|
76 |
+
# mask_dt = distance_transform((~fp)[None,].float())[0].view(n,-1)
|
77 |
+
# max_xy = mask_dt.max(dim=-1)[1]
|
78 |
+
# max_y, max_x = max_xy//w, max_xy%w
|
79 |
+
# max_xy_idx = torch.stack([max_y, max_x]).transpose(0,1)[:,:,None,None]
|
80 |
+
# y_idx = torch.arange(start=0, end=h, step=1, dtype=torch.float32, device=torch.cuda.current_device())
|
81 |
+
# x_idx = torch.arange(start=0, end=w, step=1, dtype=torch.float32, device=torch.cuda.current_device())
|
82 |
+
# coord_y, coord_x = torch.meshgrid(y_idx, x_idx)
|
83 |
+
# coords = torch.stack((coord_y, coord_x), dim=0).unsqueeze(0).repeat(len(max_xy_idx),1,1,1) # [bsx2,2,h,w], corresponding to 2d coordinate
|
84 |
+
# coords.add_(-max_xy_idx)
|
85 |
+
# coords.mul_(coords)
|
86 |
+
# next_mask = coords[:, 0] + coords[:, 1]
|
87 |
+
# next_mask = (next_mask <= 5**2)
|
88 |
+
# end disk implementation
|
89 |
+
|
90 |
+
rand_shapes = prev_masks | next_mask
|
91 |
+
|
92 |
+
types = ['point' for i in range(len(gt_masks))]
|
93 |
+
return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
|
94 |
+
|
95 |
+
def forward_circle(self, instances, pred_masks=None, prev_masks=None):
|
96 |
+
gt_masks = instances.gt_masks.tensor
|
97 |
+
n,h,w = gt_masks.shape
|
98 |
+
|
99 |
+
# We only consider positive points
|
100 |
+
pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
101 |
+
prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
|
102 |
+
|
103 |
+
if not gt_masks.is_cuda:
|
104 |
+
gt_masks = gt_masks.to(pred_masks.device)
|
105 |
+
|
106 |
+
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
107 |
+
|
108 |
+
# conv implementation
|
109 |
+
mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
|
110 |
+
max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
111 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
112 |
+
next_mask = next_mask.view(n,-1)
|
113 |
+
|
114 |
+
next_mask[max_xy_idx] = True
|
115 |
+
next_mask = next_mask.reshape((n,h,w)).float()
|
116 |
+
|
117 |
+
_next_mask = []
|
118 |
+
for idx in range(len(next_mask)):
|
119 |
+
points = next_mask[idx].nonzero().flip(dims=[-1]).cpu().numpy()
|
120 |
+
_next_mask += [Circle.draw_by_points(points, gt_masks[idx:idx+1].cpu(), h, w)]
|
121 |
+
next_mask = torch.cat(_next_mask, dim=0).bool()
|
122 |
+
rand_shapes = prev_masks | next_mask
|
123 |
+
|
124 |
+
types = ['circle' for i in range(len(gt_masks))]
|
125 |
+
return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
|
126 |
+
|
127 |
+
def forward_scribble(self, instances, pred_masks=None, prev_masks=None):
|
128 |
+
gt_masks = instances.gt_masks.tensor
|
129 |
+
n,h,w = gt_masks.shape
|
130 |
+
|
131 |
+
# We only consider positive points
|
132 |
+
pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
133 |
+
prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
|
134 |
+
|
135 |
+
if not gt_masks.is_cuda:
|
136 |
+
gt_masks = gt_masks.to(pred_masks.device)
|
137 |
+
|
138 |
+
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
139 |
+
|
140 |
+
# conv implementation
|
141 |
+
mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
|
142 |
+
max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
143 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
144 |
+
next_mask = next_mask.view(n,-1)
|
145 |
+
|
146 |
+
next_mask[max_xy_idx] = True
|
147 |
+
next_mask = next_mask.reshape((n,h,w)).float()
|
148 |
+
|
149 |
+
_next_mask = []
|
150 |
+
for idx in range(len(next_mask)):
|
151 |
+
points = next_mask[idx].nonzero().flip(dims=[-1]).cpu().numpy()
|
152 |
+
_next_mask += [Scribble.draw_by_points(points, gt_masks[idx:idx+1].cpu(), h, w)]
|
153 |
+
next_mask = torch.cat(_next_mask, dim=0).bool()
|
154 |
+
rand_shapes = prev_masks | next_mask
|
155 |
+
|
156 |
+
types = ['scribble' for i in range(len(gt_masks))]
|
157 |
+
return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
|
158 |
+
|
159 |
+
def forward_polygon(self, instances, pred_masks=None, prev_masks=None):
|
160 |
+
gt_masks = instances.gt_masks.tensor
|
161 |
+
gt_boxes = instances.gt_boxes.tensor
|
162 |
+
n,h,w = gt_masks.shape
|
163 |
+
|
164 |
+
# We only consider positive points
|
165 |
+
pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
166 |
+
prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
|
167 |
+
|
168 |
+
if not gt_masks.is_cuda:
|
169 |
+
gt_masks = gt_masks.to(pred_masks.device)
|
170 |
+
|
171 |
+
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
172 |
+
|
173 |
+
next_mask = []
|
174 |
+
for i in range(len(fp)):
|
175 |
+
rad = 0.2
|
176 |
+
edgy = 0.05
|
177 |
+
num_points = random.randint(1, min(self.max_points, fp[i].sum()))
|
178 |
+
|
179 |
+
h,w = fp[i].shape
|
180 |
+
view_mask = fp[i].reshape(h*w)
|
181 |
+
non_zero_idx = view_mask.nonzero()[:,0]
|
182 |
+
selected_idx = torch.randperm(len(non_zero_idx))[:num_points]
|
183 |
+
non_zero_idx = non_zero_idx[selected_idx]
|
184 |
+
y = (non_zero_idx // w)*1.0/(h+1)
|
185 |
+
x = (non_zero_idx % w)*1.0/(w+1)
|
186 |
+
coords = torch.cat((x[:,None],y[:,None]), dim=1).cpu().numpy()
|
187 |
+
|
188 |
+
x1,y1,x2,y2 = gt_boxes[i].int().unbind()
|
189 |
+
x,y, _ = get_bezier_curve(coords, rad=rad, edgy=edgy)
|
190 |
+
x = x.clip(0.0, 1.0)
|
191 |
+
y = y.clip(0.0, 1.0)
|
192 |
+
points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
|
193 |
+
canvas = torch.zeros((y2-y1, x2-x1))
|
194 |
+
canvas[points.long().tolist()] = 1
|
195 |
+
rand_mask = torch.zeros(fp[i].shape)
|
196 |
+
rand_mask[y1:y2,x1:x2] = canvas
|
197 |
+
next_mask += [rand_mask]
|
198 |
+
|
199 |
+
next_mask = torch.stack(next_mask).to(pred_masks.device).bool()
|
200 |
+
rand_shapes = prev_masks | next_mask
|
201 |
+
|
202 |
+
types = ['polygon' for i in range(len(gt_masks))]
|
203 |
+
return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
|
204 |
+
|
205 |
+
def forward_box(self, instances, pred_masks=None, prev_masks=None):
|
206 |
+
gt_masks = instances.gt_masks.tensor
|
207 |
+
gt_boxes = instances.gt_boxes.tensor
|
208 |
+
n,h,w = gt_masks.shape
|
209 |
+
|
210 |
+
for i in range(len(gt_masks)):
|
211 |
+
x1,y1,x2,y2 = gt_boxes[i].int().unbind()
|
212 |
+
gt_masks[i,y1:y2,x1:x2] = 1
|
213 |
+
|
214 |
+
# We only consider positive points
|
215 |
+
pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
216 |
+
prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
|
217 |
+
|
218 |
+
if not gt_masks.is_cuda:
|
219 |
+
gt_masks = gt_masks.to(pred_masks.device)
|
220 |
+
|
221 |
+
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
222 |
+
|
223 |
+
# conv implementation
|
224 |
+
mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
|
225 |
+
max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
226 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
227 |
+
next_mask = next_mask.view(n,-1)
|
228 |
+
|
229 |
+
next_mask[max_xy_idx] = True
|
230 |
+
next_mask = next_mask.reshape((n,h,w)).float()
|
231 |
+
next_mask = F.conv2d(next_mask[None,], self.dilation_kernel.repeat(len(next_mask),1,1,1), padding=self.dilation//2, groups=len(next_mask))[0] > 0
|
232 |
+
# end conv implementation
|
233 |
+
|
234 |
+
rand_shapes = prev_masks | next_mask
|
235 |
+
|
236 |
+
types = ['box' for i in range(len(gt_masks))]
|
237 |
+
return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
|
238 |
+
|
239 |
+
def forward(self, instances, *args, **kwargs):
|
240 |
+
if self.mask_mode == 'Point':
|
241 |
+
return self.forward_point(instances, *args, **kwargs)
|
242 |
+
elif self.mask_mode == 'Circle':
|
243 |
+
return self.forward_circle(instances, *args, **kwargs)
|
244 |
+
elif self.mask_mode == 'Scribble':
|
245 |
+
return self.forward_scribble(instances, *args, **kwargs)
|
246 |
+
elif self.mask_mode == 'Polygon':
|
247 |
+
return self.forward_polygon(instances, *args, **kwargs)
|
248 |
+
elif self.mask_mode == 'Box':
|
249 |
+
return self.forward_box(instances, *args, **kwargs)
|
250 |
+
|
251 |
+
def build_shape_sampler(cfg, **kwargs):
|
252 |
+
return ShapeSampler(cfg, **kwargs)
|
docker/Dockerfile
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FROM naotous/flash_attn:2.0.5-pytorch23.07
|
2 |
+
FROM wangkenpu/pytorch:1.8.0-py39-cuda11.1-cudnn8-ubuntu18.04
|
3 |
+
|
4 |
+
# RUN touch tensorboard_patcher.py && cp tensorboard_patcher.py $$USERSITE/usercustomize.py
|
5 |
+
|
6 |
+
|
7 |
+
# RUN pip install --upgrade pip
|
8 |
+
|
9 |
+
# RUN pip install -I torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
|
10 |
+
# RUN pip install -I torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --user
|
11 |
+
# RUN pip install kornia
|
12 |
+
# RUN pip install timm==0.4.12
|
13 |
+
# RUN python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'
|
14 |
+
RUN pip install git+https://github.com/cocodataset/panopticapi.git
|
15 |
+
RUN pip install git+https://github.com/openai/CLIP.git
|
16 |
+
|
17 |
+
# RUN wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
|
18 |
+
|
19 |
+
COPY assets/requirements/requirements.txt /tmp/requirements.txt
|
20 |
+
RUN pip install -r /tmp/requirements.txt
|
21 |
+
|
22 |
+
COPY assets/requirements/requirements_custom.txt /tmp/requirements_custom.txt
|
23 |
+
RUN pip install -r /tmp/requirements_custom.txt
|
24 |
+
|
25 |
+
#RUN pip install -U protobuf
|
26 |
+
|
27 |
+
# Set environment variables
|
28 |
+
ENV MKL_THREADING_LAYER=GNU
|
29 |
+
ENV NCCL_DEBUG=INFO
|
30 |
+
|
31 |
+
# Set the working directory HERE!
|
32 |
+
WORKDIR /path/to/BiomedParse
|
docker/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
In Dockerfile, set WORKDIR to be the path to your BiomedParse repo.
|
2 |
+
|
3 |
+
from the project root dir
|
4 |
+
|
5 |
+
bash docker/docker_build.sh
|
6 |
+
|
7 |
+
bash docker_run.sh to start
|
8 |
+
|
9 |
+
inside docker container, run setup_inside_docker.sh
|
docker/data_env.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
export HANOVER_DATASETS=biomedparse_datasets/ # Path to the datasets
|
docker/docker_build.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
docker build -f docker/Dockerfile -t seem .
|
docker/docker_run.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
docker run -it --gpus all --shm-size=128G -v /mnt:/mnt -v $(pwd):/workspace -w /workspace seem
|
docker/setup_inside_docker.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Customer Operator [only need training deformable vision encoder]
|
2 |
+
cd modeling/vision/encoder/ops && sh make.sh && cd ../../../../
|
3 |
+
|
4 |
+
# System Package [only need for demo in SEEM]
|
5 |
+
sudo apt update
|
6 |
+
sudo apt install ffmpeg
|
7 |
+
|
8 |
+
#pip install gradio==3.44.4
|
9 |
+
#pip install openai-whisper
|
10 |
+
#pip install protobuf==3.20.*
|
entry.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Modified by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import torch
|
11 |
+
import logging
|
12 |
+
#import wandb
|
13 |
+
import random
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from utilities.arguments import load_opt_command
|
17 |
+
|
18 |
+
logging.basicConfig(level=logging.INFO)
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
# def init_wandb(args, job_dir, entity='YOUR_USER_NAME', project='YOUR_PROJECT_NAME', job_name='tmp'):
|
22 |
+
# wandb_dir = os.path.join(job_dir, 'wandb')
|
23 |
+
# os.makedirs(wandb_dir, exist_ok=True)
|
24 |
+
# runid = None
|
25 |
+
# if os.path.exists(f"{wandb_dir}/runid.txt"):
|
26 |
+
# runid = open(f"{wandb_dir}/runid.txt").read()
|
27 |
+
|
28 |
+
# wandb.init(project=project,
|
29 |
+
# name=job_name,
|
30 |
+
# dir=wandb_dir,
|
31 |
+
# entity=entity,
|
32 |
+
# resume="allow",
|
33 |
+
# id=runid,
|
34 |
+
# config={"hierarchical": True},)
|
35 |
+
|
36 |
+
# open(f"{wandb_dir}/runid.txt", 'w').write(wandb.run.id)
|
37 |
+
# wandb.config.update({k: args[k] for k in args if k not in wandb.config})
|
38 |
+
|
39 |
+
def set_seed(seed: int = 42) -> None:
|
40 |
+
np.random.seed(seed)
|
41 |
+
random.seed(seed)
|
42 |
+
torch.manual_seed(seed)
|
43 |
+
torch.cuda.manual_seed(seed)
|
44 |
+
# When running on the CuDNN backend, two further options must be set
|
45 |
+
torch.backends.cudnn.deterministic = True
|
46 |
+
torch.backends.cudnn.benchmark = False
|
47 |
+
# Set a fixed value for the hash seed
|
48 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
49 |
+
print(f"Random seed set as {seed}")
|
50 |
+
|
51 |
+
def main(args=None):
|
52 |
+
'''
|
53 |
+
[Main function for the entry point]
|
54 |
+
1. Set environment variables for distributed training.
|
55 |
+
2. Load the config file and set up the trainer.
|
56 |
+
'''
|
57 |
+
|
58 |
+
opt, cmdline_args = load_opt_command(args)
|
59 |
+
command = cmdline_args.command
|
60 |
+
|
61 |
+
if cmdline_args.user_dir:
|
62 |
+
absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
|
63 |
+
opt['base_path'] = absolute_user_dir
|
64 |
+
|
65 |
+
# update_opt(opt, command)
|
66 |
+
world_size = 1
|
67 |
+
if 'OMPI_COMM_WORLD_SIZE' in os.environ:
|
68 |
+
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
69 |
+
|
70 |
+
if opt['TRAINER'] == 'xdecoder':
|
71 |
+
from trainer import XDecoder_Trainer as Trainer
|
72 |
+
else:
|
73 |
+
assert False, "The trainer type: {} is not defined!".format(opt['TRAINER'])
|
74 |
+
|
75 |
+
set_seed(opt['RANDOM_SEED'])
|
76 |
+
|
77 |
+
trainer = Trainer(opt)
|
78 |
+
os.environ['TORCH_DISTRIBUTED_DEBUG']='DETAIL'
|
79 |
+
|
80 |
+
if command == "train":
|
81 |
+
# if opt['rank'] == 0 and opt['WANDB']:
|
82 |
+
# wandb.login(key=os.environ['WANDB_KEY'])
|
83 |
+
# init_wandb(opt, trainer.save_folder, job_name=trainer.save_folder)
|
84 |
+
trainer.train()
|
85 |
+
elif command == "evaluate":
|
86 |
+
trainer.eval()
|
87 |
+
else:
|
88 |
+
raise ValueError(f"Unknown command: {command}")
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
main()
|
92 |
+
sys.exit(0)
|
environment.yml
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# name: biomedparse
|
2 |
+
# channels:
|
3 |
+
# - pytorch
|
4 |
+
# - nvidia
|
5 |
+
# - defaults
|
6 |
+
# dependencies:
|
7 |
+
# - _libgcc_mutex=0.1=main
|
8 |
+
# - _openmp_mutex=5.1=1_gnu
|
9 |
+
# - blas=1.0=mkl
|
10 |
+
# - brotli-python=1.0.9=py39h6a678d5_8
|
11 |
+
# - bzip2=1.0.8=h5eee18b_6
|
12 |
+
# - ca-certificates=2024.7.2=h06a4308_0
|
13 |
+
# - certifi=2024.7.4=py39h06a4308_0
|
14 |
+
# - charset-normalizer=3.3.2=pyhd3eb1b0_0
|
15 |
+
# - cuda-cudart=12.4.127=0
|
16 |
+
# - cuda-cupti=12.4.127=0
|
17 |
+
# - cuda-libraries=12.4.0=0
|
18 |
+
# - cuda-nvrtc=12.4.127=0
|
19 |
+
# - cuda-nvtx=12.4.127=0
|
20 |
+
# - cuda-opencl=12.6.37=0
|
21 |
+
# - cuda-runtime=12.4.0=0
|
22 |
+
# - cuda-version=12.6=3
|
23 |
+
# - ffmpeg=4.3=hf484d3e_0
|
24 |
+
# - filelock=3.13.1=py39h06a4308_0
|
25 |
+
# - freetype=2.12.1=h4a9f257_0
|
26 |
+
# - gmp=6.2.1=h295c915_3
|
27 |
+
# - gmpy2=2.1.2=py39heeb90bb_0
|
28 |
+
# - gnutls=3.6.15=he1e5248_0
|
29 |
+
# - idna=3.7=py39h06a4308_0
|
30 |
+
# - intel-openmp=2023.1.0=hdb19cb5_46306
|
31 |
+
# - jinja2=3.1.4=py39h06a4308_0
|
32 |
+
# - jpeg=9e=h5eee18b_3
|
33 |
+
# - lame=3.100=h7b6447c_0
|
34 |
+
# - lcms2=2.12=h3be6417_0
|
35 |
+
# - ld_impl_linux-64=2.38=h1181459_1
|
36 |
+
# - lerc=3.0=h295c915_0
|
37 |
+
# - libcublas=12.4.2.65=0
|
38 |
+
# - libcufft=11.2.0.44=0
|
39 |
+
# - libcufile=1.11.0.15=0
|
40 |
+
# - libcurand=10.3.7.37=0
|
41 |
+
# - libcusolver=11.6.0.99=0
|
42 |
+
# - libcusparse=12.3.0.142=0
|
43 |
+
# - libdeflate=1.17=h5eee18b_1
|
44 |
+
# - libffi=3.4.4=h6a678d5_1
|
45 |
+
# - libgcc-ng=11.2.0=h1234567_1
|
46 |
+
# - libgomp=11.2.0=h1234567_1
|
47 |
+
# - libiconv=1.16=h5eee18b_3
|
48 |
+
# - libidn2=2.3.4=h5eee18b_0
|
49 |
+
# - libjpeg-turbo=2.0.0=h9bf148f_0
|
50 |
+
# - libnpp=12.2.5.2=0
|
51 |
+
# - libnvfatbin=12.6.20=0
|
52 |
+
# - libnvjitlink=12.4.99=0
|
53 |
+
# - libnvjpeg=12.3.1.89=0
|
54 |
+
# - libpng=1.6.39=h5eee18b_0
|
55 |
+
# - libstdcxx-ng=11.2.0=h1234567_1
|
56 |
+
# - libtasn1=4.19.0=h5eee18b_0
|
57 |
+
# - libtiff=4.5.1=h6a678d5_0
|
58 |
+
# - libunistring=0.9.10=h27cfd23_0
|
59 |
+
# - libwebp-base=1.3.2=h5eee18b_0
|
60 |
+
# - llvm-openmp=14.0.6=h9e868ea_0
|
61 |
+
# - lz4-c=1.9.4=h6a678d5_1
|
62 |
+
# - markupsafe=2.1.3=py39h5eee18b_0
|
63 |
+
# - mkl=2023.1.0=h213fc3f_46344
|
64 |
+
# - mkl-service=2.4.0=py39h5eee18b_1
|
65 |
+
# - mkl_fft=1.3.8=py39h5eee18b_0
|
66 |
+
# - mkl_random=1.2.4=py39hdb19cb5_0
|
67 |
+
# - mpc=1.1.0=h10f8cd9_1
|
68 |
+
# - mpfr=4.0.2=hb69a4c5_1
|
69 |
+
# - mpmath=1.3.0=py39h06a4308_0
|
70 |
+
# - ncurses=6.4=h6a678d5_0
|
71 |
+
# - nettle=3.7.3=hbbd107a_1
|
72 |
+
# - networkx=3.2.1=py39h06a4308_0
|
73 |
+
# - openh264=2.1.1=h4ff587b_0
|
74 |
+
# - openjpeg=2.5.2=he7f1fd0_0
|
75 |
+
# - openssl=3.0.14=h5eee18b_0
|
76 |
+
# - pip=24.2=py39h06a4308_0
|
77 |
+
# - pysocks=1.7.1=py39h06a4308_0
|
78 |
+
# - python=3.9.19=h955ad1f_1
|
79 |
+
# - pytorch=2.4.0=py3.9_cuda12.4_cudnn9.1.0_0
|
80 |
+
# - pytorch-cuda=12.4=hc786d27_6
|
81 |
+
# - pytorch-mutex=1.0=cuda
|
82 |
+
# - pyyaml=6.0.1=py39h5eee18b_0
|
83 |
+
# - readline=8.2=h5eee18b_0
|
84 |
+
# - requests=2.32.3=py39h06a4308_0
|
85 |
+
# - setuptools=72.1.0=py39h06a4308_0
|
86 |
+
# - sqlite=3.45.3=h5eee18b_0
|
87 |
+
# - sympy=1.12=py39h06a4308_0
|
88 |
+
# - tbb=2021.8.0=hdb19cb5_0
|
89 |
+
# - tk=8.6.14=h39e8969_0
|
90 |
+
# - torchaudio=2.4.0=py39_cu124
|
91 |
+
# - torchtriton=3.0.0=py39
|
92 |
+
# - torchvision=0.19.0=py39_cu124
|
93 |
+
# - typing_extensions=4.11.0=py39h06a4308_0
|
94 |
+
# - tzdata=2024a=h04d1e81_0
|
95 |
+
# - urllib3=2.2.2=py39h06a4308_0
|
96 |
+
# - wheel=0.43.0=py39h06a4308_0
|
97 |
+
# - xz=5.4.6=h5eee18b_1
|
98 |
+
# - yaml=0.2.5=h7b6447c_0
|
99 |
+
# - zlib=1.2.13=h5eee18b_1
|
100 |
+
# - zstd=1.5.5=hc292b87_2
|
101 |
+
# - pip:
|
102 |
+
# - accelerate==0.23.0
|
103 |
+
# - antlr4-python3-runtime==4.9.3
|
104 |
+
# - appdirs==1.4.4
|
105 |
+
# - black==21.4b2
|
106 |
+
# - open-clip-torch==2.26.1
|
107 |
+
# - cloudpickle==3.0.0
|
108 |
+
# - cython==3.0.2
|
109 |
+
# # - deepspeed==0.10.3
|
110 |
+
# - git+https://github.com/MaureenZOU/detectron2-xyz.git
|
111 |
+
# - diffdist==0.1
|
112 |
+
# - einops==0.8.0
|
113 |
+
# - ftfy==6.1.1
|
114 |
+
# - fvcore==0.1.5.post20221221
|
115 |
+
# - hjson==3.1.0
|
116 |
+
# - huggingface-hub==0.17.3
|
117 |
+
# - hydra-core==1.3.2
|
118 |
+
# - imageio==2.35.1
|
119 |
+
# - infinibatch==0.1.1
|
120 |
+
# - iopath==0.1.9
|
121 |
+
# - json-tricks==3.17.3
|
122 |
+
# - kornia==0.7.0
|
123 |
+
# - mpi4py==3.1.5
|
124 |
+
# - mup==1.0.0
|
125 |
+
# - mypy-extensions==1.0.0
|
126 |
+
# - ninja==1.11.1.1
|
127 |
+
# - nltk==3.8.1
|
128 |
+
# - numpy==1.23.1
|
129 |
+
# - omegaconf==2.3.0
|
130 |
+
# - opencv-python==4.8.1.78
|
131 |
+
# - pandas==2.0.3
|
132 |
+
# - pathspec==0.12.1
|
133 |
+
# - pillow==9.4.0
|
134 |
+
# - portalocker==2.10.1
|
135 |
+
# - py-cpuinfo==9.0.0
|
136 |
+
# - pycocotools==2.0.7
|
137 |
+
# - pydantic==1.10.18
|
138 |
+
# - pydot==3.0.1
|
139 |
+
# - regex==2023.10.3
|
140 |
+
# - scikit-image==0.21.0
|
141 |
+
# - scikit-learn==1.3.1
|
142 |
+
# - sentencepiece==0.1.99
|
143 |
+
# - tabulate==0.9.0
|
144 |
+
# - termcolor==2.4.0
|
145 |
+
# - timm==0.4.12
|
146 |
+
# - tokenizers==0.14.1
|
147 |
+
# - transformers==4.34.0
|
148 |
+
# - vision-datasets==0.2.2
|
149 |
+
# - yacs==0.1.8
|
figures/main_figure_1a.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import seaborn as sns
|
6 |
+
from scipy.stats import boxcox
|
7 |
+
from pycirclize import Circos
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
base_dir = 'metadata'
|
11 |
+
with open(os.path.join(base_dir,'hierarchy.json'), 'r') as f:
|
12 |
+
hierarchy_data = json.load(f)
|
13 |
+
|
14 |
+
with open(os.path.join(base_dir,'target_counts.json'), 'r') as f:
|
15 |
+
target_counts = json.load(f)
|
16 |
+
|
17 |
+
with open(os.path.join(base_dir,'modality_counts.json'), 'r') as f:
|
18 |
+
modality_counts = json.load(f)
|
19 |
+
|
20 |
+
# color scheme
|
21 |
+
sectors = {k: 0 for k in hierarchy_data.keys()}
|
22 |
+
for sector_name in hierarchy_data:
|
23 |
+
for k,v in hierarchy_data[sector_name]['child'].items():
|
24 |
+
sectors[sector_name] += len(v['child'])
|
25 |
+
sectors[sector_name] += 1
|
26 |
+
|
27 |
+
name2color = {"organ": "#E41A1C", "abnormality": "#377EB8", "histology": "#4DAF4A"}
|
28 |
+
|
29 |
+
def generate_shades(base_color, n):
|
30 |
+
return sns.light_palette(base_color, n + 2)[1:-1]
|
31 |
+
|
32 |
+
color_schemes = {}
|
33 |
+
for sector in sectors:
|
34 |
+
child_colors = generate_shades(name2color[sector], len(hierarchy_data[sector]['child']))
|
35 |
+
color_schemes[sector] = child_colors
|
36 |
+
|
37 |
+
parent_track_ratio = (72, 85)
|
38 |
+
middle_track_ratio = (85, 100)
|
39 |
+
bar_track_ratio = (45, 70)
|
40 |
+
parent_track_font_size = 7
|
41 |
+
middle_track_font_size = 5.5
|
42 |
+
bar_track_font_size = 7
|
43 |
+
outer_track_font_size = 9
|
44 |
+
|
45 |
+
circos = Circos(sectors, space=8.8)
|
46 |
+
for sector in circos.sectors:
|
47 |
+
idx2label = {}
|
48 |
+
idx = 1
|
49 |
+
for k,v in hierarchy_data[sector.name.lower()]['child'].items():
|
50 |
+
for k1,v1 in v['child'].items():
|
51 |
+
idx2label[idx] = k1
|
52 |
+
idx += 1
|
53 |
+
idx2label[idx] = ''
|
54 |
+
idx2label[0] = ''
|
55 |
+
|
56 |
+
track_outer = sector.add_track((100, 101))
|
57 |
+
track_outer.xticks_by_interval(
|
58 |
+
1,
|
59 |
+
tick_length=0,
|
60 |
+
outer=True,
|
61 |
+
show_bottom_line=False,
|
62 |
+
label_orientation="vertical",
|
63 |
+
label_formatter=lambda v: idx2label[int(v)],
|
64 |
+
label_size=outer_track_font_size,
|
65 |
+
show_endlabel=True
|
66 |
+
)
|
67 |
+
|
68 |
+
track = sector.add_track(parent_track_ratio)
|
69 |
+
track.axis(fc=name2color[sector.name], lw=0)
|
70 |
+
track.text(sector.name.capitalize().replace('Mri', 'MRI').replace('Ct', 'CT').replace('Oct', 'OCT').replace('Dermoscopy', "DS"), color="white", size=parent_track_font_size)
|
71 |
+
|
72 |
+
track1 = sector.add_track(middle_track_ratio, r_pad_ratio=0.1)
|
73 |
+
sect_start = 0
|
74 |
+
color_idx = 0
|
75 |
+
for i, (k,v) in enumerate(hierarchy_data[sector.name.lower()]['child'].items()):
|
76 |
+
sect_size = len(v['child']) if i != len(hierarchy_data[sector.name.lower()]['child'])-1 else len(v['child'])+1
|
77 |
+
if i == 0:
|
78 |
+
sect_size += 0.5
|
79 |
+
if i == len(hierarchy_data[sector.name.lower()]['child'])-1:
|
80 |
+
sect_size -= 0.5
|
81 |
+
track1.rect(sect_start, sect_start+sect_size, r_lim=(middle_track_ratio[0], middle_track_ratio[1]-1), ec="black", lw=0,fc=color_schemes[sector.name][color_idx])
|
82 |
+
color_idx += 1
|
83 |
+
track1.text(k.replace('abnormality', 'abn.').replace(' anatomies', '').replace(' disturbance', '').replace('other abn.', 'Other').replace('liver', '').replace('pancreas', '').capitalize(), sect_start+sect_size/2, color="black", size=middle_track_font_size)
|
84 |
+
sect_start += sect_size
|
85 |
+
|
86 |
+
x = np.linspace(sector.start+1 , sector.end-1 , int(sector.size)-1)
|
87 |
+
y = [target_counts[idx2label[i+1]] for i in range(0,len(x))]
|
88 |
+
y_box = boxcox(y, 0.35)
|
89 |
+
|
90 |
+
track2 = sector.add_track(bar_track_ratio, r_pad_ratio=0.1)
|
91 |
+
track2.axis()
|
92 |
+
track2.yticks([1.14, 2.29, 3.43, 4.58], ["10$^2$", "10$^3$", "10$^4$", "10$^5$"], label_size=bar_track_font_size-1)
|
93 |
+
track2.bar(x, y_box, color=name2color[sector.name], alpha=0.5, align="center", lw=0)
|
94 |
+
|
95 |
+
fig = circos.plotfig()
|
96 |
+
fig.savefig('plots/figure_1a.pdf')
|
97 |
+
plt.show()
|
98 |
+
|
99 |
+
# %%
|
figures/main_figure_1b.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import json, os
|
6 |
+
import seaborn as sns
|
7 |
+
|
8 |
+
plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
|
9 |
+
|
10 |
+
# Load data
|
11 |
+
def load_data(file_path):
|
12 |
+
with open(file_path, 'r') as f:
|
13 |
+
return json.load(f)
|
14 |
+
base_dir = 'metadata'
|
15 |
+
data = load_data(os.path.join(base_dir, 'modality_counts.json'))
|
16 |
+
separate_submodality = False
|
17 |
+
|
18 |
+
# Transform data for plotting
|
19 |
+
def transform_data(data):
|
20 |
+
df = pd.DataFrame([(modality, subcat, count) for modality, subcats in data.items() for subcat, count in subcats.items()], columns=['Modality', 'Sub-category', 'Count'])
|
21 |
+
return df
|
22 |
+
|
23 |
+
df = transform_data(data)
|
24 |
+
|
25 |
+
# Calculate total counts by modality and sort
|
26 |
+
def calculate_totals(df):
|
27 |
+
total_counts_by_modality = df.groupby("Modality")["Count"].sum().sort_values(ascending=True)
|
28 |
+
sorted_modalities = total_counts_by_modality.index.tolist()
|
29 |
+
return total_counts_by_modality, sorted_modalities
|
30 |
+
|
31 |
+
total_counts_by_modality, sorted_modalities = calculate_totals(df)
|
32 |
+
|
33 |
+
# Generate color map
|
34 |
+
def generate_color_map(total_counts_by_modality):
|
35 |
+
base_colors = plt.cm.cool(np.linspace(0, 1, len(total_counts_by_modality)))
|
36 |
+
modality_color_map = {modality: base_colors[i] for i, modality in enumerate(total_counts_by_modality.index)}
|
37 |
+
return modality_color_map
|
38 |
+
|
39 |
+
modality_color_map = generate_color_map(total_counts_by_modality)
|
40 |
+
|
41 |
+
# Format total count for display
|
42 |
+
def format_total_count(total_count):
|
43 |
+
if total_count >= 1000:
|
44 |
+
exponent = int(np.floor(np.log10(total_count)))
|
45 |
+
mantissa = total_count / 10**exponent
|
46 |
+
formatted_total = f'{mantissa:.2f} x 10$^{exponent}$'
|
47 |
+
else:
|
48 |
+
exponent = 0
|
49 |
+
formatted_total = str(total_count)
|
50 |
+
return formatted_total, exponent
|
51 |
+
|
52 |
+
# Plotting function
|
53 |
+
def plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality):
|
54 |
+
fig, ax = plt.subplots(figsize=(10, 12))
|
55 |
+
current_bottom = np.zeros(len(sorted_modalities))
|
56 |
+
gap = 0.005 if separate_submodality else 0
|
57 |
+
shades = np.power(np.linspace(0.75, 1, df.groupby("Sub-category").ngroups), 2)
|
58 |
+
|
59 |
+
if separate_submodality:
|
60 |
+
for i, modality in enumerate(sorted_modalities):
|
61 |
+
subdf = df[df["Modality"] == modality].sort_values(by='Count', ascending=False)
|
62 |
+
for j, (index, row) in enumerate(subdf.iterrows()):
|
63 |
+
count = row['Count']
|
64 |
+
if count > 0:
|
65 |
+
color = np.array(modality_color_map[modality]) * shades[j % len(shades)]
|
66 |
+
ax.barh(modality, count, left=current_bottom[i], color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5)
|
67 |
+
current_bottom[i] += count + gap
|
68 |
+
current_bottom[i] -= gap
|
69 |
+
total_count = total_counts_by_modality[modality]
|
70 |
+
formatted_total, exponent = format_total_count(total_count)
|
71 |
+
ax.text(current_bottom[i] + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left')
|
72 |
+
else:
|
73 |
+
for i, modality in enumerate(sorted_modalities):
|
74 |
+
total_count = total_counts_by_modality[modality]
|
75 |
+
color = np.array(modality_color_map[modality] * shades[0])
|
76 |
+
if modality.islower():
|
77 |
+
modality = modality.capitalize()
|
78 |
+
ax.barh(modality, total_count, color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5)
|
79 |
+
formatted_total, exponent = format_total_count(total_count)
|
80 |
+
ax.text(total_count + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left')
|
81 |
+
|
82 |
+
configure_plot(ax, sorted_modalities)
|
83 |
+
|
84 |
+
plt.tight_layout()
|
85 |
+
plt.savefig("plots/data_dist_modality_bar_subbar.pdf" if separate_submodality else "plots/data_dist_modality_bar.pdf", bbox_inches="tight", pad_inches=0)
|
86 |
+
plt.show()
|
87 |
+
|
88 |
+
# Configure plot aesthetics
|
89 |
+
def configure_plot(ax, sorted_modalities):
|
90 |
+
ax.set_xscale('log')
|
91 |
+
ax.set_title("Number of images per modality", fontsize=28)
|
92 |
+
plt.yticks(rotation=0, fontsize=24, va='center')
|
93 |
+
ax.tick_params(axis='x', which='major', length=8)
|
94 |
+
ax.tick_params(axis='x', which='minor', length=5)
|
95 |
+
plt.xticks(fontsize=24)
|
96 |
+
sns.despine()
|
97 |
+
|
98 |
+
# Main script execution
|
99 |
+
plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality)
|
100 |
+
|
101 |
+
# %%
|
figures/main_figure_2a.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import seaborn as sns
|
5 |
+
import json, os
|
6 |
+
|
7 |
+
from statannot import add_stat_annotation
|
8 |
+
from statannotations.Annotator import Annotator
|
9 |
+
|
10 |
+
df = pd.read_csv('results/all_eval/all_metrics_median.csv')
|
11 |
+
|
12 |
+
|
13 |
+
metric = 'dice'
|
14 |
+
|
15 |
+
model_names = {metric: 'BiomedParse', f'medsam_{metric}': 'MedSAM (oracle box)', f'sam_{metric}': 'SAM (oracle box)',
|
16 |
+
f'dino_medsam_{metric}': 'MedSAM (Grounding DINO)', f'dino_sam_{metric}': 'SAM (Grounding DINO)'}
|
17 |
+
df = df.rename(columns=model_names)
|
18 |
+
|
19 |
+
score_vars = list(model_names.values())
|
20 |
+
|
21 |
+
|
22 |
+
modality_list = ['CT', 'MRI', 'X-Ray', 'Pathology', 'Ultrasound', 'Fundus', 'Endoscope', 'Dermoscopy', 'OCT']
|
23 |
+
# modify modality names
|
24 |
+
mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
|
25 |
+
'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
|
26 |
+
df['modality'] = df['modality'].apply(lambda x: mod_names[x])
|
27 |
+
|
28 |
+
# add an "All" modality
|
29 |
+
all_df = df.copy()
|
30 |
+
all_df['modality'] = 'All'
|
31 |
+
df = pd.concat([df, all_df])
|
32 |
+
|
33 |
+
df_long = df[['modality', 'task']+score_vars].melt(id_vars=['modality', 'task'], var_name='Model', value_name='Performance')
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
# add statistical annotations
|
38 |
+
fig, ax = plt.subplots(figsize=(9, 6))
|
39 |
+
ax = sns.boxplot(data=df_long, x='modality', y='Performance', hue='Model', ax=ax, palette='Set2',
|
40 |
+
order=['All']+modality_list,
|
41 |
+
whis=2, saturation=0.6, linewidth=0.8, fliersize=0.5) # whiskers at 5th and 95th percentile)
|
42 |
+
#errorbar='sd', capsize=0.1, errwidth=1.5)
|
43 |
+
|
44 |
+
# no frame
|
45 |
+
ax.spines['top'].set_visible(False)
|
46 |
+
ax.spines['right'].set_visible(False)
|
47 |
+
ax.spines['left'].set_visible(False)
|
48 |
+
# add arrow on y axis
|
49 |
+
ax.annotate('', xy=(0, 1.05), xytext=(0, -0.01), arrowprops=dict(arrowstyle='->', lw=1, color='black'), xycoords='axes fraction')
|
50 |
+
|
51 |
+
|
52 |
+
plt.title('')
|
53 |
+
if metric == 'dice':
|
54 |
+
plt.ylabel('Dice score', fontsize=18)
|
55 |
+
elif metric == 'assd':
|
56 |
+
plt.ylabel('ASSD', fontsize=18)
|
57 |
+
plt.xlabel('')
|
58 |
+
plt.xticks(rotation=45, fontsize=16)
|
59 |
+
plt.yticks(fontsize=14)
|
60 |
+
|
61 |
+
# axis thickness
|
62 |
+
ax.spines['bottom'].set_linewidth(1)
|
63 |
+
ax.spines['left'].set_linewidth(1)
|
64 |
+
|
65 |
+
|
66 |
+
# change to log scale
|
67 |
+
if metric == 'assd':
|
68 |
+
plt.yscale('log')
|
69 |
+
|
70 |
+
# set legend names
|
71 |
+
ax.legend(score_vars, fontsize=14)
|
72 |
+
|
73 |
+
# legend on top in a row, without frame
|
74 |
+
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.4), ncol=2, fontsize=14, frameon=False)
|
75 |
+
|
76 |
+
# Define pairs between models for each modality
|
77 |
+
box_pairs = []
|
78 |
+
|
79 |
+
# Add statistical annotations for each modality
|
80 |
+
for modality in ['All']+modality_list:
|
81 |
+
# Define pairs between models within the same modality
|
82 |
+
box_pairs += [((modality, 'BiomedParse'), (modality, 'MedSAM (oracle box)'))]
|
83 |
+
annotator = Annotator(ax, box_pairs, data=df_long, x='modality', y='Performance', hue='Model',
|
84 |
+
order=['All']+modality_list)
|
85 |
+
annotator.configure(test='t-test_paired', text_format='star', loc='inside', hide_non_significant=True)
|
86 |
+
annotator.apply_test(alternative='less')
|
87 |
+
annotator.annotate()
|
88 |
+
|
89 |
+
plt.tight_layout()
|
90 |
+
|
91 |
+
# save the plot
|
92 |
+
ax.get_figure().savefig(f'plots/{metric}_comparison.png')
|
93 |
+
ax.get_figure().savefig(f'plots/{metric}_comparison.pdf', bbox_inches='tight')
|
figures/main_figure_3b.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import seaborn as sns
|
5 |
+
import json, os
|
6 |
+
|
7 |
+
from statannot import add_stat_annotation
|
8 |
+
from statannotations.Annotator import Annotator
|
9 |
+
|
10 |
+
df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
|
11 |
+
|
12 |
+
# modify modality names
|
13 |
+
mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
|
14 |
+
'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
|
15 |
+
df['modality'] = df['modality'].apply(lambda x: mod_names[x])
|
16 |
+
|
17 |
+
# MedSAM reported tasks
|
18 |
+
reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
|
19 |
+
|
20 |
+
# find overlap between the dfs by dataset and target
|
21 |
+
overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
|
22 |
+
suffixes=('_biomedparse', '_baseline'))
|
23 |
+
# non-overlapping datasets
|
24 |
+
non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
baseline = 'medsam'
|
29 |
+
metric = 'box_ratio'
|
30 |
+
|
31 |
+
baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
|
32 |
+
metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
|
33 |
+
'IRI': 'Inversed Rotational Inertia'}
|
34 |
+
|
35 |
+
non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
|
36 |
+
# scatter plot
|
37 |
+
fig, ax = plt.subplots(figsize=(7,5))
|
38 |
+
sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
|
39 |
+
|
40 |
+
# add linear regression line
|
41 |
+
sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
|
42 |
+
color='k', line_kws={'linestyle':'--', 'linewidth':1})
|
43 |
+
|
44 |
+
# remove all spines
|
45 |
+
ax.spines['top'].set_visible(False)
|
46 |
+
ax.spines['right'].set_visible(False)
|
47 |
+
ax.spines['left'].set_visible(False)
|
48 |
+
ax.spines['bottom'].set_visible(False)
|
49 |
+
|
50 |
+
|
51 |
+
# add arrow on x-axis and y-axis
|
52 |
+
xlim = [0, 0.85]
|
53 |
+
ylim = [-0.18, 0.75]
|
54 |
+
ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
|
55 |
+
ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
|
56 |
+
ax.set_xlim(xlim)
|
57 |
+
ax.set_ylim(ylim)
|
58 |
+
|
59 |
+
ax.xaxis.set_tick_params(width=1.5)
|
60 |
+
ax.yaxis.set_tick_params(width=1.5)
|
61 |
+
|
62 |
+
# set x-ticks and y-ticks
|
63 |
+
plt.xticks(fontsize=18)
|
64 |
+
plt.yticks(fontsize=18)
|
65 |
+
|
66 |
+
# show R^2 value, p value, and equation of the line
|
67 |
+
from scipy.stats import linregress
|
68 |
+
slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
|
69 |
+
x_text = 0.4
|
70 |
+
plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
|
71 |
+
plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
|
72 |
+
plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
|
73 |
+
|
74 |
+
plt.title('')
|
75 |
+
plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
|
76 |
+
plt.xlabel(f'{metric_names[metric]}', fontsize=22)
|
77 |
+
|
78 |
+
plt.tight_layout()
|
79 |
+
|
80 |
+
# save the plot
|
81 |
+
ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
|
82 |
+
ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
|
83 |
+
|
figures/main_figure_3c.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import seaborn as sns
|
5 |
+
import json, os
|
6 |
+
|
7 |
+
from statannot import add_stat_annotation
|
8 |
+
from statannotations.Annotator import Annotator
|
9 |
+
|
10 |
+
df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
|
11 |
+
|
12 |
+
# modify modality names
|
13 |
+
mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
|
14 |
+
'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
|
15 |
+
df['modality'] = df['modality'].apply(lambda x: mod_names[x])
|
16 |
+
|
17 |
+
# MedSAM reported tasks
|
18 |
+
reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
|
19 |
+
|
20 |
+
# find overlap between the dfs by dataset and target
|
21 |
+
overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
|
22 |
+
suffixes=('_biomedparse', '_baseline'))
|
23 |
+
# non-overlapping datasets
|
24 |
+
non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
baseline = 'medsam'
|
29 |
+
metric = 'convex_ratio'
|
30 |
+
|
31 |
+
baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
|
32 |
+
metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
|
33 |
+
'IRI': 'Inversed Rotational Inertia'}
|
34 |
+
|
35 |
+
non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
|
36 |
+
# scatter plot
|
37 |
+
fig, ax = plt.subplots(figsize=(7,5))
|
38 |
+
sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
|
39 |
+
|
40 |
+
# add linear regression line
|
41 |
+
sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
|
42 |
+
color='k', line_kws={'linestyle':'--', 'linewidth':1})
|
43 |
+
|
44 |
+
# remove all spines
|
45 |
+
ax.spines['top'].set_visible(False)
|
46 |
+
ax.spines['right'].set_visible(False)
|
47 |
+
ax.spines['left'].set_visible(False)
|
48 |
+
ax.spines['bottom'].set_visible(False)
|
49 |
+
|
50 |
+
|
51 |
+
# add arrow on x-axis and y-axis
|
52 |
+
xlim = [0, 1.05]
|
53 |
+
ylim = [-0.18, 0.75]
|
54 |
+
ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
|
55 |
+
ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
|
56 |
+
ax.set_xlim(xlim)
|
57 |
+
ax.set_ylim(ylim)
|
58 |
+
|
59 |
+
ax.xaxis.set_tick_params(width=1.5)
|
60 |
+
ax.yaxis.set_tick_params(width=1.5)
|
61 |
+
|
62 |
+
# set x-ticks and y-ticks
|
63 |
+
plt.xticks(fontsize=18)
|
64 |
+
plt.yticks(fontsize=18)
|
65 |
+
|
66 |
+
# show R^2 value, p value, and equation of the line
|
67 |
+
from scipy.stats import linregress
|
68 |
+
slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
|
69 |
+
x_text = 0.4
|
70 |
+
plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
|
71 |
+
plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
|
72 |
+
plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
|
73 |
+
|
74 |
+
plt.title('')
|
75 |
+
plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
|
76 |
+
plt.xlabel(f'{metric_names[metric]}', fontsize=22)
|
77 |
+
|
78 |
+
plt.tight_layout()
|
79 |
+
plt.show()
|
80 |
+
|
81 |
+
# save the plot
|
82 |
+
ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
|
83 |
+
ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
|
figures/main_figure_3d.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import seaborn as sns
|
5 |
+
import json, os
|
6 |
+
|
7 |
+
from statannot import add_stat_annotation
|
8 |
+
from statannotations.Annotator import Annotator
|
9 |
+
|
10 |
+
df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
|
11 |
+
|
12 |
+
# modify modality names
|
13 |
+
mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
|
14 |
+
'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
|
15 |
+
df['modality'] = df['modality'].apply(lambda x: mod_names[x])
|
16 |
+
|
17 |
+
# MedSAM reported tasks
|
18 |
+
reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
|
19 |
+
|
20 |
+
# find overlap between the dfs by dataset and target
|
21 |
+
overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
|
22 |
+
suffixes=('_biomedparse', '_baseline'))
|
23 |
+
# non-overlapping datasets
|
24 |
+
non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
baseline = 'medsam'
|
29 |
+
metric = 'IRI'
|
30 |
+
|
31 |
+
baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
|
32 |
+
metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
|
33 |
+
'IRI': 'Inversed Rotational Inertia'}
|
34 |
+
|
35 |
+
non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
|
36 |
+
# scatter plot
|
37 |
+
fig, ax = plt.subplots(figsize=(7,5))
|
38 |
+
sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
|
39 |
+
|
40 |
+
# add linear regression line
|
41 |
+
sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
|
42 |
+
color='k', line_kws={'linestyle':'--', 'linewidth':1})
|
43 |
+
|
44 |
+
# remove all spines
|
45 |
+
ax.spines['top'].set_visible(False)
|
46 |
+
ax.spines['right'].set_visible(False)
|
47 |
+
ax.spines['left'].set_visible(False)
|
48 |
+
ax.spines['bottom'].set_visible(False)
|
49 |
+
|
50 |
+
|
51 |
+
# add arrow on x-axis and y-axis
|
52 |
+
xlim = [0, 1.05]
|
53 |
+
ylim = [-0.18, 0.75]
|
54 |
+
ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
|
55 |
+
ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
|
56 |
+
ax.set_xlim(xlim)
|
57 |
+
ax.set_ylim(ylim)
|
58 |
+
|
59 |
+
ax.xaxis.set_tick_params(width=1.5)
|
60 |
+
ax.yaxis.set_tick_params(width=1.5)
|
61 |
+
|
62 |
+
# set x-ticks and y-ticks
|
63 |
+
plt.xticks(fontsize=18)
|
64 |
+
plt.yticks(fontsize=18)
|
65 |
+
|
66 |
+
# show R^2 value, p value, and equation of the line
|
67 |
+
from scipy.stats import linregress
|
68 |
+
slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
|
69 |
+
x_text = 0.4
|
70 |
+
plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
|
71 |
+
plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
|
72 |
+
plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
|
73 |
+
|
74 |
+
plt.title('')
|
75 |
+
plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
|
76 |
+
plt.xlabel(f'{metric_names[metric]}', fontsize=22)
|
77 |
+
|
78 |
+
plt.tight_layout()
|
79 |
+
plt.show()
|
80 |
+
|
81 |
+
# save the plot
|
82 |
+
ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
|
83 |
+
ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
|
figures/plots/IRI_mean_improvement_medsam.pdf
ADDED
Binary file (21.2 kB). View file
|
|
figures/plots/IRI_mean_improvement_medsam.png
ADDED
figures/plots/IRI_mean_improvement_sam.pdf
ADDED
Binary file (21.2 kB). View file
|
|