diff --git a/README.md b/README.md
index d5849e783bf51fb35c080b164bafddcb2b707026..9765c505ad8e61fc5b0d0ba4b776c55c01aaa578 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,44 @@
---
-title: HakimAiV2
-emoji: π
-colorFrom: green
+title: HakimAi
+emoji: π₯
+colorFrom: blue
colorTo: green
sdk: gradio
-sdk_version: 5.9.1
+sdk_version: "5.9.0"
app_file: app.py
pinned: false
-license: cc-by-nc-4.0
-short_description: hakim ai by cbai
---
-Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference
+# HakimAi
+
+A medical imaging analysis platform powered by BiomedParse, offering comprehensive biomedical image analysis across multiple modalities.
+
+## Features
+
+- **Multi-modal Analysis**: Support for various medical imaging types including X-ray, CT, MRI, pathology, and more
+- **Advanced Detection**: Automated identification and segmentation of medical objects and conditions
+- **Interactive Interface**: User-friendly Gradio interface for easy image upload and analysis
+- **Powered by BiomedParse**: Utilizes Microsoft's BiomedParse foundation model for accurate medical image analysis
+
+## Usage
+
+1. Upload your medical image
+2. Select the analysis type
+3. View the results including segmentation masks and detection results
+
+## Technical Details
+
+This space uses:
+- BiomedParse foundation model for medical image analysis
+- Gradio for the web interface
+- Git LFS for handling large files
+- Python 3.9+ environment
+
+## Model Information
+
+Based on BiomedParse, capable of:
+- Segmentation
+- Detection
+- Recognition across nine biomedical modalities
+
+Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference
\ No newline at end of file
diff --git a/README.pdf b/README.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..6b2493ed4162e8e3d3cdda45db51a08ae8b9973c
Binary files /dev/null and b/README.pdf differ
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..f55a19e0388a05ad65c4f5766e05f4e41565913c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,242 @@
+import gradio as gr
+import os
+from typing import Tuple, Optional
+import os
+import shutil
+import sys
+from pathlib import Path
+import cv2
+import gradio as gr
+import numpy as np
+import spaces
+# import supervision as sv
+import torch
+from PIL import Image
+from tqdm import tqdm
+import sys
+from pathlib import Path
+from huggingface_hub import login
+# from dotenv import load_dotenv
+
+# For Hugging Face Spaces, secrets are automatically loaded as environment variables
+token = os.getenv("HF_TOKEN")
+if token:
+ login(token=token)
+# Clear Hugging Face cache
+# cache_dirs = [
+# "/home/user/.cache/huggingface/",
+# "/home/user/.cache/torch/",
+# "/home/user/.cache/pip/"
+# ]
+
+# for cache_dir in cache_dirs:
+# if os.path.exists(cache_dir):
+# print(f"Clearing cache: {cache_dir}")
+# shutil.rmtree(cache_dir, ignore_errors=True)
+# Add the current directory to Python path
+current_dir = Path(__file__).parent
+sys.path.append(str(current_dir))
+# sys.path.append("./BiomedParse/")
+# BIOMEDPARSE_PATH = Path(__file__).parent / "BiomedParse"
+# sys.path.append(str(BIOMEDPARSE_PATH))
+# sys.path.append(str(BIOMEDPARSE_PATH / "BiomedParse")) # Add the inner BiomedParse directory
+from modeling.BaseModel import BaseModel
+from modeling import build_model
+from utilities.arguments import load_opt_from_config_files
+from utilities.constants import BIOMED_CLASSES
+from inference_utils.inference import interactive_infer_image
+from inference_utils.output_processing import check_mask_stats
+from inference_utils.processing_utils import read_rgb
+
+import spaces
+
+# breakpoint()
+MARKDOWN = """
+#
αAαͺiα AI
+
+
+
+This demo integrates BiomedParse, a foundation model for joint segmentation, detection, and recognition across 9 biomedical imaging modalities. The model supports:
+
+- Segmentation/Detection/Recognition across multiple modalities (CT, MRI, X-Ray, etc.)
+- Text-prompted object detection
+- Recognition of anatomical structures and abnormalities
+
+
+"""
+
+IMAGE_PROCESSING_EXAMPLES = [
+ ["BiomedParse Segmentation",
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/T0011.jpg",
+ "Optic disc in retinal Fundus"],
+ ["BiomedParse Segmentation",
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/Part_3_226_pathology_breast.png",
+ "optic disc, optic cup"],
+ ["BiomedParse Segmentation",
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/covid_1585.png",
+ "COVID-19 infection in chest X-Ray"],
+ ["BiomedParse Segmentation",
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/TCGA_HT_7856_19950831_8_MRI-FLAIR_brain.png",
+ "Lower-grade glioma in brain MRI"],
+ ["BiomedParse Segmentation",
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/LIDC-IDRI-0140_143_280_CT_lung.png",
+ "COVID-19 infection in chest CT"],
+ ["BiomedParse Segmentation",
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/144DME_as_F.jpeg",
+ "Cystoid macular edema in retinal OCT"],
+ ["BiomedParse Segmentation",
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/Part_1_516_pathology_breast.png",
+ "Glandular structure in colon Pathology"],
+ ["BiomedParse Segmentation",
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/ISIC_0015551.jpg",
+ "Melanoma in skin Dermoscopy"],
+ ["BiomedParse Segmentation",
+ "https://raw.githubusercontent.com/microsoft/BiomedParse/main/examples/C3_EndoCV2021_00462.jpg",
+ "Neoplastic polyp in colon Endoscope"]
+]
+
+BIOMEDPARSE_MODES = {
+ "CT": ["abdomen", "colon", "liver", "lung", "pelvis"],
+ "MRI": ["brain", "heart", "prostate", "abdomen"],
+ "MRI-FLAIR": ["brain"],
+ "MRI-T1-Gd": ["brain"],
+ "MRI-T2": ["prostate"],
+ "OCT": ["retinal"],
+ "X-Ray": ["chest"],
+ "Dermoscopy": ["skin"],
+ "Endoscope": ["colon"],
+ "Fundus": ["retinal"],
+ "Pathology": ["bladder", "breast", "cervix", "colon", "esophagus", "kidney",
+ "liver", "ovarian", "prostate", "stomach", "testis", "thyroid", "uterus"],
+ "Ultrasound": ["breast", "heart", "transperineal"]
+}
+
+IMAGE_INFERENCE_MODES = [
+ "BIOMED SEGMENTATION",
+ "BIOMED DETECTION",
+ "BIOMED RECOGNITION",
+ "BIOMED SEGMENTATION + DETECTION",
+ "BIOMED SEGMENTATION + RECOGNITION",
+ "BIOMED DETECTION + RECOGNITION",
+ "BIOMED SEGMENTATION + DETECTION + RECOGNITION"
+]
+
+
+def on_mode_dropdown_change(selected_mode):
+ if selected_mode in IMAGE_INFERENCE_MODES:
+ # Show modality dropdown and hide other inputs initially
+ return [
+ gr.Dropdown(visible=True, choices=list(BIOMEDPARSE_MODES.keys()), label="Modality"),
+ gr.Dropdown(visible=True, label="Anatomical Site"),
+ gr.Textbox(visible=False),
+ gr.Textbox(visible=False)
+ ]
+ else:
+ # Original behavior for other modes
+ return [
+ gr.Dropdown(visible=False),
+ gr.Dropdown(visible=False),
+ gr.Textbox(visible=True),
+ gr.Textbox(visible=(selected_mode == None))
+ ]
+
+def on_modality_change(modality):
+ if modality:
+ return gr.Dropdown(choices=BIOMEDPARSE_MODES[modality], visible=True)
+ return gr.Dropdown(visible=False)
+
+
+def initialize_model():
+ opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"])
+ pretrained_pth = 'hf_hub:microsoft/BiomedParse'
+ opt['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
+ model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval()
+ with torch.no_grad():
+ model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
+ BIOMED_CLASSES + ["background"], is_eval=True
+ )
+ return model
+
+
+model = initialize_model()
+
+
+# Utility functions
+@spaces.GPU
+@torch.inference_mode()
+@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
+def process_image(image_path, text_prompts, modality):
+ image = read_rgb(image_path)
+ text_prompts = [prompt.strip() for prompt in text_prompts.split(',')]
+
+ # Run inference
+ pred_masks = interactive_infer_image(model, Image.fromarray(image), text_prompts)
+
+ # Prepare outputs
+ results = []
+ dice_scores = []
+ p_values = []
+
+ for i, prompt in enumerate(text_prompts):
+ # Calculate p-value for the selected modality
+ print("PROMPT: ", prompt, flush=True)
+ p_value = check_mask_stats(image, pred_masks[i] * 255, modality, prompt)
+ p_values.append(f"P-value for '{prompt}' ({modality}): {p_value:.4f}")
+
+ # Overlay predictions on the image
+ overlay_image = image.copy()
+ overlay_image[pred_masks[i] > 0.5] = [255, 0, 0] # Highlight predictions in red
+ results.append(overlay_image)
+
+ return results, p_values
+
+# Define Gradio interface
+with gr.Blocks() as demo:
+ gr.Markdown(MARKDOWN)
+ with gr.Row():
+ with gr.Column():
+ image_input = gr.Image(type="filepath", label="Input Image")
+ prompts_input = gr.Textbox(lines=2, placeholder="Enter prompts separated by commas...", label="Prompts")
+ modality_dropdown = gr.Dropdown(
+ choices=BIOMEDPARSE_MODES.keys(),
+ value=BIOMEDPARSE_MODES.keys()[0],
+ label="Modality"
+ )
+ submit_btn = gr.Button("Submit")
+ with gr.Column():
+ output_gallery = gr.Gallery(label="Predicted Masks")
+ pvalue_output = gr.Textbox(label="P-values", interactive=False)
+
+ submit_btn.click(
+ process_image,
+ inputs=[image_input, prompts_input, modality_dropdown],
+ outputs=[output_gallery, pvalue_output]
+ )
+ with gr.Row():
+ gr.Examples(
+ fn=process_image,
+ examples=IMAGE_PROCESSING_EXAMPLES,
+ inputs=[
+ image_processing_mode_dropdown_component,
+ image_processing_image_input_component,
+ image_processing_text_input_component
+ ],
+ outputs=[
+ image_processing_image_output_component,
+ image_processing_text_output_component
+ ],
+ run_on_click=True
+ )
+
+# Launch the app
+demo.launch()
\ No newline at end of file
diff --git a/configs/biomed_seg_lang_v1.yaml b/configs/biomed_seg_lang_v1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a0af2135f2cf6b11a085298f0e8d52f76d73f28a
--- /dev/null
+++ b/configs/biomed_seg_lang_v1.yaml
@@ -0,0 +1,330 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+# Define Test/Trainer/Saving
+PIPELINE: XDecoderPipeline
+TRAINER: xdecoder
+SAVE_DIR: './output'
+base_path: "./"
+
+# Resume Logistic
+RESUME: false
+WEIGHT: false
+RESUME_FROM: ''
+EVAL_AT_START: false
+SAVE_CHECKPOINT: True
+
+# Logging and Debug
+WANDB: False
+LOG_EVERY: 100
+FIND_UNUSED_PARAMETERS: false
+
+# Speed up training
+FP16: false
+PORT: '36873'
+
+# misc
+LOADER:
+ JOINT: True
+ KEY_DATASET: ""
+ SAMPLE_PROB: "prop" # sampling probability proportional to data size. Use "equal" for each bach from all datasets
+ MIXING_LEVEL: 1 # num of different datasets for batch mixing on each GPU
+
+RANDOM_SEED: 2024
+
+STANDARD_TEXT_FOR_EVAL: False
+
+##################
+# Task settings
+##################
+VERBOSE: true
+MODEL:
+ DEVICE: "cuda" # or "cpu" if no GPU available
+ NAME: seem_model_v1
+ HEAD: xdecoder_head
+ MASK_ON: false
+ KEYPOINT_ON: false
+ LOAD_PROPOSALS: false
+ DIM_PROJ: 512
+ TEXT:
+ ARCH: vlpencoder
+ NAME: transformer
+ TOKENIZER: clip
+ CONTEXT_LENGTH: 77 #256 # 77
+ WIDTH: 512 # 768 # 512
+ HEADS: 8
+ LAYERS: 12 # 6
+ AUTOGRESSIVE: True
+ BACKBONE:
+ NAME: focal # focal_dw # focal
+ PRETRAINED: ''
+ LOAD_PRETRAINED: false
+ FOCAL:
+ PRETRAIN_IMG_SIZE: 224
+ PATCH_SIZE: 4
+ EMBED_DIM: 192 # 96 # 192
+ DEPTHS: [2, 2, 18, 2] # [2, 2, 6, 2] # [2, 2, 18, 2]
+ FOCAL_LEVELS: [4, 4, 4, 4] # [3, 3, 3, 3] # [4, 4, 4, 4]
+ FOCAL_WINDOWS: [3, 3, 3, 3]
+ DROP_PATH_RATE: 0.3
+ MLP_RATIO: 4.0
+ DROP_RATE: 0.0
+ PATCH_NORM: True
+ USE_CONV_EMBED: True
+ SCALING_MODULATOR: True
+ USE_CHECKPOINT: False
+ USE_POSTLN: true
+ USE_POSTLN_IN_MODULATION: false
+ USE_LAYERSCALE: True
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
+ OUT_INDICES: [0, 1, 2, 3]
+ ENCODER:
+ NAME: transformer_encoder_fpn
+ IGNORE_VALUE: 255
+ NUM_CLASSES: 16
+ BINARY_CLASSES: False
+ LOSS_WEIGHT: 1.0
+ CONVS_DIM: 512
+ MASK_DIM: 512
+ NORM: "GN"
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
+ COMMON_STRIDE: 4
+ TRANSFORMER_ENC_LAYERS: 6
+ DECODER:
+ NAME: seem_v1
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
+ MASK:
+ ENABLED: True
+ DETECTION: False
+ SPATIAL:
+ ENABLED: True
+ MAX_ITER: 1
+ GROUNDING:
+ ENABLED: True
+ MAX_LEN: 10
+ TEXT_WEIGHT: 2.0
+ CLASS_WEIGHT: 0.5
+ RETRIEVAL:
+ ENABLED: False
+ LVIS:
+ ENABLED: False
+ THRES: 0.7
+ OPENIMAGE:
+ ENABLED: False
+ NEGATIVE_SAMPLES: 5
+ GROUNDING:
+ ENABLED: False
+ MAX_LEN: 5
+ CAPTION:
+ ENABLED: False
+ PHRASE_PROB: 0.5
+ SIM_THRES: 0.95
+ DEEP_SUPERVISION: True
+ NO_OBJECT_WEIGHT: 0.1
+ GCLASS_WEIGHT: 0.4
+ GMASK_WEIGHT: 1.0
+ GDICE_WEIGHT: 1.0
+ SCLASS_WEIGHT: 0.4
+ SMASK_WEIGHT: 1.0
+ SDICE_WEIGHT: 1.0
+ OCLASS_WEIGHT: 0.4
+ OMASK_WEIGHT: 1.0
+ ODICE_WEIGHT: 1.0
+ CLASS_WEIGHT: 2.0
+ MASK_WEIGHT: 5.0
+ DICE_WEIGHT: 5.0
+ BBOX_WEIGHT: 5.0
+ GIOU_WEIGHT: 2.0
+ CAPTION_WEIGHT: 2.0
+ COST_SPATIAL:
+ CLASS_WEIGHT: 5.0
+ MASK_WEIGHT: 2.0
+ DICE_WEIGHT: 2.0
+ HIDDEN_DIM: 512
+ NUM_OBJECT_QUERIES: 101
+ NHEADS: 8
+ DROPOUT: 0.0
+ DIM_FEEDFORWARD: 2048
+ MAX_SPATIAL_LEN: [512, 512, 512, 512]
+ # ENC_LAYERS: 0
+ PRE_NORM: False
+ ENFORCE_INPUT_PROJ: False
+ SIZE_DIVISIBILITY: 32
+ TRAIN_NUM_POINTS: 12544
+ OVERSAMPLE_RATIO: 3.0
+ IMPORTANCE_SAMPLE_RATIO: 0.75
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
+ TOP_GROUNDING_LAYERS: 10
+ TOP_CAPTION_LAYERS: 10
+ TOP_SPATIAL_LAYERS: 10
+ TOP_OPENIMAGE_LAYERS: 10
+ TEST:
+ SEMANTIC_ON: False
+ INSTANCE_ON: False
+ PANOPTIC_ON: False
+ OVERLAP_THRESHOLD: 0.8
+ OBJECT_MASK_THRESHOLD: 0.8
+ SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: true
+
+# Spatial sampler
+STROKE_SAMPLER:
+ MAX_CANDIDATE: 1
+ CANDIDATE_PROBS: [0.25, 0.25, 0.25, 0.25] # for training only
+ CANDIDATE_NAMES: ["Point", "Polygon", "Scribble", "Circle"]
+ DILATION: 3
+ CIRCLE:
+ NUM_STROKES: 5
+ STROKE_PRESET: ['object_like', 'object_like_middle', 'object_like_small']
+ STROKE_PROB: [0.33, 0.33, 0.33]
+ SCRIBBLE:
+ NUM_STROKES: 5
+ STROKE_PRESET: ['rand_curve', 'rand_curve_small']
+ STROKE_PROB: [0.5, 0.5]
+ POINT:
+ NUM_POINTS: 20
+ POLYGON:
+ MAX_POINTS: 9
+ EVAL:
+ MODE: 'best' # best/random/best_random
+ NEGATIVE: False
+ MAX_ITER: 1
+ IOU_ITER: 1
+ GROUNDING: True
+
+# Multi-modal Architecture, order matters
+ATTENTION_ARCH:
+ VARIABLE:
+ queries: ['object', 'grounding', 'spatial']
+ tokens: ['grounding', 'spatial']
+ memories: ['spatial']
+ SELF_ATTENTION:
+ queries:
+ object: ['queries_object']
+ grounding: ['queries_grounding', 'tokens_grounding']
+ spatial: ['queries_spatial', 'tokens_spatial', 'memories_spatial']
+ tokens:
+ grounding: ['queries_grounding', 'tokens_grounding']
+ spatial: ['tokens_spatial']
+ memories:
+ spatial: ['memories_spatial']
+ CROSS_ATTENTION:
+ queries:
+ object: True
+ grounding: True
+ spatial: True
+ memories:
+ spatial: True
+ tokens:
+ grounding: False
+ spatial: False
+ MASKING: ['tokens_spatial', 'tokens_grounding']
+ DUPLICATION:
+ queries:
+ grounding: 'queries_object'
+ spatial: 'queries_object'
+ SPATIAL_MEMORIES: 32
+ QUERY_NUMBER: 3
+
+DATASETS:
+ TRAIN: [
+ 'biomed_BiomedParseData-Demo_demo' # Add your registered training datasets here
+ ]
+
+
+
+ TEST: [
+ 'biomed_BiomedParseData-Demo_demo' # Add your registered test datasets here
+ ]
+
+ CLASS_CONCAT: false
+ SIZE_DIVISIBILITY: 32
+ PROPOSAL_FILES_TRAIN: []
+
+INPUT:
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+
+TRAIN:
+ ASPECT_RATIO_GROUPING: true
+ BATCH_SIZE_TOTAL: 4
+ BATCH_SIZE_PER_GPU: 4
+ SHUFFLE: true
+
+TEST:
+ DETECTIONS_PER_IMAGE: 100
+ NAME: coco_eval
+ IOU_TYPE: ['bbox', 'segm']
+ USE_MULTISCALE: false
+ BATCH_SIZE_TOTAL: 4
+ MODEL_FILE: ''
+ AUG:
+ ENABLED: False
+
+DATALOADER:
+ FILTER_EMPTY_ANNOTATIONS: False
+ NUM_WORKERS: 8
+ LOAD_PROPOSALS: False
+ SAMPLER_TRAIN: "TrainingSampler"
+ ASPECT_RATIO_GROUPING: True
+
+
+BioMed:
+ INPUT:
+ PIXEL_MEAN: [64.284, 59.293, 59.962]
+ PIXEL_STD: [62.484, 60.865, 59.835]
+ DATASET_MAPPER_NAME: "biomed_interactive"
+ MIN_SIZE_TRAIN: 900
+ MAX_SIZE_TRAIN: 1100
+ MIN_SIZE_TRAIN_SAMPLING: 'choice'
+ MIN_SIZE_TEST: 900
+ MAX_SIZE_TEST: 1100
+ IMAGE_SIZE: 1024
+ MIN_SCALE: 0.9
+ MAX_SCALE: 1.1
+ IGNORE_VALUE: 255
+ COLOR_AUG_SSD: False
+ SIZE_DIVISIBILITY: 32
+ RANDOM_FLIP: "none"
+ RANDOM_ROTATE: False
+ MASK_FORMAT: "polygon"
+ MIN_AREA: 30
+ FORMAT: "RGB"
+ SPATIAL: True
+ CROP:
+ ENABLED: True
+ DATASET:
+ DATASET: "biomed"
+
+
+# Detectron2 training config for optimizer and lr scheduler
+SOLVER:
+ BASE_LR: 0.0001
+ STEPS: [0.88889, 0.96296]
+ MAX_ITER: 1
+ GAMMA: 0.1
+ WARMUP_FACTOR: 1.0
+ WARMUP_ITERS: 10
+ WARMUP_METHOD: "linear"
+ WEIGHT_DECAY: 0.05
+ OPTIMIZER: "ADAMW"
+ LR_SCHEDULER_NAME: "WarmupMultiStepLR"
+ LR_MULTIPLIER:
+ backbone: 0.1
+ lang_encoder: 0.1
+ FIX_PARAM:
+ backbone: True
+ lang_encoder: True
+ pixel_decoder: True
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_EMBED: 0.0
+ CLIP_GRADIENTS:
+ ENABLED: True
+ CLIP_TYPE: "full_model"
+ CLIP_VALUE: 5.0 # 0.01
+ NORM_TYPE: 2.0
+ MAX_NUM_EPOCHS: 50
\ No newline at end of file
diff --git a/configs/biomedparse_inference.yaml b/configs/biomedparse_inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4fc422cf9418e6b67e7a5273199ee74819aa0f61
--- /dev/null
+++ b/configs/biomedparse_inference.yaml
@@ -0,0 +1,198 @@
+# Define Test/Trainer/Saving
+PIPELINE: XDecoderPipeline
+TRAINER: xdecoder
+SAVE_DIR: '../../data/output/test'
+base_path: "./"
+
+# Resume Logistic
+RESUME: false
+WEIGHT: false
+RESUME_FROM: ''
+EVAL_AT_START: false
+
+# Logging and Debug
+WANDB: False
+LOG_EVERY: 100
+FIND_UNUSED_PARAMETERS: false
+
+# Speed up training
+FP16: false
+PORT: '36873'
+
+# misc
+LOADER:
+ JOINT: False
+ KEY_DATASET: 'coco'
+
+STANDARD_TEXT_FOR_EVAL: False
+
+##################
+# Task settings
+##################
+VERBOSE: true
+MODEL:
+ device: "cuda" # or "cpu" if no GPU available
+ DEVICE: "cuda" # or "cpu" if no GPU available
+ NAME: seem_model_demo
+ HEAD: xdecoder_head
+ DIM_PROJ: 512
+ TEXT:
+ ARCH: vlpencoder
+ NAME: transformer
+ TOKENIZER: clip
+ CONTEXT_LENGTH: 77 # 77
+ WIDTH: 512
+ HEADS: 8
+ LAYERS: 12 # 6
+ AUTOGRESSIVE: True
+ BACKBONE:
+ NAME: focal
+ PRETRAINED: ''
+ LOAD_PRETRAINED: false
+ FOCAL:
+ PRETRAIN_IMG_SIZE: 224
+ PATCH_SIZE: 4
+ EMBED_DIM: 192
+ DEPTHS: [2, 2, 18, 2]
+ FOCAL_LEVELS: [4, 4, 4, 4]
+ FOCAL_WINDOWS: [3, 3, 3, 3]
+ DROP_PATH_RATE: 0.3
+ MLP_RATIO: 4.0
+ DROP_RATE: 0.0
+ PATCH_NORM: True
+ USE_CONV_EMBED: True
+ SCALING_MODULATOR: True
+ USE_CHECKPOINT: False
+ USE_POSTLN: true
+ USE_POSTLN_IN_MODULATION: false
+ USE_LAYERSCALE: True
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
+ OUT_INDICES: [0, 1, 2, 3]
+ ENCODER:
+ NAME: transformer_encoder_fpn
+ IGNORE_VALUE: 255
+ NUM_CLASSES: 16
+ BINARY_CLASSES: False
+ LOSS_WEIGHT: 1.0
+ CONVS_DIM: 512
+ MASK_DIM: 512
+ NORM: "GN"
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
+ COMMON_STRIDE: 4
+ TRANSFORMER_ENC_LAYERS: 6
+ DECODER:
+ NAME: seem_demo
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
+ MASK:
+ ENABLED: False
+ DETECTION: False
+ SPATIAL:
+ ENABLED: True
+ MAX_ITER: 1
+ GROUNDING:
+ ENABLED: True
+ MAX_LEN: 5
+ TEXT_WEIGHT: 2.0
+ CLASS_WEIGHT: 0.5
+ VISUAL:
+ ENABLED: False
+ AUDIO:
+ ENABLED: False
+ RETRIEVAL:
+ ENABLED: False
+ LVIS:
+ ENABLED: True
+ THRES: 0.7
+ OPENIMAGE:
+ ENABLED: False
+ NEGATIVE_SAMPLES: 5
+ GROUNDING:
+ ENABLED: False
+ MAX_LEN: 5
+ CAPTION:
+ ENABLED: False
+ PHRASE_PROB: 0.5
+ SIM_THRES: 0.95
+ DEEP_SUPERVISION: True
+ NO_OBJECT_WEIGHT: 0.1
+ GCLASS_WEIGHT: 0.4
+ GMASK_WEIGHT: 1.0
+ GDICE_WEIGHT: 1.0
+ SCLASS_WEIGHT: 0.4
+ SMASK_WEIGHT: 1.0
+ SDICE_WEIGHT: 1.0
+ OCLASS_WEIGHT: 0.4
+ OMASK_WEIGHT: 1.0
+ ODICE_WEIGHT: 1.0
+ CLASS_WEIGHT: 2.0
+ MASK_WEIGHT: 5.0
+ DICE_WEIGHT: 5.0
+ BBOX_WEIGHT: 5.0
+ GIOU_WEIGHT: 2.0
+ CAPTION_WEIGHT: 2.0
+ COST_SPATIAL:
+ CLASS_WEIGHT: 5.0
+ MASK_WEIGHT: 2.0
+ DICE_WEIGHT: 2.0
+ HIDDEN_DIM: 512
+ NUM_OBJECT_QUERIES: 101
+ NHEADS: 8
+ DROPOUT: 0.0
+ DIM_FEEDFORWARD: 2048
+ MAX_SPATIAL_LEN: [512, 512, 512, 512]
+ # ENC_LAYERS: 0
+ PRE_NORM: False
+ ENFORCE_INPUT_PROJ: False
+ SIZE_DIVISIBILITY: 32
+ TRAIN_NUM_POINTS: 12544
+ OVERSAMPLE_RATIO: 3.0
+ IMPORTANCE_SAMPLE_RATIO: 0.75
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
+ TOP_GROUNDING_LAYERS: 10
+ TOP_CAPTION_LAYERS: 10
+ TOP_SPATIAL_LAYERS: 10
+ TOP_OPENIMAGE_LAYERS: 10
+ TEST:
+ SEMANTIC_ON: True
+ INSTANCE_ON: True
+ PANOPTIC_ON: True
+ OVERLAP_THRESHOLD: 0.8
+ OBJECT_MASK_THRESHOLD: 0.4
+ SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false
+ DETECTIONS_PER_IMAGE: 100
+
+# Multi-modal Architecture, order matters
+ATTENTION_ARCH:
+ VARIABLE:
+ queries: ['object']
+ tokens: ['grounding', 'spatial', 'visual', 'audio']
+ SELF_ATTENTION:
+ queries:
+ object: ['queries_object', 'tokens_grounding', 'tokens_spatial', 'tokens_visual', 'tokens_audio']
+ tokens:
+ grounding: ['queries_object', 'tokens_grounding']
+ spatial: ['tokens_spatial']
+ visual: ['tokens_visual']
+ audio: ['queries_object', 'tokens_audio']
+ CROSS_ATTENTION:
+ queries:
+ object: True
+ tokens:
+ grounding: False
+ spatial: False
+ visual: False
+ audio: False
+ MASKING: ['tokens_spatial', 'tokens_grounding', 'tokens_visual', 'tokens_audio']
+ DUPLICATION:
+ queries:
+ grounding: 'queries_object'
+ spatial: 'queries_object'
+ SPATIAL_MEMORIES: 32
+
+INPUT:
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+# INPUT:
+# PIXEL_MEAN: [64.284, 59.293, 59.962]
+# PIXEL_STD: [62.484, 60.865, 59.835]
\ No newline at end of file
diff --git a/datasets/__init__.py b/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d03060a25db34aa7ca46bdc6a2e604dd4fa9bc52
--- /dev/null
+++ b/datasets/__init__.py
@@ -0,0 +1,2 @@
+from . import registration
+from .build import build_train_dataloader, build_eval_dataloader, build_evaluator
\ No newline at end of file
diff --git a/datasets/build.py b/datasets/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..f95da8f3d7a1dec10bc720e0ce6dbc90872744ea
--- /dev/null
+++ b/datasets/build.py
@@ -0,0 +1,630 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import os
+import numpy as np
+import itertools
+import logging
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.utils.data
+import torch.utils.data as torchdata
+
+import detectron2.utils.comm as comm
+from detectron2.data.build import (
+ build_batch_data_loader,
+ load_proposals_into_dataset,
+ trivial_batch_collator,
+)
+from detectron2.data import MetadataCatalog
+from detectron2.data.catalog import DatasetCatalog
+from detectron2.data.common import DatasetFromList, MapDataset
+from detectron2.data.dataset_mapper import DatasetMapper
+from detectron2.data.samplers import InferenceSampler, TrainingSampler
+from detectron2.evaluation import (
+ CityscapesInstanceEvaluator,
+ CityscapesSemSegEvaluator,
+ COCOEvaluator,
+ DatasetEvaluators,
+ LVISEvaluator,
+ verify_results,
+)
+from fvcore.common.config import CfgNode
+
+from .dataset_mappers import *
+from .evaluation import (InstanceSegEvaluator,
+ ClassificationEvaluator,
+ SemSegEvaluator,
+ RetrievalEvaluator,
+ #CaptioningEvaluator,
+ COCOPanopticEvaluator,
+ GroundingEvaluator,
+ InteractiveEvaluator,
+)
+from modeling.utils import configurable
+from utilities.distributed import get_world_size
+
+class JointLoader(torchdata.IterableDataset):
+ """
+ Randomly sampple from one of the dataloaders per worker in each iteration.
+ The sampling probability is determined by the size of each dataset.
+ All examples from one worker (GPU) are from the same dataset in the iteration.
+ Mixing is achieved through multiple workers (GPUs).
+ """
+ def __init__(self, loaders, key_dataset, sample_prob, mixing_level):
+ dataset_names = []
+ for key, loader in loaders.items():
+ name = "{}".format(key.split('_')[0])
+ setattr(self, name, loader)
+ dataset_names += [name]
+ self.dataset_names = dataset_names
+ self.key_dataset = key_dataset
+ if sample_prob == 'prop':
+ self.sample_prob = [len(getattr(self, key)) for key in self.dataset_names]
+ elif sample_prob == 'equal':
+ self.sample_prob = [1 for key in self.dataset_names]
+ elif sample_prob == 'sqrt':
+ self.sample_prob = [np.sqrt(len(getattr(self, key))) for key in self.dataset_names]
+ self.sample_prob = [p/sum(self.sample_prob) for p in self.sample_prob]
+ self.mixing_level = mixing_level
+
+ # Not sure how expensive `len(getattr(self, name))` is. computing this once and cache.
+ # this assumes the len of the underlying data loaders do not change.
+ self._len = sum(len(getattr(self, name)) for name in self.dataset_names)
+
+ def __iter__(self):
+ # Reset iterators at the start of each new epoch
+ self.iterators = {name: iter(getattr(self, name)) for name in self.dataset_names}
+ self._count = 0
+ return self
+
+ def __next__(self):
+ while self._count < self._len:
+ # Randomly select a dataloader
+ name = np.random.choice(self.dataset_names, size=None, replace=False, p=self.sample_prob)
+ iterator = self.iterators[name]
+
+ try:
+ # Get next batch from the selected dataloader
+ self._count += 1
+ return next(iterator)
+ except StopIteration:
+ # If the selected dataloader is exhausted, reinitialize it
+ self.iterators[name] = iter(getattr(self, name))
+ raise StopIteration
+
+ def __len__(self):
+ return self._len
+
+def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names):
+ """
+ Filter out images with none annotations or only crowd annotations
+ (i.e., images without non-crowd annotations).
+ A common training-time preprocessing on COCO dataset.
+
+ Args:
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
+
+ Returns:
+ list[dict]: the same format, but filtered.
+ """
+ num_before = len(dataset_dicts)
+
+ def valid(anns):
+ for ann in anns:
+ if isinstance(ann, list):
+ for instance in ann:
+ if instance.get("iscrowd", 0) == 0:
+ return True
+ else:
+ if ann.get("iscrowd", 0) == 0:
+ return True
+ return False
+
+ dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
+ num_after = len(dataset_dicts)
+ logger = logging.getLogger(__name__)
+ logger.info(
+ "Removed {} images with no usable annotations. {} images left.".format(
+ num_before - num_after, num_after
+ )
+ )
+ return dataset_dicts
+
+
+def get_detection_dataset_dicts(
+ dataset_names, filter_empty=True, proposal_files=None
+):
+ """
+ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
+
+ Args:
+ dataset_names (str or list[str]): a dataset name or a list of dataset names
+ filter_empty (bool): whether to filter out images without instance annotations
+ proposal_files (list[str]): if given, a list of object proposal files
+ that match each dataset in `dataset_names`.
+
+ Returns:
+ list[dict]: a list of dicts following the standard dataset dict format.
+ """
+ if isinstance(dataset_names, str):
+ dataset_names = [dataset_names]
+ assert len(dataset_names)
+
+ dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
+ for dataset_name, dicts in zip(dataset_names, dataset_dicts):
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
+
+ if proposal_files is not None:
+ assert len(dataset_names) == len(proposal_files)
+ # load precomputed proposals from proposal files
+ dataset_dicts = [
+ load_proposals_into_dataset(dataset_i_dicts, proposal_file)
+ for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
+ ]
+
+ dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
+
+ has_instances = "annotations" in dataset_dicts[0]
+ if filter_empty and has_instances:
+ dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names)
+
+ assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names))
+ return dataset_dicts
+
+
+def _test_loader_from_config(cfg, dataset_name, mapper=None):
+ """
+ Uses the given `dataset_name` argument (instead of the names in cfg), because the
+ standard practice is to evaluate each test set individually (not combining them).
+ """
+ if isinstance(dataset_name, str):
+ dataset_name = [dataset_name]
+
+ dataset = get_detection_dataset_dicts(
+ dataset_name,
+ filter_empty=False,
+ proposal_files=None,
+ )
+ if mapper is None:
+ mapper_cfg = CfgNode({'INPUT': cfg['INPUT'], 'MODEL': cfg['MODEL'], 'DATASETS': cfg['DATASETS']})
+ mapper = DatasetMapper(mapper_cfg, False)
+ assert cfg['TEST']['BATCH_SIZE_TOTAL'] % get_world_size() == 0, "Evaluation total batchsize is not divisible by gpu number"
+ #batch_size = cfg['TEST']['BATCH_SIZE_TOTAL'] // get_world_size()
+ batch_size = 1
+
+ return {
+ "dataset": dataset,
+ "mapper": mapper,
+ "num_workers": cfg['DATALOADER']['NUM_WORKERS'],
+ "sampler": InferenceSampler(len(dataset)),
+ "batch_size": batch_size,
+ }
+
+
+@configurable(from_config=_test_loader_from_config)
+def build_detection_test_loader(
+ dataset: Union[List[Any], torchdata.Dataset],
+ *,
+ mapper: Callable[[Dict[str, Any]], Any],
+ sampler: Optional[torchdata.Sampler] = None,
+ batch_size: int = 1,
+ num_workers: int = 0,
+ collate_fn: Optional[Callable[[List[Any]], Any]] = None,
+) -> torchdata.DataLoader:
+ """
+ Similar to `build_detection_train_loader`, with default batch size = 1,
+ and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
+ to produce the exact set of all samples.
+
+ Args:
+ dataset: a list of dataset dicts,
+ or a pytorch dataset (either map-style or iterable). They can be obtained
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
+ mapper: a callable which takes a sample (dict) from dataset
+ and returns the format to be consumed by the model.
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
+ sampler: a sampler that produces
+ indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
+ which splits the dataset across all workers. Sampler must be None
+ if `dataset` is iterable.
+ batch_size: the batch size of the data loader to be created.
+ Default to 1 image per worker since this is the standard when reporting
+ inference time in papers.
+ num_workers: number of parallel data loading workers
+ collate_fn: same as the argument of `torch.utils.data.DataLoader`.
+ Defaults to do no collation and return a list of data.
+
+ Returns:
+ DataLoader: a torch DataLoader, that loads the given detection
+ dataset, with test-time transformation and batching.
+
+ Examples:
+ ::
+ data_loader = build_detection_test_loader(
+ DatasetRegistry.get("my_test"),
+ mapper=DatasetMapper(...))
+
+ # or, instantiate with a CfgNode:
+ data_loader = build_detection_test_loader(cfg, "my_test")
+ """
+
+ if isinstance(dataset, list):
+ dataset = DatasetFromList(dataset, copy=False)
+ if mapper is not None:
+ dataset = MapDataset(dataset, mapper)
+ if isinstance(dataset, torchdata.IterableDataset):
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
+ else:
+ if sampler is None:
+ sampler = InferenceSampler(len(dataset))
+ return torchdata.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ drop_last=False,
+ num_workers=num_workers,
+ collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
+ )
+
+
+def _train_loader_from_config(cfg, dataset_name, mapper, *, dataset=None, sampler=None):
+ cfg_datasets = cfg['DATASETS']
+ cfg_dataloader = cfg['DATALOADER']
+
+ if dataset is None:
+ dataset = get_detection_dataset_dicts(
+ dataset_name,
+ filter_empty=cfg_dataloader['FILTER_EMPTY_ANNOTATIONS'],
+ proposal_files=cfg_datasets['PROPOSAL_FILES_TRAIN'] if cfg_dataloader['LOAD_PROPOSALS'] else None,
+ )
+
+ if mapper is None:
+ mapper = DatasetMapper(cfg, True)
+
+ if sampler is None:
+ sampler_name = cfg_dataloader['SAMPLER_TRAIN']
+ logger = logging.getLogger(__name__)
+ logger.info("Using training sampler {}".format(sampler_name))
+ sampler = TrainingSampler(len(dataset))
+
+ return {
+ "dataset": dataset,
+ "sampler": sampler,
+ "mapper": mapper,
+ "total_batch_size": cfg['TRAIN']['BATCH_SIZE_TOTAL'],
+ "aspect_ratio_grouping": cfg_dataloader['ASPECT_RATIO_GROUPING'],
+ "num_workers": cfg_dataloader['NUM_WORKERS'],
+ }
+
+
+@configurable(from_config=_train_loader_from_config)
+def build_detection_train_loader(
+ dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
+):
+ """
+ Build a dataloader for object detection with some default features.
+ This interface is experimental.
+
+ Args:
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
+ or a map-style pytorch dataset. They can be obtained by using
+ :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
+ mapper (callable): a callable which takes a sample (dict) from dataset and
+ returns the format to be consumed by the model.
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that
+ produces indices to be applied on ``dataset``.
+ Default to :class:`TrainingSampler`, which coordinates a random shuffle
+ sequence across all workers.
+ total_batch_size (int): total batch size across all workers. Batching
+ simply puts data into a list.
+ aspect_ratio_grouping (bool): whether to group images with similar
+ aspect ratio for efficiency. When enabled, it requires each
+ element in dataset be a dict with keys "width" and "height".
+ num_workers (int): number of parallel data loading workers
+
+ Returns:
+ torch.utils.data.DataLoader: a dataloader. Each output from it is a
+ ``list[mapped_element]`` of length ``total_batch_size / num_workers``,
+ where ``mapped_element`` is produced by the ``mapper``.
+ """
+ if isinstance(dataset, list):
+ dataset = DatasetFromList(dataset, copy=False)
+ if mapper is not None:
+ dataset = MapDataset(dataset, mapper)
+ if sampler is None:
+ sampler = TrainingSampler(len(dataset))
+ assert isinstance(sampler, torch.utils.data.sampler.Sampler)
+ return build_batch_data_loader(
+ dataset,
+ sampler,
+ total_batch_size,
+ aspect_ratio_grouping=aspect_ratio_grouping,
+ num_workers=num_workers,
+ )
+
+
+def get_config_from_name(cfg, dataset_name):
+ # adjust config according to dataset
+ if 'refcoco' in dataset_name:
+ cfg.update(cfg['REF'])
+ return cfg
+ elif 'cocomini' in dataset_name:
+ cfg.update(cfg['DAVIS'])
+ return cfg
+ elif 'ytvos' in dataset_name:
+ cfg.update(cfg['VOS'])
+ return cfg
+ elif 'ade600' in dataset_name:
+ cfg.update(cfg['DAVIS'])
+ return cfg
+ elif 'openimage600' in dataset_name:
+ cfg.update(cfg['DAVIS'])
+ return cfg
+ elif 'ade' in dataset_name:
+ if 'ADE20K' in cfg.keys():
+ cfg.update(cfg['ADE20K'])
+ return cfg
+ elif 'imagenet' in dataset_name:
+ if 'IMAGENET' in cfg.keys():
+ cfg.update(cfg['IMAGENET'])
+ return cfg
+ elif 'vlp' in dataset_name:
+ cfg.update(cfg['VLP'])
+ return cfg
+ elif 'coco' in dataset_name:
+ if 'COCO' in cfg.keys():
+ cfg.update(cfg['COCO'])
+ return cfg
+ elif 'voc' in dataset_name:
+ cfg.update(cfg['VOC'])
+ return cfg
+ elif 'context' in dataset_name:
+ cfg.update(cfg['CONTEXT'])
+ return cfg
+ elif 'sun' in dataset_name:
+ cfg.update(cfg['SUN'])
+ return cfg
+ elif 'scan' in dataset_name:
+ cfg.update(cfg['SCAN'])
+ return cfg
+ elif 'cityscape' in dataset_name:
+ cfg.update(cfg['CITY'])
+ return cfg
+ elif 'bdd' in dataset_name:
+ cfg.update(cfg['BDD'])
+ return cfg
+ elif 'tsv' in dataset_name:
+ cfg.update(cfg['TSV'])
+ return cfg
+ elif 'phrasecut' in dataset_name:
+ cfg.update(cfg['PHRASE'])
+ return cfg
+ elif 'object365' in dataset_name:
+ cfg.update(cfg['OBJECT365'])
+ return cfg
+ elif 'openimage' in dataset_name:
+ cfg.update(cfg['OPENIMAGE'])
+ return cfg
+ elif 'lvis' in dataset_name:
+ cfg.update(cfg['LVIS'])
+ return cfg
+ elif 'seginw' in dataset_name:
+ cfg.update(cfg['SEGINW'])
+ return cfg
+ elif 'sbd' in dataset_name:
+ cfg.update(cfg['SBD'])
+ return cfg
+ elif 'davis' in dataset_name:
+ cfg.update(cfg['DAVIS'])
+ return cfg
+ elif 'med_sam' in dataset_name:
+ cfg.update(cfg['MedSAM'])
+ return cfg
+ elif 'biomed' in dataset_name:
+ cfg.update(cfg['BioMed'])
+ return cfg
+ elif 'sam' in dataset_name:
+ cfg.update(cfg['SAM'])
+ return cfg
+ else:
+ assert False, "dataset not support."
+
+
+def build_eval_dataloader(cfg, ):
+ dataloaders = []
+ for dataset_name in cfg['DATASETS']['TEST']:
+ cfg = get_config_from_name(cfg, dataset_name)
+ # adjust mapper according to dataset
+ if dataset_name == 'imagenet_val':
+ mapper = ImageNetDatasetMapper(cfg, False)
+ elif dataset_name == 'bdd10k_val_sem_seg':
+ mapper = BDDSemDatasetMapper(cfg, False)
+ elif dataset_name in ["vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017"]:
+ mapper = VLPreDatasetMapper(cfg, False, dataset_name)
+ elif dataset_name in ["scannet_21_val_seg", "scannet_38_val_seg", "scannet_41_val_seg"]:
+ mapper = ScanNetSegDatasetMapper(cfg, False)
+ elif dataset_name in ["scannet_21_panoptic_val", 'bdd10k_40_panoptic_val']:
+ mapper = ScanNetPanoDatasetMapper(cfg, False)
+ elif "pascalvoc_val" in dataset_name:
+ mapper = PascalVOCSegDatasetMapperIX(cfg, False, dataset_name)
+ elif 'sun' in dataset_name:
+ mapper = SunRGBDSegDatasetMapper(cfg, False)
+ elif 'refcoco' in dataset_name:
+ mapper = RefCOCODatasetMapper(cfg, False)
+ elif 'med_sam' in dataset_name:
+ mapper = MedSAMDatasetMapper(cfg, False)
+ elif 'biomed' in dataset_name:
+ mapper = BioMedDatasetMapper(cfg, False)
+ else:
+ mapper = None
+ dataloaders += [build_detection_test_loader(cfg, dataset_name, mapper=mapper)]
+ return dataloaders
+
+
+def build_train_dataloader(cfg, ):
+ dataset_names = cfg['DATASETS']['TRAIN']
+
+ loaders = {}
+ for dataset_name in dataset_names:
+ cfg = get_config_from_name(cfg, dataset_name)
+ mapper_name = cfg['INPUT']['DATASET_MAPPER_NAME']
+ # Semantic segmentation dataset mapper
+ if mapper_name == "mask_former_semantic":
+ mapper = MaskFormerSemanticDatasetMapper(cfg, True)
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+ # Panoptic segmentation dataset mapper
+ elif mapper_name == "mask_former_panoptic":
+ mapper = MaskFormerPanopticDatasetMapper(cfg, True)
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+ # Instance segmentation dataset mapper
+ elif mapper_name == "mask_former_instance":
+ mapper = MaskFormerInstanceDatasetMapper(cfg, True)
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+ # coco instance segmentation lsj new baseline
+ elif mapper_name == "coco_instance_lsj":
+ mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+ # coco panoptic segmentation lsj new baseline
+ elif mapper_name == "coco_panoptic_lsj":
+ mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True)
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+ elif mapper_name == "vlpretrain":
+ mapper = VLPreDatasetMapper(cfg, True, dataset_name)
+ loaders['vlp'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+ elif mapper_name == "refcoco":
+ mapper = RefCOCODatasetMapper(cfg, True)
+ loaders['ref'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+ elif mapper_name == "coco_interactive":
+ mapper = COCOPanopticInteractiveDatasetMapper(cfg, True)
+ loaders['coco'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+ elif mapper_name == "medsam_interactive":
+ mapper = MedSAMDatasetMapper(cfg, True)
+ loaders['med_sam'] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+ elif mapper_name == "biomed_interactive":
+ mapper = BioMedDatasetMapper(cfg, True)
+ name_key = dataset_name.split("_")[1]
+ loaders[name_key] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+ else:
+ mapper = None
+ loaders[dataset_name] = build_detection_train_loader(cfg, dataset_name=dataset_name, mapper=mapper)
+
+ if len(loaders) == 1 or not cfg['LOADER'].get('JOINT', False):
+ return list(loaders.values())[0]
+ else:
+ sample_prob = cfg['LOADER'].get('SAMPLE_PROB', 'prop')
+ mixing_level = cfg['LOADER'].get('MIXING_LEVEL', 1)
+ return JointLoader(loaders, key_dataset=cfg['LOADER'].get('KEY_DATASET', 'coco'), sample_prob=sample_prob, mixing_level=mixing_level)
+
+
+def build_evaluator(cfg, dataset_name, output_folder=None):
+ """
+ Create evaluator(s) for a given dataset.
+ This uses the special metadata "evaluator_type" associated with each
+ builtin dataset. For your own dataset, you can simply create an
+ evaluator manually in your script and do not have to worry about the
+ hacky if-else logic here.
+ """
+ if output_folder is None:
+ output_folder = os.path.join(cfg["SAVE_DIR"], "inference")
+ evaluator_list = []
+ evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
+
+ # semantic segmentation
+ if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
+ evaluator_list.append(
+ SemSegEvaluator(
+ dataset_name,
+ distributed=True,
+ output_dir=output_folder,
+ )
+ )
+ # instance segmentation
+ if evaluator_type == "coco":
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
+
+ cfg_model_decoder_test = cfg["MODEL"]["DECODER"]["TEST"]
+ # panoptic segmentation
+ if evaluator_type in [
+ "coco_panoptic_seg",
+ "ade20k_panoptic_seg",
+ "cityscapes_panoptic_seg",
+ "mapillary_vistas_panoptic_seg",
+ "scannet_panoptic_seg",
+ "bdd_panoptic_pano"
+ ]:
+ if cfg_model_decoder_test["PANOPTIC_ON"]:
+ evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
+ # COCO
+ if (evaluator_type == "coco_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]) or evaluator_type == "object365_od":
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
+ if (evaluator_type == "coco_panoptic_seg" and cfg_model_decoder_test["SEMANTIC_ON"]) or evaluator_type == "coco_sem_seg":
+ evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
+ # Mapillary Vistas
+ if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]:
+ evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
+ if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg_model_decoder_test["SEMANTIC_ON"]:
+ evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
+ # Cityscapes
+ if evaluator_type == "cityscapes_instance":
+ assert (
+ torch.cuda.device_count() > comm.get_rank()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ return CityscapesInstanceEvaluator(dataset_name)
+ if evaluator_type == "cityscapes_sem_seg":
+ assert (
+ torch.cuda.device_count() > comm.get_rank()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ return CityscapesSemSegEvaluator(dataset_name)
+ if evaluator_type == "cityscapes_panoptic_seg":
+ if cfg_model_decoder_test["SEMANTIC_ON"]:
+ assert (
+ torch.cuda.device_count() > comm.get_rank()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ evaluator_list.append(CityscapesSemSegEvaluator(dataset_name))
+ if cfg_model_decoder_test["INSTANCE_ON"]:
+ assert (
+ torch.cuda.device_count() > comm.get_rank()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ evaluator_list.append(CityscapesInstanceEvaluator(dataset_name))
+ # ADE20K
+ if evaluator_type == "ade20k_panoptic_seg" and cfg_model_decoder_test["INSTANCE_ON"]:
+ evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
+ # SEGINW
+ if evaluator_type == "seginw" and cfg_model_decoder_test["INSTANCE_ON"]:
+ evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
+ # LVIS
+ if evaluator_type == "lvis":
+ return LVISEvaluator(dataset_name, output_dir=output_folder)
+ # Classification
+ if evaluator_type == "classification":
+ evaluator_list.append(ClassificationEvaluator(dataset_name, output_folder))
+ # Retrieval
+ if evaluator_type in ["retrieval"]:
+ evaluator_list.append(RetrievalEvaluator(dataset_name, output_folder, cfg['MODEL']['DECODER']['RETRIEVAL']['ENSEMBLE']))
+ if evaluator_type == "captioning":
+ evaluator_list.append(CaptioningEvaluator(dataset_name, output_folder, MetadataCatalog.get(dataset_name).gt_json))
+ if evaluator_type in ["grounding_refcoco", "grounding_phrasecut", "grounding_spatial", "grounding_entity"]:
+ evaluator_list.append(GroundingEvaluator(dataset_name))
+ # Interactive
+ if evaluator_type in ["interactive", "interactive_grounding"]:
+ evaluator_list.append(InteractiveEvaluator(dataset_name, output_dir=output_folder, max_clicks=cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'], iou_iter=cfg['STROKE_SAMPLER']['EVAL']['IOU_ITER']))
+
+ if len(evaluator_list) == 0:
+ raise NotImplementedError(
+ "no Evaluator for the dataset {} with the type {}".format(
+ dataset_name, evaluator_type
+ )
+ )
+ elif len(evaluator_list) == 1:
+ return evaluator_list[0]
+
+
+ return DatasetEvaluators(evaluator_list)
diff --git a/datasets/dataset_mappers/__init__.py b/datasets/dataset_mappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..000e03e43cd2fb90b9afa48c6d4a09953f28e43f
--- /dev/null
+++ b/datasets/dataset_mappers/__init__.py
@@ -0,0 +1 @@
+from .biomed_dataset_mapper import BioMedDatasetMapper
\ No newline at end of file
diff --git a/datasets/dataset_mappers/biomed_dataset_mapper.py b/datasets/dataset_mappers/biomed_dataset_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f101cccfae40deaafd2ef7990f014ad221378ca
--- /dev/null
+++ b/datasets/dataset_mappers/biomed_dataset_mapper.py
@@ -0,0 +1,378 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py
+import copy
+import logging
+import random
+
+import numpy as np
+import torch
+
+from transformers import AutoTokenizer, LlamaForCausalLM
+
+from detectron2.data import detection_utils as utils
+from detectron2.data import transforms as T
+from detectron2.data.transforms import TransformGen
+from detectron2.structures import BitMasks, Boxes, Instances, BoxMode
+from detectron2.structures.boxes import pairwise_iou
+from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
+from detectron2.data import MetadataCatalog
+from pycocotools import mask as coco_mask
+
+from utilities import prompt_engineering
+from modeling.language import build_tokenizer
+from modeling.language.misc import text_noun_with_prompt_all
+from modeling.utils import configurable
+
+from ..visual_sampler.sampler import build_shape_sampler
+
+__all__ = ["BioMedDatasetMapper"]
+
+
+def build_transform_gen(cfg, is_train):
+ """
+ Create a list of default :class:`Augmentation` from config.
+ Now it includes resizing and flipping.
+ Returns:
+ list[Augmentation]
+ """
+ assert is_train, "Only support training augmentation"
+ cfg_input = cfg['INPUT']
+ image_size = cfg_input['IMAGE_SIZE']
+ min_scale = cfg_input['MIN_SCALE']
+ max_scale = cfg_input['MAX_SCALE']
+
+ augmentation = []
+
+ if cfg_input['RANDOM_FLIP'] != "none":
+ augmentation.append(
+ T.RandomFlip(
+ horizontal=cfg_input['RANDOM_FLIP'] == "horizontal",
+ vertical=cfg_input['RANDOM_FLIP'] == "vertical",
+ )
+ )
+
+ augmentation.extend([
+ T.ResizeScale(
+ min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
+ ),
+ T.FixedSizeCrop(crop_size=(image_size, image_size)),
+ ])
+
+ return augmentation
+
+def build_transform_gen_se(cfg, is_train):
+ # min_scale = cfg['INPUT']['MIN_SIZE_TEST']
+ # max_scale = cfg['INPUT']['MAX_SIZE_TEST']
+
+ augmentation = []
+ # augmentation.extend([
+ # T.ResizeShortestEdge(
+ # min_scale, max_size=max_scale
+ # ),
+ # ])
+ return augmentation
+
+def convert_coco_poly_to_mask(segmentations, height, width):
+ masks = []
+ for polygons in segmentations:
+ rles = coco_mask.frPyObjects(polygons, height, width)
+ mask = coco_mask.decode(rles)
+ if len(mask.shape) < 3:
+ mask = mask[..., None]
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
+ mask = mask.any(dim=2)
+ masks.append(mask)
+ if masks:
+ masks = torch.stack(masks, dim=0)
+ else:
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
+ return masks
+
+# This is specifically designed for the COCO dataset.
+class BioMedDatasetMapper:
+ """
+ A callable which takes a dataset dict in Detectron2 Dataset format,
+ and map it into a format used by MaskFormer.
+
+ This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
+
+ The callable currently does the following:
+
+ 1. Read the image from "file_name"
+ 2. Applies geometric transforms to the image and annotation
+ 3. Find and applies suitable cropping to the image and annotation
+ 4. Prepare image and annotation to Tensors
+ """
+
+ @configurable
+ def __init__(
+ self,
+ is_train=True,
+ *,
+ tfm_gens,
+ image_format,
+ caption_thres,
+ grounding,
+ lvis,
+ lvis_thres,
+ max_grounding_num,
+ shape_sampler,
+ retrieval,
+ max_token_num,
+ tokenizer,
+ binary_classes: bool,
+ rotate: bool,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ is_train: for training or inference
+ augmentations: a list of augmentations or deterministic transforms to apply
+ crop_gen: crop augmentation
+ tfm_gens: data augmentation
+ image_format: an image format supported by :func:`detection_utils.read_image`.
+ """
+ self.tfm_gens = tfm_gens
+ logging.getLogger(__name__).info(
+ "[COCOPanopticNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(
+ str(self.tfm_gens)
+ )
+ )
+
+ self.img_format = image_format
+ self.is_train = is_train
+ self.caption_thres = caption_thres
+ self.grounding = grounding
+ self.lvis = lvis
+ self.lvis_thres = lvis_thres
+ self.max_grounding_num = max_grounding_num
+
+ self.shape_sampler = shape_sampler
+
+ self.retrieval = retrieval
+ self.tokenizer = tokenizer
+ self.max_token_num = max_token_num
+
+ self.binary_classes = binary_classes
+ self.rotate = rotate
+
+ @classmethod
+ def from_config(cls, cfg, is_train=True):
+ # Build augmentation
+ if is_train:
+ tfm_gens = build_transform_gen(cfg, is_train)
+ else:
+ tfm_gens = build_transform_gen_se(cfg, is_train)
+
+ shape_sampler = build_shape_sampler(cfg)
+
+ retrieval = cfg['MODEL']['DECODER']['RETRIEVAL']['ENABLED']
+ tokenizer, max_token_num = None, None
+ if retrieval:
+ lang_model = cfg['MODEL']['TEXT']['NAME']
+ max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
+ if 'llama' in lang_model:
+ tokenizer = AutoTokenizer.from_pretrained(lang_model, padding_side='right')
+ tokenizer.pad_token = tokenizer.eos_token
+ else:
+ tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
+
+ ret = {
+ "is_train": is_train,
+ "tfm_gens": tfm_gens,
+ "image_format": cfg['INPUT']['FORMAT'],
+ "caption_thres": cfg['MODEL']['DECODER']['CAPTION']['SIM_THRES'],
+ "grounding": cfg['MODEL']['DECODER']['GROUNDING']['ENABLED'],
+ "lvis": cfg['MODEL']['DECODER']['LVIS']['ENABLED'],
+ "lvis_thres": cfg['MODEL']['DECODER']['LVIS']['THRES'],
+ "max_grounding_num": cfg['MODEL']['DECODER']['GROUNDING']['MAX_LEN'],
+ "shape_sampler": shape_sampler,
+ "retrieval": retrieval,
+ "max_token_num": max_token_num,
+ "tokenizer": tokenizer,
+ "binary_classes": cfg['MODEL']['ENCODER']['BINARY_CLASSES'],
+ "rotate": cfg['INPUT']['RANDOM_ROTATE'],
+ }
+ return ret
+
+ def __call__(self, dataset_dict):
+ """
+ Args:
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
+
+ Returns:
+ dict: a format that builtin models in detectron2 accept
+ """
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
+ while True:
+ try:
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
+ break
+ except:
+ print('Image loading error:', dataset_dict["file_name"])
+
+ utils.check_image_size(dataset_dict, image)
+
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
+ image_shape = image.shape[:2] # h, w
+
+ rotate_time = 0
+ if self.is_train and self.rotate and random.random() < 0.5:
+ rotate_time = random.randint(1, 3)
+ if rotate_time > 0:
+ image = np.rot90(image, rotate_time)
+
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
+ # Therefore it's important to use torch.Tensor.
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
+
+
+ grounding_anno = dataset_dict['grounding_info']
+ if len(grounding_anno) == 0:
+ print(dataset_dict['file_name'])
+ assert len(grounding_anno) > 0
+ masks_grd = []
+ texts_grd = []
+ boxes_grd = []
+ hash_grd = []
+ classes = []
+ masks_orig = []
+ for ann in grounding_anno:
+ if 'segmentation' in ann:
+ if len(ann['segmentation']) == 0:
+ print('Empty segmentation!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
+ continue
+ rle = coco_mask.frPyObjects(
+ ann['segmentation'], dataset_dict['height'], dataset_dict['width'])
+ m = coco_mask.decode(rle)
+ masks_orig.append(m)
+ # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = np.sum(m, axis=2)
+ else:
+ # directly read from mask file
+ while True:
+ try:
+ m = utils.read_image(ann["mask_file"], format=self.img_format)
+ break
+ except:
+ print('Image loading error:', ann["mask_file"])
+ m = np.sum(m, axis=2)
+ m = 1 * (m > 0)
+ m = m.astype(np.uint8) # convert to np.uint8
+ m = transforms.apply_segmentation(255*m[:,:,None])[:,:,0]
+ if rotate_time > 0:
+ m = np.rot90(m, rotate_time)
+ masks_grd += [m]
+ rand_id = random.randint(0, len(ann['sentences'])-1)
+ texts_grd.append(ann['sentences'][rand_id]['raw'].lower())
+ hash_grd.append(hash(ann['sentences'][rand_id]['raw'].lower()))
+ if self.binary_classes:
+ ann["category_id"] = 1 * (ann["category_id"] > 0)
+ classes.append(ann["category_id"])
+ #masks_grd = torch.from_numpy(np.stack(masks_grd))
+ boxes_grd = torch.tensor(boxes_grd)
+ groundings = {'masks': masks_grd, 'texts': texts_grd, 'hash': hash_grd, 'mode': 'text'}
+ dataset_dict["groundings"] = groundings
+
+ masks_grd = torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks_grd])
+
+ instances = Instances(image_shape)
+
+ instances.gt_masks = BitMasks(masks_grd)
+ instances.gt_boxes = BitMasks(masks_grd).get_bounding_boxes()
+
+ classes = np.array(classes)
+ is_things = np.array([1 for _ in classes])
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
+ instances.is_things = torch.tensor(is_things, dtype=torch.int64)
+
+ dataset_dict["instances"] = instances
+
+
+ spatial_query_utils = self.shape_sampler(instances)
+ dataset_dict['spatial_query'] = spatial_query_utils
+
+ if self.retrieval:
+ captions = dataset_dict['captions']
+ tokens = self.tokenizer(
+ captions, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
+ )
+ dataset_dict['tokens'] = {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}
+
+ if self.grounding:
+ grounding_anno = dataset_dict['grounding_info']
+ grounding_len = random.randint(1, self.max_grounding_num-1)
+ if len(grounding_anno) > 0:
+ masks_grd = []
+ texts_grd = []
+ mode = 'text'
+ random.shuffle(grounding_anno)
+ for ann in grounding_anno:
+ if 'segmentation' in ann:
+ if len(ann['segmentation']) == 0:
+ print('Empty segmentation!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
+ continue
+ rle = coco_mask.frPyObjects(
+ ann['segmentation'], dataset_dict['height'], dataset_dict['width'])
+ m = coco_mask.decode(rle)
+ # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = np.sum(m, axis=2)
+ else:
+ # directly read from mask file
+ while True:
+ try:
+ m = utils.read_image(ann["mask_file"], format=self.img_format)
+ break
+ except:
+ print('Image loading error:', ann["mask_file"])
+ m = np.sum(m, axis=2)
+ m = 1 * (m > 0)
+
+ m = m.astype(np.uint8) # convert to np.uint8
+ m = transforms.apply_segmentation(m[:,:,None])[:,:,0]
+ if rotate_time > 0:
+ m = np.rot90(m, rotate_time)
+ masks_grd += [m]
+ # random select a sentence of a single annotation.
+ rand_index = random.randint(0, len(ann['sentences'])-1)
+ texts_grd += [ann['sentences'][rand_index]['raw'].lower()]
+ # max_len = min(grounding_len, len(texts_grd))
+ max_len = len(masks_grd)
+ indices = np.random.permutation(max_len)
+ texts_grd = list(np.array(texts_grd)[indices])
+ masks_grd = torch.tensor(np.stack(masks_grd)[indices])
+ hash_grd = np.array([hash(txt) for txt in texts_grd])
+ else:
+ masks_grd = instances.gt_masks.tensor
+ mode = 'class'
+ if len(masks_grd) == 0:
+ masks_grd = torch.tensor([])
+ texts_grd = ['none']
+ hash_grd = np.array([hash(txt) for txt in texts_grd])
+ else:
+ biomed_classes = ['liver', 'lung', 'kidney', 'pancreas', 'heart anatomies', 'brain anatomies',
+ 'eye anatomies', 'vessel', 'other organ', 'tumor', 'infection', 'other lesion',
+ 'fluid disturbance', 'other abnormality', 'histology structure', 'other']
+ if self.binary_classes:
+ biomed_classes = ['target']
+ texts_grd = np.array(biomed_classes)
+ hash_grd = np.array([hash(txt) for txt in texts_grd])
+ unique_hash_grd = np.unique(hash_grd)
+ np.random.shuffle(unique_hash_grd)
+ max_len = min(grounding_len, len(unique_hash_grd))
+ indices = np.random.permutation(max_len)
+ selected_unique_hash_grd = unique_hash_grd[indices]
+ selected_mask = np.in1d(hash_grd, selected_unique_hash_grd)
+ texts_grd = texts_grd[selected_mask]
+ hash_grd = hash_grd[selected_mask]
+ masks_grd = masks_grd[selected_mask]
+ texts_grd = [prompt_engineering(text.replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
+ for text in texts_grd]
+ groundings = {'masks': masks_grd, 'texts': texts_grd, 'mode': mode, 'hash': hash_grd}
+ dataset_dict["groundings"] = groundings
+ assert len(masks_grd) == len(dataset_dict['grounding_info']), f"len(masks_grd)={len(masks_grd)}, len(dataset_dict['grounding_info'])={len(dataset_dict['grounding_info'])}, mask shape={masks_grd.shape}, max_len={max_len}, grounding_len={grounding_len}, len(texts_grd)={len(texts_grd)}, len(hash_grd)={len(hash_grd)}"
+ # gt_masks_orisize = torch.stack([torch.from_numpy(m.squeeze(-1)) for m in masks_orig])
+ # dataset_dict['gt_masks_orisize'] = gt_masks_orisize # (nm,h,w)
+
+ return dataset_dict
diff --git a/datasets/evaluation/__init__.py b/datasets/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..43d64b47a36d60c171cbcc1bd2d92babc3e25e25
--- /dev/null
+++ b/datasets/evaluation/__init__.py
@@ -0,0 +1,8 @@
+from .instance_evaluation import *
+from .classification_evaluation import *
+from .segmentation_evaluation import *
+from .retrieval_evaluation import *
+#from .captioning_evaluation import *
+from .panoptic_evaluation import *
+from .grounding_evaluation import *
+from .interactive_evaluation import *
\ No newline at end of file
diff --git a/datasets/evaluation/captioning_evaluation.py b/datasets/evaluation/captioning_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e941948e2f23fababc41dcec8bdf6777b9cd676
--- /dev/null
+++ b/datasets/evaluation/captioning_evaluation.py
@@ -0,0 +1,129 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import os
+import json
+import logging
+import itertools
+
+import detectron2.utils.comm as comm
+from detectron2.evaluation.evaluator import DatasetEvaluator
+
+from caption_pycocotools.coco import COCO
+from pycocoevalcap.eval import COCOEvalCap
+
+
+class CaptioningEvaluator(DatasetEvaluator):
+ """
+ Evaluate AR for object proposals, AP for instance detection/segmentation, AP
+ for keypoint detection outputs using COCO's metrics.
+ See http://cocodataset.org/#detection-eval and
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
+ the metric cannot be computed (e.g. due to no predictions made).
+ In addition to COCO, this evaluator is able to support any bounding box detection,
+ instance segmentation, or keypoint detection dataset.
+ """
+
+ def __init__(
+ self,
+ distributed=True,
+ output_dir=None,
+ gt_json=None,
+ ):
+ """
+ Args:
+ dataset_name (str): name of the dataset to be evaluated.
+ It must have either the following corresponding metadata:
+ "json_file": the path to the COCO format annotation
+ Or it must be in detectron2's standard dataset format
+ so it can be converted to COCO format automatically.
+ tasks (tuple[str]): tasks that can be evaluated under the given
+ configuration. A task is one of "bbox", "segm", "keypoints".
+ By default, will infer this automatically from predictions.
+ distributed (True): if True, will collect results from all ranks and run evaluation
+ in the main process.
+ Otherwise, will only evaluate the results in the current process.
+ output_dir (str): optional, an output directory to dump all
+ results predicted on the dataset. The dump contains two files:
+ 1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
+ contains all the results in the format they are produced by the model.
+ 2. "coco_instances_results.json" a json file in COCO's result format.
+ max_dets_per_image (int): limit on the maximum number of detections per image.
+ By default in COCO, this limit is to 100, but this can be customized
+ to be greater, as is needed in evaluation metrics AP fixed and AP pool
+ (see https://arxiv.org/pdf/2102.01066.pdf)
+ This doesn't affect keypoint evaluation.
+ use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
+ Although the results should be very close to the official implementation in COCO
+ API, it is still recommended to compute results with the official API for use in
+ papers. The faster implementation also uses more RAM.
+ kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
+ See http://cocodataset.org/#keypoints-eval
+ When empty, it will use the defaults in COCO.
+ Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
+ allow_cached_coco (bool): Whether to use cached coco json from previous validation
+ runs. You should set this to False if you need to use different validation data.
+ Defaults to True.
+ """
+ self._logger = logging.getLogger(__name__)
+ self._distributed = distributed
+ self._output_dir = output_dir
+ self._gt_json = COCO(gt_json)
+
+ def reset(self):
+ self._gen_captions = []
+ self._image_ids = []
+
+ def process(self, inputs, outputs):
+ """
+ Args:
+ inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
+ It is a list of dict. Each dict corresponds to an image and
+ contains keys like "height", "width", "file_name", "image_id".
+ outputs: the outputs of a COCO model. It is a list of dicts with key
+ "instances" that contains :class:`Instances`.
+ """
+ for output in outputs:
+ self._image_ids.append(output['image_id'])
+ self._gen_captions.append(output['captioning_text'])
+
+ def evaluate(self, img_ids=None):
+ """
+ Args:
+ img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
+ """
+
+ if self._distributed:
+ comm.synchronize()
+ def gather(x, move=False):
+ x = comm.gather(x)
+ x = list(itertools.chain(*x))
+ if move:
+ x = [xx.to(self._gen_captions[0].device) for xx in x]
+ return x
+ gen_captions = gather(self._gen_captions)
+ image_ids = gather(self._image_ids)
+ if not comm.is_main_process():
+ return {}
+ else:
+ gen_captions = self._gen_captions
+ image_ids = self._image_ids
+
+ assert len(gen_captions) == len(image_ids)
+ pred_captions = [{"image_id": image_id, "caption": gen_caption} for image_id, gen_caption in zip(image_ids, gen_captions)]
+ pred_pth = os.path.join(self._output_dir, 'results.json')
+ json.dump(pred_captions, open(pred_pth, "w"))
+
+ gt_captions = self._gt_json
+ pred_captions = gt_captions.loadRes(pred_pth)
+
+ cocoEval = COCOEvalCap(gt_captions, pred_captions)
+ cocoEval.params['image_id'] = pred_captions.getImgIds()
+ cocoEval.evaluate()
+ return cocoEval.eval
\ No newline at end of file
diff --git a/datasets/evaluation/classification_evaluation.py b/datasets/evaluation/classification_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..db9e2d973c443dcc441a2db86a1639b30860e470
--- /dev/null
+++ b/datasets/evaluation/classification_evaluation.py
@@ -0,0 +1,76 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import torch
+import logging
+
+from detectron2.evaluation.evaluator import DatasetEvaluator
+
+from utilities.misc import AverageMeter
+from utilities.distributed import get_world_size
+
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ if isinstance(output, list):
+ output = output[-1]
+
+ n_classes = output.size()[1]
+ maxk = min(max(topk), n_classes)
+ batch_size = target.size(0)
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / batch_size).item())
+ return res
+
+class ClassificationEvaluator(DatasetEvaluator):
+ def __init__(self, *args):
+ self.top1 = AverageMeter()
+ self.top5 = AverageMeter()
+ self._logger = logging.getLogger(__name__)
+
+ def reset(self):
+ self.top1.reset()
+ self.top5.reset()
+
+ def process(self, inputs, outputs):
+ logits = torch.stack([o['pred_class'] for o in outputs])
+ y = torch.tensor([t['class_id'] for t in inputs], device=logits.device)
+ prec1, prec5 = accuracy(logits, y, (1, 5))
+ self.top1.update(prec1, y.size(0))
+ self.top5.update(prec5, y.size(0))
+
+ def evaluate(self):
+ if get_world_size() > 1:
+ tmp_tensor = torch.tensor(
+ [self.top1.sum, self.top5.sum, self.top1.count],
+ device=torch.cuda.current_device()
+ )
+ torch.distributed.all_reduce(
+ tmp_tensor, torch.distributed.ReduceOp.SUM
+ )
+ top1_sum, top5_sum, count = tmp_tensor.tolist()
+ else:
+ top1_sum = self.top1.sum
+ top5_sum = self.top5.sum
+ count = self.top1.count
+
+ results = {}
+ scores = {
+ 'top1': top1_sum / count,
+ "top5": top5_sum / count
+ }
+ results['class'] = scores
+ self._logger.info(results)
+ return results
diff --git a/datasets/evaluation/grounding_evaluation.py b/datasets/evaluation/grounding_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..345d4f8779938752de8e9ecfba75be5bf5fe0a53
--- /dev/null
+++ b/datasets/evaluation/grounding_evaluation.py
@@ -0,0 +1,173 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+import logging
+import torch
+from torchvision.ops import box_iou
+
+from detectron2.structures import BoxMode
+from detectron2.data import MetadataCatalog
+from detectron2.utils.comm import all_gather, is_main_process, synchronize
+from detectron2.evaluation.evaluator import DatasetEvaluator
+
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+
+import copy
+
+class GroundingEvaluator(DatasetEvaluator):
+ """
+ Evaluate grounding segmentation metrics.
+ """
+
+ def __init__(
+ self,
+ dataset_name,
+ compute_box=False,
+ distributed=True,
+ ):
+ self._logger = logging.getLogger(__name__)
+ self._dataset_name = dataset_name
+ self._distributed = distributed
+ self._cpu_device = torch.device("cpu")
+ self._compute_box = compute_box
+ meta = MetadataCatalog.get(dataset_name)
+
+ def reset(self):
+ self.cum_I = 0
+ self.cum_U = 0
+ self.mIoU = 0
+ self.mDice = 0
+ self.cum_mean_area = 0
+ self.eval_seg_iou_list = [.5, .6, .7, .8, .9]
+ self.seg_correct = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
+ self.seg_total = 0
+ self.instance_results = []
+ if self._compute_box:
+ self.mIoU_box = 0
+ self.seg_correct_box = torch.zeros(len(self.eval_seg_iou_list), device=self._cpu_device)
+
+ @staticmethod
+ def computeIoU(pred_seg, gd_seg):
+ I = (pred_seg & gd_seg)
+ U = (pred_seg | gd_seg)
+ return I, U
+
+ def get_metadata(self, _input):
+ """
+ Extracts and returns specific metadata from the input dictionary.
+
+ Parameters:
+ _input (dict): A dictionary containing keys like 'file_name', 'image_id', and 'grounding_info'.
+ The 'grounding_info' is a list of dictionaries with keys like 'area', 'iscrowd', etc.
+
+ Returns:
+ dict: A dictionary containing filtered metadata.
+ """
+
+ _input = copy.deepcopy(_input)
+
+ selected_input_keys = ['file_name', 'image_id', 'grounding_info']
+ selected_grounding_info_keys = ['area', 'mask_file', 'iscrowd', 'image_id', 'category_id', 'id', 'file_name', 'split', 'ann_id', 'ref_id']
+
+ filtered_input = {key: _input[key] for key in selected_input_keys if key in _input}
+
+ # Check if grounding_info is present and is a list
+ if 'grounding_info' in filtered_input and isinstance(filtered_input['grounding_info'], list):
+ # Filter each grounding_info dictionary
+ filtered_input['grounding_info'] = [
+ {key: info[key] for key in selected_grounding_info_keys if key in info}
+ for info in filtered_input['grounding_info']
+ ]
+
+ return filtered_input
+
+ def process(self, inputs, outputs):
+ for input, output in zip(inputs, outputs):
+ pred = output['grounding_mask'].sigmoid() > 0.5
+ # # save pixel probability
+ # prob = output['grounding_mask'].sigmoid().cpu().numpy()[0] * 255
+ # pred_file = input['file_name'].split('.')[0].replace('test/', 'test_pred/') + '_' + input['groundings']['texts'][0].replace(' ', '+') + '.png'
+ # if not os.path.exists('/'.join(pred_file.split('/')[:-1])):
+ # os.makedirs('/'.join(pred_file.split('/')[:-1]), exist_ok=True)
+ # plt.imsave(pred_file,
+ # prob.astype(np.uint8), cmap='gray')
+
+ gt = input['groundings']['masks'].bool()
+ bsi = len(pred)
+ I, U = self.computeIoU(pred, gt)
+ self.cum_I += I.sum().cpu()
+ self.cum_U += U.sum().cpu()
+ IoU = I.reshape(bsi,-1).sum(-1)*1.0 / (U.reshape(bsi,-1).sum(-1) + 1e-6)
+ self.mIoU += IoU.sum().cpu()
+ # Add Dice score in eval
+ Dice = I.reshape(bsi,-1).sum(-1)*2.0 / (gt.reshape(bsi,-1).sum(-1) + pred.reshape(bsi,-1).sum(-1) + 1e-6)
+ self.mDice += Dice.sum().cpu()
+ self.cum_mean_area += ((gt.reshape(bsi,-1).sum(-1) + pred.reshape(bsi,-1).sum(-1)) / 2.0).sum().cpu()
+
+ if self._compute_box:
+ pred_box = BoxMode.convert(output['grounding_box'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
+ gt_box = BoxMode.convert(input['groundings']['boxes'], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS).cpu()
+ IoU_box = box_iou(pred_box, gt_box).diagonal()
+ self.mIoU_box += IoU_box.sum()
+
+ for idx in range(len(self.eval_seg_iou_list)):
+ eval_seg_iou = self.eval_seg_iou_list[idx]
+ self.seg_correct[idx] += (IoU >= eval_seg_iou).sum().cpu()
+ if self._compute_box:
+ self.seg_correct_box[idx] += (IoU_box >= eval_seg_iou).sum().cpu()
+ self.seg_total += bsi
+
+ instance_result = {
+ 'metadata': self.get_metadata(input),
+ 'IoU': IoU.cpu().numpy().tolist(),
+ 'Dice': Dice.cpu().numpy().tolist(),
+ 'I': I.sum(dim=(1, 2)).cpu().numpy().tolist(),
+ 'U': U.sum(dim=(1, 2)).cpu().numpy().tolist(),
+ 'IoU_box': IoU_box.cpu().numpy().tolist() if self._compute_box else '',
+ 'pred_area': pred.reshape(bsi,-1).sum(-1).cpu().numpy().tolist(),
+ }
+
+ iou_len = IoU.shape[0]
+ grounding_info_len = len(self.get_metadata(input)['grounding_info'])
+ assert iou_len == grounding_info_len, f'Number of IoU scores ({iou_len}) and grounding info ({grounding_info_len}) do not match.'
+ self.instance_results.append(instance_result)
+
+ def evaluate(self):
+ if self._distributed:
+ synchronize()
+ self.cum_I = torch.stack(all_gather(self.cum_I)).sum()
+ self.cum_U = torch.stack(all_gather(self.cum_U)).sum()
+ self.mIoU = torch.stack(all_gather(self.mIoU)).sum()
+ self.mDice = torch.stack(all_gather(self.mDice)).sum()
+ self.cum_mean_area = torch.stack(all_gather(self.cum_mean_area)).sum()
+ self.seg_correct = torch.stack(all_gather(self.seg_correct)).sum(0)
+ self.seg_total = sum(all_gather(self.seg_total))
+ self.instance_results = sum(all_gather(self.instance_results), [])
+ if self._compute_box:
+ self.mIoU_box = torch.stack(all_gather(self.mIoU_box)).sum()
+ self.seg_correct_box = torch.stack(all_gather(self.seg_correct_box)).sum(0)
+ if not is_main_process():
+ return
+
+ results = {}
+ for idx in range(len(self.eval_seg_iou_list)):
+ result_str = 'precision@{}'.format(self.eval_seg_iou_list[idx])
+ results[result_str] = (self.seg_correct[idx]*100 / self.seg_total).item()
+ results['cIoU'] = (self.cum_I*100./self.cum_U).item()
+ results['mIoU'] = (self.mIoU*100./self.seg_total).item()
+ results['cDice'] = (self.cum_I*100./self.cum_mean_area).item()
+ results['mDice'] = (self.mDice*100./self.seg_total).item()
+
+ if self._compute_box:
+ for idx in range(len(self.eval_seg_iou_list)):
+ result_str = 'precisionB@{}'.format(self.eval_seg_iou_list[idx])
+ results[result_str] = (self.seg_correct_box[idx]*100 / self.seg_total).item()
+ results['mBIoU'] = (self.mIoU_box*100./self.seg_total).item()
+
+ self._logger.info(results)
+ return {'grounding': {'scores': results, 'instance_results': self.instance_results}}
\ No newline at end of file
diff --git a/datasets/evaluation/instance_evaluation.py b/datasets/evaluation/instance_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc2facec351e5f6ee965ee9acb4394f12c023f54
--- /dev/null
+++ b/datasets/evaluation/instance_evaluation.py
@@ -0,0 +1,107 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import contextlib
+import copy
+import io
+import itertools
+import json
+import logging
+import numpy as np
+import os
+import pickle
+from collections import OrderedDict
+import pycocotools.mask as mask_util
+import torch
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+from tabulate import tabulate
+
+import detectron2.utils.comm as comm
+from detectron2.config import CfgNode
+from detectron2.data import MetadataCatalog
+from detectron2.data.datasets.coco import convert_to_coco_json
+from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco
+from detectron2.evaluation.fast_eval_api import COCOeval_opt
+from detectron2.structures import Boxes, BoxMode, pairwise_iou
+from detectron2.utils.file_io import PathManager
+from detectron2.utils.logger import create_small_table
+
+
+# modified from COCOEvaluator for instance segmetnat
+class InstanceSegEvaluator(COCOEvaluator):
+ """
+ Evaluate AR for object proposals, AP for instance detection/segmentation, AP
+ for keypoint detection outputs using COCO's metrics.
+ See http://cocodataset.org/#detection-eval and
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
+ the metric cannot be computed (e.g. due to no predictions made).
+
+ In addition to COCO, this evaluator is able to support any bounding box detection,
+ instance segmentation, or keypoint detection dataset.
+ """
+
+ def _eval_predictions(self, predictions, img_ids=None):
+ """
+ Evaluate predictions. Fill self._results with the metrics of the tasks.
+ """
+ self._logger.info("Preparing results for COCO format ...")
+ coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
+ tasks = self._tasks or self._tasks_from_predictions(coco_results)
+
+ # unmap the category ids for COCO
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
+ dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
+ # all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
+ # num_classes = len(all_contiguous_ids)
+ # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
+
+ reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
+ for result in coco_results:
+ category_id = result["category_id"]
+ # assert category_id < num_classes, (
+ # f"A prediction has class={category_id}, "
+ # f"but the dataset only has {num_classes} classes and "
+ # f"predicted class id should be in [0, {num_classes - 1}]."
+ # )
+ assert category_id in reverse_id_mapping, (
+ f"A prediction has class={category_id}, "
+ f"but the dataset only has class ids in {dataset_id_to_contiguous_id}."
+ )
+ result["category_id"] = reverse_id_mapping[category_id]
+
+ if self._output_dir:
+ file_path = os.path.join(self._output_dir, "coco_instances_results.json")
+ self._logger.info("Saving results to {}".format(file_path))
+ with PathManager.open(file_path, "w") as f:
+ f.write(json.dumps(coco_results))
+ f.flush()
+
+ if not self._do_evaluation:
+ self._logger.info("Annotations are not available for evaluation.")
+ return
+
+ self._logger.info(
+ "Evaluating predictions with {} COCO API...".format(
+ "unofficial" if self._use_fast_impl else "official"
+ )
+ )
+ for task in sorted(tasks):
+ assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
+ coco_eval = (
+ _evaluate_predictions_on_coco(
+ self._coco_api,
+ coco_results,
+ task,
+ kpt_oks_sigmas=self._kpt_oks_sigmas,
+ use_fast_impl=self._use_fast_impl,
+ img_ids=img_ids,
+ max_dets_per_image=self._max_dets_per_image,
+ )
+ if len(coco_results) > 0
+ else None # cocoapi does not handle empty results very well
+ )
+
+ res = self._derive_coco_results(
+ coco_eval, task, class_names=self._metadata.get("thing_classes")
+ )
+ self._results[task] = res
diff --git a/datasets/evaluation/interactive_evaluation.py b/datasets/evaluation/interactive_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa69f795923ac6e50def4c5b4423505326cc0b9d
--- /dev/null
+++ b/datasets/evaluation/interactive_evaluation.py
@@ -0,0 +1,122 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import os
+
+import numpy as np
+import torch
+from torchvision.ops import box_iou
+
+from detectron2.structures import BoxMode
+from detectron2.data import MetadataCatalog
+from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize
+from detectron2.evaluation.evaluator import DatasetEvaluator
+
+
+class InteractiveEvaluator(DatasetEvaluator):
+ """
+ Evaluate point interactive IoU metrics.
+ """
+
+ def __init__(
+ self,
+ dataset_name,
+ output_dir,
+ max_clicks=20,
+ iou_iter=1,
+ compute_box=False,
+ distributed=True,
+ ):
+ self._logger = logging.getLogger(__name__)
+ self._dataset_name = dataset_name
+ self._distributed = distributed
+ self._cpu_device = torch.device("cpu")
+ self._output_dir = output_dir
+
+ self.max_clicks = max_clicks
+ self.iou_iter = iou_iter
+ meta = MetadataCatalog.get(dataset_name)
+
+ def reset(self):
+ self.iou_list = []
+ self.num_samples = 0
+ self.all_ious = [0.5, 0.8, 0.85, 0.9]
+
+ def process(self, inputs, outputs):
+ self.iou_list += [o['mask_iou'] for o in outputs]
+ self.num_samples += len(outputs)
+
+ def compute_noc(self):
+ def _get_noc(iou_arr, iou_thr):
+ vals = iou_arr >= iou_thr
+ return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks
+
+ noc_list = {}
+ for iou_thr in self.all_ious:
+ scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list]
+ noc_list[str(iou_thr)] = scores_arr
+
+ iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1]
+ noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()}
+
+ if self._distributed:
+ num_samples = sum(all_gather(self.num_samples))
+ noc_list_sum_gather = all_gather(noc_list_sum)
+ iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu())
+
+ noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]}
+ for nlg in noc_list_sum_gather:
+ for key, value in nlg.items():
+ noc_list_sum[key] += value
+
+ pred_noc = {}
+ if self._distributed and (not is_main_process()):
+ return pred_noc
+
+ for key, value in noc_list_sum.items():
+ pred_noc[key] = value / num_samples
+
+ pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples
+ return pred_noc
+
+ def evaluate(self):
+ pred_noc = self.compute_noc()
+
+ if self._distributed and (not is_main_process()):
+ return
+
+ def draw_iou_curve(iou_list, save_dir):
+ iou_list = torch.stack(iou_list, dim=0)
+ iou_list = iou_list.mean(dim=0).cpu().numpy()
+ # draw iou curve, with x-axis as number of clicks, y-axis as iou using matplotlib
+ import matplotlib.pyplot as plt
+ plt.figure()
+ plt.plot(range(1, self.max_clicks+1), iou_list)
+ plt.xlabel('Number of clicks')
+ plt.ylabel('IoU')
+
+
+ # create directory if not exist
+ import os
+ output_dir = os.path.join(save_dir, 'iou_by_clicks')
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ # get current time and format in 10 digits
+ import time
+ current_time = time.time()
+ current_time = int(current_time)
+ current_time = str(current_time)
+
+ # save iou curve
+ plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time)))
+
+ draw_iou_curve(self.iou_list, self._output_dir)
+ results = {}
+ for idx in range(len(self.all_ious)):
+ result_str = 'noc@{}'.format(self.all_ious[idx])
+ results[result_str] = pred_noc[str(self.all_ious[idx])]
+
+ results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter']
+
+ self._logger.info(results)
+ return {'interactive': results}
\ No newline at end of file
diff --git a/datasets/evaluation/panoptic_evaluation.py b/datasets/evaluation/panoptic_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba419a69447122d3f8c5c08ad5c048dc22b1e984
--- /dev/null
+++ b/datasets/evaluation/panoptic_evaluation.py
@@ -0,0 +1,199 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import contextlib
+import io
+import itertools
+import json
+import logging
+import numpy as np
+import os
+import tempfile
+from collections import OrderedDict
+from typing import Optional
+from PIL import Image
+from tabulate import tabulate
+
+from detectron2.data import MetadataCatalog
+from detectron2.utils import comm
+from detectron2.utils.file_io import PathManager
+
+from detectron2.evaluation.evaluator import DatasetEvaluator
+
+logger = logging.getLogger(__name__)
+
+
+class COCOPanopticEvaluator(DatasetEvaluator):
+ """
+ Evaluate Panoptic Quality metrics on COCO using PanopticAPI.
+ It saves panoptic segmentation prediction in `output_dir`
+
+ It contains a synchronize call and has to be called from all workers.
+ """
+
+ def __init__(self, dataset_name: str, output_dir: Optional[str] = None):
+ """
+ Args:
+ dataset_name: name of the dataset
+ output_dir: output directory to save results for evaluation.
+ """
+ self._metadata = MetadataCatalog.get(dataset_name)
+ self._thing_contiguous_id_to_dataset_id = {
+ v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
+ }
+ self._stuff_contiguous_id_to_dataset_id = {
+ v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items()
+ }
+
+ self._output_dir = output_dir
+ if self._output_dir is not None:
+ PathManager.mkdirs(self._output_dir)
+
+ def reset(self):
+ self._predictions = []
+
+ def _convert_category_id(self, segment_info):
+ isthing = segment_info.pop("isthing", None)
+ if isthing is None:
+ # the model produces panoptic category id directly. No more conversion needed
+ return segment_info
+ if isthing is True:
+ segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[
+ segment_info["category_id"]
+ ]
+ else:
+ segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[
+ segment_info["category_id"]
+ ]
+ return segment_info
+
+ def process(self, inputs, outputs):
+ from panopticapi.utils import id2rgb
+
+ for input, output in zip(inputs, outputs):
+ panoptic_img, segments_info = output["panoptic_seg"]
+ panoptic_img = panoptic_img.cpu().numpy()
+ if segments_info is None:
+ # If "segments_info" is None, we assume "panoptic_img" is a
+ # H*W int32 image storing the panoptic_id in the format of
+ # category_id * label_divisor + instance_id. We reserve -1 for
+ # VOID label, and add 1 to panoptic_img since the official
+ # evaluation script uses 0 for VOID label.
+ label_divisor = self._metadata.label_divisor
+ segments_info = []
+ for panoptic_label in np.unique(panoptic_img):
+ if panoptic_label == -1:
+ # VOID region.
+ continue
+ pred_class = panoptic_label // label_divisor
+ isthing = (
+ pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values()
+ )
+ segments_info.append(
+ {
+ "id": int(panoptic_label) + 1,
+ "category_id": int(pred_class),
+ "isthing": bool(isthing),
+ }
+ )
+ # Official evaluation script uses 0 for VOID label.
+ panoptic_img += 1
+
+ file_name = os.path.basename(input["file_name"])
+ file_name_png = os.path.splitext(file_name)[0] + ".png"
+ with io.BytesIO() as out:
+ Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
+ segments_info = [self._convert_category_id(x) for x in segments_info]
+ self._predictions.append(
+ {
+ "image_id": input["image_id"],
+ "file_name": file_name_png,
+ "png_string": out.getvalue(),
+ "segments_info": segments_info,
+ }
+ )
+
+ def evaluate(self):
+ comm.synchronize()
+
+ self._predictions = comm.gather(self._predictions)
+ self._predictions = list(itertools.chain(*self._predictions))
+ if not comm.is_main_process():
+ return
+
+ # PanopticApi requires local files
+ gt_json = PathManager.get_local_path(self._metadata.panoptic_json)
+ gt_folder = PathManager.get_local_path(self._metadata.panoptic_root)
+
+ with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
+ logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
+ for p in self._predictions:
+ with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
+ f.write(p.pop("png_string"))
+
+ with open(gt_json, "r") as f:
+ json_data = json.load(f)
+ json_data["annotations"] = self._predictions
+
+ output_dir = self._output_dir or pred_dir
+ predictions_json = os.path.join(output_dir, "predictions.json")
+ with PathManager.open(predictions_json, "w") as f:
+ f.write(json.dumps(json_data))
+
+ from panopticapi.evaluation import pq_compute
+
+ with contextlib.redirect_stdout(io.StringIO()):
+ pq_res = pq_compute(
+ gt_json,
+ PathManager.get_local_path(predictions_json),
+ gt_folder=gt_folder,
+ pred_folder=pred_dir,
+ )
+
+ res = {}
+ res["PQ"] = 100 * pq_res["All"]["pq"]
+ res["SQ"] = 100 * pq_res["All"]["sq"]
+ res["RQ"] = 100 * pq_res["All"]["rq"]
+ res["PQ_th"] = 100 * pq_res["Things"]["pq"]
+ res["SQ_th"] = 100 * pq_res["Things"]["sq"]
+ res["RQ_th"] = 100 * pq_res["Things"]["rq"]
+ res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
+ res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
+ res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
+
+ results = OrderedDict({"panoptic_seg": res})
+ _print_panoptic_results(pq_res)
+
+ return results
+
+
+def _print_panoptic_results(pq_res):
+ headers = ["", "PQ", "SQ", "RQ", "#categories"]
+ data = []
+ for name in ["All", "Things", "Stuff"]:
+ row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
+ data.append(row)
+ table = tabulate(
+ data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
+ )
+ logger.info("Panoptic Evaluation Results:\n" + table)
+
+
+if __name__ == "__main__":
+ from detectron2.utils.logger import setup_logger
+
+ logger = setup_logger()
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--gt-json")
+ parser.add_argument("--gt-dir")
+ parser.add_argument("--pred-json")
+ parser.add_argument("--pred-dir")
+ args = parser.parse_args()
+
+ from panopticapi.evaluation import pq_compute
+
+ with contextlib.redirect_stdout(io.StringIO()):
+ pq_res = pq_compute(
+ args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir
+ )
+ _print_panoptic_results(pq_res)
diff --git a/datasets/evaluation/retrieval_evaluation.py b/datasets/evaluation/retrieval_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b290c94bfc1f8fba587cea5024cf1763c797c6f
--- /dev/null
+++ b/datasets/evaluation/retrieval_evaluation.py
@@ -0,0 +1,260 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu), Ziyi Dou (zdou@cs.ucla.edu)
+# --------------------------------------------------------
+import copy
+import itertools
+import logging
+from collections import OrderedDict
+import torch
+from pycocotools.cocoeval import COCOeval
+
+import detectron2.utils.comm as comm
+from detectron2.evaluation.evaluator import DatasetEvaluator
+
+try:
+ from detectron2.evaluation.fast_eval_api import COCOeval_opt
+except ImportError:
+ COCOeval_opt = COCOeval
+
+
+class RetrievalEvaluator(DatasetEvaluator):
+ """
+ Evaluate AR for object proposals, AP for instance detection/segmentation, AP
+ for keypoint detection outputs using COCO's metrics.
+ See http://cocodataset.org/#detection-eval and
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
+ the metric cannot be computed (e.g. due to no predictions made).
+ In addition to COCO, this evaluator is able to support any bounding box detection,
+ instance segmentation, or keypoint detection dataset.
+ """
+
+ def __init__(
+ self,
+ dataset_name=None,
+ output_dir=None,
+ ensemble=False,
+ distributed=True,
+ ):
+ """
+ Args:
+ dataset_name (str): name of the dataset to be evaluated.
+ It must have either the following corresponding metadata:
+ "json_file": the path to the COCO format annotation
+ Or it must be in detectron2's standard dataset format
+ so it can be converted to COCO format automatically.
+ tasks (tuple[str]): tasks that can be evaluated under the given
+ configuration. A task is one of "bbox", "segm", "keypoints".
+ By default, will infer this automatically from predictions.
+ distributed (True): if True, will collect results from all ranks and run evaluation
+ in the main process.
+ Otherwise, will only evaluate the results in the current process.
+ output_dir (str): optional, an output directory to dump all
+ results predicted on the dataset. The dump contains two files:
+ 1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
+ contains all the results in the format they are produced by the model.
+ 2. "coco_instances_results.json" a json file in COCO's result format.
+ max_dets_per_image (int): limit on the maximum number of detections per image.
+ By default in COCO, this limit is to 100, but this can be customized
+ to be greater, as is needed in evaluation metrics AP fixed and AP pool
+ (see https://arxiv.org/pdf/2102.01066.pdf)
+ This doesn't affect keypoint evaluation.
+ use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
+ Although the results should be very close to the official implementation in COCO
+ API, it is still recommended to compute results with the official API for use in
+ papers. The faster implementation also uses more RAM.
+ kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
+ See http://cocodataset.org/#keypoints-eval
+ When empty, it will use the defaults in COCO.
+ Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
+ allow_cached_coco (bool): Whether to use cached coco json from previous validation
+ runs. You should set this to False if you need to use different validation data.
+ Defaults to True.
+ """
+ self._logger = logging.getLogger(__name__)
+ self._dataset_name = dataset_name
+ self._output_dir = output_dir
+ self._ensemble = ensemble
+ self._distributed = distributed
+
+ if 'p2i' in dataset_name:
+ self.mode = 'patch2image'
+ elif 'interactive2i' in dataset_name:
+ self.mode = 'interactive2image'
+ else:
+ self.mode = 'default'
+
+ def reset(self):
+ self._text_embeds = []
+ self._image_embeds = []
+ self._image_embeds2 = []
+ self._text_ids = []
+ self._image_ids = []
+
+ def process(self, inputs, outputs):
+ """
+ Args:
+ inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
+ It is a list of dict. Each dict corresponds to an image and
+ contains keys like "height", "width", "file_name", "image_id".
+ outputs: the outputs of a COCO model. It is a list of dicts with key
+ "instances" that contains :class:`Instances`.
+ """
+ for output in outputs:
+ self._text_ids.extend(output['caption']['caption_ids'])
+ self._image_ids.append(output['caption']['image_ids'])
+ self._text_embeds.append(output['caption']['text_embeds'])
+ self._image_embeds.append(output['caption']['image_embeds'][0])
+ if self._ensemble:
+ self._image_embeds2.append(output['caption']['image_embeds'][1])
+
+ def evaluate(self, img_ids=None):
+ if self.mode == 'default':
+ return self.evaluate_default(img_ids)
+ elif self.mode in ['patch2image', 'interactive2image']:
+ return self.evaluate_p2i(img_ids)
+ else:
+ assert False, "Unknown mode for retrieval evaluation"
+
+ def evaluate_default(self, img_ids=None):
+ """
+ Args:
+ img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
+ """
+
+ if self._distributed:
+ comm.synchronize()
+ def gather(x, move=False):
+ x = comm.gather(x)
+ x = list(itertools.chain(*x))
+ if move:
+ x = [xx.to(self._text_embeds[0].device) for xx in x]
+ return x
+ text_embeds = gather(self._text_embeds, move=True)
+ image_embeds = gather(self._image_embeds, move=True)
+ if self._ensemble:
+ image_embeds2 = gather(self._image_embeds2, move=True)
+ text_ids = gather(self._text_ids)
+ image_ids = gather(self._image_ids)
+ if not comm.is_main_process():
+ return {}
+ else:
+ text_embeds = self._text_embeds
+ image_embeds = self._image_embeds
+ if self._ensemble:
+ image_embeds2 = self._image_embeds2
+ text_ids = self._text_ids
+ image_ids = self._image_ids
+ if len(text_embeds) == 0:
+ self._logger.warning("[COCOCaptionEvaluator] Did not receive valid predictions.")
+ return {}
+ iids = torch.tensor(image_ids).view(-1)
+ tiids = torch.tensor(text_ids).view(-1)
+ image_embeds = torch.cat(image_embeds)
+ text_embeds = torch.cat(text_embeds)
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
+ scores = image_embeds @ text_embeds.t()
+
+ if self._ensemble:
+ image_embeds2 = torch.cat(image_embeds2)
+ image_embeds2 = image_embeds2 / image_embeds2.norm(dim=-1, keepdim=True)
+ scores2 = image_embeds2 @ text_embeds.t()
+ scores = scores2 * 0.5 + scores * 0.5
+
+ topk10 = scores.topk(10, dim=1)
+ topk5 = scores.topk(5, dim=1)
+ topk1 = scores.topk(1, dim=1)
+ topk10_iids = tiids[topk10.indices]
+ topk5_iids = tiids[topk5.indices]
+ topk1_iids = tiids[topk1.indices]
+ tr_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
+ tr_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
+ tr_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()
+ topk10 = scores.topk(10, dim=0)
+ topk5 = scores.topk(5, dim=0)
+ topk1 = scores.topk(1, dim=0)
+ topk10_iids = iids[topk10.indices]
+ topk5_iids = iids[topk5.indices]
+ topk1_iids = iids[topk1.indices]
+ ir_r10 = (tiids.unsqueeze(0) == topk10_iids).float().max(dim=0)[0].mean()
+ ir_r5 = (tiids.unsqueeze(0) == topk5_iids).float().max(dim=0)[0].mean()
+ ir_r1 = (tiids.unsqueeze(0) == topk1_iids).float().max(dim=0)[0].mean()
+ self._results = OrderedDict()
+ # Copy so the caller can do whatever with results
+ self._results['recall'] = {}
+ self._results['recall']['irtr'] = float("{:.3f}".format((ir_r1 + tr_r1).item() * 100))
+ self._results['recall']['ir1'] = float("{:.3f}".format(ir_r1.item() * 100))
+ self._results['recall']['ir5'] = float("{:.3f}".format(ir_r5.item() * 100))
+ self._results['recall']['ir10'] = float("{:.3f}".format(ir_r10.item() * 100))
+ self._results['recall']['tr1'] = float("{:.3f}".format(tr_r1.item() * 100))
+ self._results['recall']['tr5'] = float("{:.3f}".format(tr_r5.item() * 100))
+ self._results['recall']['tr10'] = float("{:.3f}".format(tr_r10.item() * 100))
+ self._logger.info(self._results)
+ return copy.deepcopy(self._results)
+
+ def evaluate_p2i(self, img_ids=None):
+ """
+ Args:
+ img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
+ """
+
+ if self._distributed:
+ comm.synchronize()
+ def gather(x, move=False):
+ x = comm.gather(x)
+ x = list(itertools.chain(*x))
+ if move:
+ x = [xx.to(self._text_embeds[0].device) for xx in x]
+ return x
+ text_embeds = gather(self._text_embeds, move=True)
+ image_embeds = gather(self._image_embeds, move=True)
+ image_embeds2 = gather(self._image_embeds2, move=True)
+ text_ids = gather(self._text_ids)
+ image_ids = gather(self._image_ids)
+ if not comm.is_main_process():
+ return {}
+ else:
+ text_embeds = self._text_embeds
+ image_embeds = self._image_embeds
+ image_embeds2 = self._image_embeds2
+ text_ids = self._text_ids
+ image_ids = self._image_ids
+
+ if len(text_embeds) == 0:
+ self._logger.warning("[COCOCaptionEvaluator] Did not receive valid predictions.")
+ return {}
+
+ iids = torch.tensor(image_ids).view(-1)
+ tiids = torch.tensor(text_ids).view(-1)
+ image_embeds = torch.cat(image_embeds)
+ text_embeds = torch.cat(text_embeds)
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
+
+ image_embeds2 = torch.cat(image_embeds2)
+ image_embeds2 = image_embeds2 / image_embeds2.norm(dim=-1, keepdim=True)
+
+ # compute image to image retrieval
+ self._results = OrderedDict()
+ self._results['recall'] = {}
+ ii_scores = image_embeds2 @ image_embeds.t()
+
+ topk10 = ii_scores.topk(10, dim=1)
+ topk5 = ii_scores.topk(5, dim=1)
+ topk1 = ii_scores.topk(1, dim=1)
+ topk10_iids = iids[topk10.indices]
+ topk5_iids = iids[topk5.indices]
+ topk1_iids = iids[topk1.indices]
+ iir_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
+ iir_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
+ iir_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()
+ # Copy so the caller can do whatever with results
+ self._results['recall']['p2ir1'] = float("{:.3f}".format(iir_r1.item() * 100))
+ self._results['recall']['p2ir5'] = float("{:.3f}".format(iir_r5.item() * 100))
+ self._results['recall']['p2ir10'] = float("{:.3f}".format(iir_r10.item() * 100))
+ self._logger.info(self._results)
+ return copy.deepcopy(self._results)
\ No newline at end of file
diff --git a/datasets/evaluation/segmentation_evaluation.py b/datasets/evaluation/segmentation_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..98a14ce385dc0969d582f1f58287168a2dd491b4
--- /dev/null
+++ b/datasets/evaluation/segmentation_evaluation.py
@@ -0,0 +1,195 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import itertools
+import json
+import logging
+import numpy as np
+import os
+from collections import OrderedDict
+import PIL.Image as Image
+import pycocotools.mask as mask_util
+import torch
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.utils.comm import all_gather, is_main_process
+from detectron2.utils.file_io import PathManager
+from detectron2.evaluation.evaluator import DatasetEvaluator
+from utilities.distributed import synchronize
+
+from ..semseg_loader import load_semseg
+
+
+class SemSegEvaluator(DatasetEvaluator):
+ """
+ Evaluate semantic segmentation metrics.
+ """
+
+ def __init__(
+ self,
+ dataset_name,
+ distributed=True,
+ output_dir=None,
+ *,
+ num_classes=None,
+ ignore_label=None,
+ ):
+ """
+ Args:
+ dataset_name (str): name of the dataset to be evaluated.
+ distributed (bool): if True, will collect results from all ranks for evaluation.
+ Otherwise, will evaluate the results in the current process.
+ output_dir (str): an output directory to dump results.
+ num_classes, ignore_label: deprecated argument
+ """
+ self._logger = logging.getLogger(__name__)
+ if num_classes is not None:
+ self._logger.warn(
+ "SemSegEvaluator(num_classes) is deprecated! It should be obtained from metadata."
+ )
+ if ignore_label is not None:
+ self._logger.warn(
+ "SemSegEvaluator(ignore_label) is deprecated! It should be obtained from metadata."
+ )
+ self._dataset_name = dataset_name
+ self._distributed = distributed
+ self._output_dir = output_dir
+
+ self._cpu_device = torch.device("cpu")
+
+ self.input_file_to_gt_file = {
+ dataset_record["file_name"]: dataset_record["sem_seg_file_name"]
+ for dataset_record in DatasetCatalog.get(dataset_name)
+ }
+
+ meta = MetadataCatalog.get(dataset_name)
+ # Dict that maps contiguous training ids to COCO category ids
+ try:
+ c2d = meta.stuff_dataset_id_to_contiguous_id
+ self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()}
+ except AttributeError:
+ self._contiguous_id_to_dataset_id = None
+ self._class_names = meta.stuff_classes
+ self._class_offset = meta.class_offset if hasattr(meta, 'class_offset') else 0
+ self._num_classes = len(meta.stuff_classes)
+ self._semseg_loader = meta.semseg_loader if hasattr(meta, 'semseg_loader') else 'PIL'
+
+ if num_classes is not None:
+ assert self._num_classes == num_classes, f"{self._num_classes} != {num_classes}"
+ self._ignore_label = ignore_label if ignore_label is not None else meta.ignore_label
+
+ def reset(self):
+ self._conf_matrix = np.zeros((self._num_classes + 1, self._num_classes + 1), dtype=np.int64)
+ self._predictions = []
+
+ def process(self, inputs, outputs):
+ """
+ Args:
+ inputs: the inputs to a model.
+ It is a list of dicts. Each dict corresponds to an image and
+ contains keys like "height", "width", "file_name".
+ outputs: the outputs of a model. It is either list of semantic segmentation predictions
+ (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
+ segmentation prediction in the same format.
+ """
+ for input, output in zip(inputs, outputs):
+ output = output["sem_seg"].argmax(dim=0).to(self._cpu_device)
+ pred = np.array(output, dtype=np.int)
+
+ with PathManager.open(self.input_file_to_gt_file[input["file_name"]], "rb") as f:
+ gt = load_semseg(f, self._semseg_loader) - self._class_offset
+
+ if isinstance(self._ignore_label, int):
+ ignore_label = self._ignore_label - self._class_offset
+ gt[gt == self._ignore_label] = self._num_classes
+ elif isinstance(self._ignore_label, list):
+ for ignore_label in self._ignore_label:
+ ignore_label = ignore_label - self._class_offset
+ gt[gt == ignore_label] = self._num_classes
+
+ self._conf_matrix += np.bincount(
+ (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
+ minlength=self._conf_matrix.size,
+ ).reshape(self._conf_matrix.shape)
+
+ self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
+
+ def evaluate(self):
+ """
+ Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
+
+ * Mean intersection-over-union averaged across classes (mIoU)
+ * Frequency Weighted IoU (fwIoU)
+ * Mean pixel accuracy averaged across classes (mACC)
+ * Pixel Accuracy (pACC)
+ """
+ if self._distributed:
+ synchronize()
+ conf_matrix_list = all_gather(self._conf_matrix)
+ self._predictions = all_gather(self._predictions)
+ self._predictions = list(itertools.chain(*self._predictions))
+ if not is_main_process():
+ return
+ self._conf_matrix = np.zeros_like(self._conf_matrix)
+ for conf_matrix in conf_matrix_list:
+ self._conf_matrix += conf_matrix
+
+ if self._output_dir:
+ PathManager.mkdirs(self._output_dir)
+ file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
+ with PathManager.open(file_path, "w") as f:
+ f.write(json.dumps(self._predictions))
+
+ acc = np.full(self._num_classes, np.nan, dtype=np.float)
+ iou = np.full(self._num_classes, np.nan, dtype=np.float)
+ tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
+ pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
+ class_weights = pos_gt / np.sum(pos_gt)
+ pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
+ acc_valid = pos_gt > 0
+ acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
+ iou_valid = (pos_gt + pos_pred) > 0
+ union = pos_gt + pos_pred - tp
+ iou[acc_valid] = tp[acc_valid] / union[acc_valid]
+ macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
+ miou = np.sum(iou[acc_valid]) / np.sum(iou_valid)
+ fiou = np.sum(iou[acc_valid] * class_weights[acc_valid])
+ pacc = np.sum(tp) / np.sum(pos_gt)
+
+ res = {}
+ res["mIoU"] = 100 * miou
+ res["fwIoU"] = 100 * fiou
+ for i, name in enumerate(self._class_names):
+ res["IoU-{}".format(name)] = 100 * iou[i]
+ res["mACC"] = 100 * macc
+ res["pACC"] = 100 * pacc
+ for i, name in enumerate(self._class_names):
+ res["ACC-{}".format(name)] = 100 * acc[i]
+
+ if self._output_dir:
+ file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
+ with PathManager.open(file_path, "wb") as f:
+ torch.save(res, f)
+ results = OrderedDict({"sem_seg": res})
+ self._logger.info(results)
+ return results
+
+ def encode_json_sem_seg(self, sem_seg, input_file_name):
+ """
+ Convert semantic segmentation to COCO stuff format with segments encoded as RLEs.
+ See http://cocodataset.org/#format-results
+ """
+ json_list = []
+ for label in np.unique(sem_seg):
+ if self._contiguous_id_to_dataset_id is not None:
+ assert (
+ label in self._contiguous_id_to_dataset_id
+ ), "Label {} is not in the metadata info for {}".format(label, self._dataset_name)
+ dataset_id = self._contiguous_id_to_dataset_id[label]
+ else:
+ dataset_id = int(label)
+ mask = (sem_seg == label).astype(np.uint8)
+ mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0]
+ mask_rle["counts"] = mask_rle["counts"].decode("utf-8")
+ json_list.append(
+ {"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle}
+ )
+ return json_list
diff --git a/datasets/refer.py b/datasets/refer.py
new file mode 100644
index 0000000000000000000000000000000000000000..733145647a061a92da4ca4e34bdf957dd9f5db00
--- /dev/null
+++ b/datasets/refer.py
@@ -0,0 +1,371 @@
+__author__ = 'licheng'
+
+"""
+This interface provides access to four datasets:
+1) refclef
+2) refcoco
+3) refcoco+
+4) refcocog
+split by unc and google
+
+The following API functions are defined:
+REFER - REFER api class
+getRefIds - get ref ids that satisfy given filter conditions.
+getAnnIds - get ann ids that satisfy given filter conditions.
+getImgIds - get image ids that satisfy given filter conditions.
+getCatIds - get category ids that satisfy given filter conditions.
+loadRefs - load refs with the specified ref ids.
+loadAnns - load anns with the specified ann ids.
+loadImgs - load images with the specified image ids.
+loadCats - load category names with the specified category ids.
+getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
+showRef - show image, segmentation or box of the referred object with the ref
+getMask - get mask and area of the referred object given ref
+showMask - show mask of the referred object given ref
+"""
+
+from doctest import REPORT_ONLY_FIRST_FAILURE
+import sys
+import os.path as osp
+import json
+import pickle
+import time
+import itertools
+import skimage.io as io
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon, Rectangle
+from pprint import pprint
+import numpy as np
+from pycocotools import mask
+# import cv2
+# from skimage.measure import label, regionprops
+
+
+class REFER:
+ def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
+ # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
+ # also provide dataset name and splitBy information
+ # e.g., dataset = 'refcoco', splitBy = 'unc'
+ print('loading dataset {} into memory...'.format(dataset))
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
+ self.DATA_DIR = osp.join(data_root, dataset)
+ if dataset in ['refcoco', 'refcoco+', 'refcocog']:
+ self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
+ elif dataset == 'refclef':
+ self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
+ else:
+ print('No refer dataset is called [{}]'.format(dataset))
+ sys.exit()
+
+ # load refs from data/dataset/refs(dataset).json
+ tic = time.time()
+ ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
+ self.data = {}
+ self.data['dataset'] = dataset
+ self.data['refs'] = pickle.load(open(ref_file, 'rb'))
+
+ # load annotations from data/dataset/instances.json
+ instances_file = osp.join(self.DATA_DIR, 'instances.json')
+ instances = json.load(open(instances_file, 'r'))
+ self.data['images'] = instances['images']
+ self.data['annotations'] = instances['annotations']
+ self.data['categories'] = instances['categories']
+
+ # create index
+ self.createIndex()
+ print('DONE (t=%.2fs)'.format(time.time()-tic))
+
+ def createIndex(self):
+ # create sets of mapping
+ # 1) Refs: {ref_id: ref}
+ # 2) Anns: {ann_id: ann}
+ # 3) Imgs: {image_id: image}
+ # 4) Cats: {category_id: category_name}
+ # 5) Sents: {sent_id: sent}
+ # 6) imgToRefs: {image_id: refs}
+ # 7) imgToAnns: {image_id: anns}
+ # 8) refToAnn: {ref_id: ann}
+ # 9) annToRef: {ann_id: ref}
+ # 10) catToRefs: {category_id: refs}
+ # 11) sentToRef: {sent_id: ref}
+ # 12) sentToTokens: {sent_id: tokens}
+ print('creating index...')
+ # fetch info from instances
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
+ for ann in self.data['annotations']:
+ Anns[ann['id']] = ann
+ imgToAnns[ann['image_id']] = imgToAnns.get(
+ ann['image_id'], []) + [ann]
+ for img in self.data['images']:
+ Imgs[img['id']] = img
+ for cat in self.data['categories']:
+ Cats[cat['id']] = cat['name']
+
+ # fetch info from refs
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
+ Sents, sentToRef, sentToTokens = {}, {}, {}
+ for ref in self.data['refs']:
+ # ids
+ ref_id = ref['ref_id']
+ ann_id = ref['ann_id']
+ category_id = ref['category_id']
+ image_id = ref['image_id']
+
+ # add mapping related to ref
+ Refs[ref_id] = ref
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
+ refToAnn[ref_id] = Anns[ann_id]
+ annToRef[ann_id] = ref
+
+ # add mapping of sent
+ for sent in ref['sentences']:
+ Sents[sent['sent_id']] = sent
+ sentToRef[sent['sent_id']] = ref
+ sentToTokens[sent['sent_id']] = sent['tokens']
+
+ # create class members
+ self.Refs = Refs
+ self.Anns = Anns
+ self.Imgs = Imgs
+ self.Cats = Cats
+ self.Sents = Sents
+ self.imgToRefs = imgToRefs
+ self.imgToAnns = imgToAnns
+ self.refToAnn = refToAnn
+ self.annToRef = annToRef
+ self.catToRefs = catToRefs
+ self.sentToRef = sentToRef
+ self.sentToTokens = sentToTokens
+ print('index created.')
+
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
+ refs = self.data['refs']
+ else:
+ if not len(image_ids) == 0:
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
+ else:
+ refs = self.data['refs']
+ if not len(cat_ids) == 0:
+ refs = [ref for ref in refs if ref['category_id'] in cat_ids]
+ if not len(ref_ids) == 0:
+ refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
+ if not len(split) == 0:
+ if split in ['testA', 'testB', 'testC']:
+ # we also consider testAB, testBC, ...
+ refs = [ref for ref in refs if split[-1] in ref['split']]
+ elif split in ['testAB', 'testBC', 'testAC']:
+ # rarely used I guess...
+ refs = [ref for ref in refs if ref['split'] == split]
+ elif split == 'test':
+ refs = [ref for ref in refs if 'test' in ref['split']]
+ elif split == 'train' or split == 'val':
+ refs = [ref for ref in refs if ref['split'] == split]
+ else:
+ print('No such split [{}]'.format(split))
+ sys.exit()
+ ref_ids = [ref['ref_id'] for ref in refs]
+ return ref_ids
+
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
+ ann_ids = [ann['id'] for ann in self.data['annotations']]
+ else:
+ if not len(image_ids) == 0:
+ lists = [self.imgToAnns[image_id]
+ for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
+ anns = list(itertools.chain.from_iterable(lists))
+ else:
+ anns = self.data['annotations']
+ if not len(cat_ids) == 0:
+ anns = [ann for ann in anns if ann['category_id'] in cat_ids]
+ ann_ids = [ann['id'] for ann in anns]
+ if not len(ref_ids) == 0:
+ ids = set(ann_ids).intersection(
+ set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
+ return ann_ids
+
+ def getImgIds(self, ref_ids=[]):
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if not len(ref_ids) == 0:
+ image_ids = list(set([self.Refs[ref_id]['image_id']
+ for ref_id in ref_ids]))
+ else:
+ image_ids = self.Imgs.keys()
+ return image_ids
+
+ def getCatIds(self):
+ return self.Cats.keys()
+
+ def loadRefs(self, ref_ids=[]):
+ if type(ref_ids) == list:
+ return [self.Refs[ref_id] for ref_id in ref_ids]
+ elif type(ref_ids) == int:
+ return [self.Refs[ref_ids]]
+
+ def loadAnns(self, ann_ids=[]):
+ if type(ann_ids) == list:
+ return [self.Anns[ann_id] for ann_id in ann_ids]
+ elif type(ann_ids) == int or type(ann_ids) == unicode:
+ return [self.Anns[ann_ids]]
+
+ def loadImgs(self, image_ids=[]):
+ if type(image_ids) == list:
+ return [self.Imgs[image_id] for image_id in image_ids]
+ elif type(image_ids) == int:
+ return [self.Imgs[image_ids]]
+
+ def loadCats(self, cat_ids=[]):
+ if type(cat_ids) == list:
+ return [self.Cats[cat_id] for cat_id in cat_ids]
+ elif type(cat_ids) == int:
+ return [self.Cats[cat_ids]]
+
+ def getRefBox(self, ref_id):
+ ref = self.Refs[ref_id]
+ ann = self.refToAnn[ref_id]
+ return ann['bbox'] # [x, y, w, h]
+
+ def showRef(self, ref, seg_box='seg'):
+ ax = plt.gca()
+ # show image
+ image = self.Imgs[ref['image_id']]
+ I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
+ ax.imshow(I)
+ # show refer expression
+ for sid, sent in enumerate(ref['sentences']):
+ print('{}. {}'.format(sid+1, sent['sent']))
+ # show segmentations
+ if seg_box == 'seg':
+ ann_id = ref['ann_id']
+ ann = self.Anns[ann_id]
+ polygons = []
+ color = []
+ c = 'none'
+ if type(ann['segmentation'][0]) == list:
+ # polygon used for refcoco*
+ for seg in ann['segmentation']:
+ poly = np.array(seg).reshape((len(seg)/2, 2))
+ polygons.append(Polygon(poly, True, alpha=0.4))
+ color.append(c)
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(
+ 1, 1, 0, 0), linewidths=3, alpha=1)
+ ax.add_collection(p) # thick yellow polygon
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(
+ 1, 0, 0, 0), linewidths=1, alpha=1)
+ ax.add_collection(p) # thin red polygon
+ else:
+ # mask used for refclef
+ rle = ann['segmentation']
+ m = mask.decode(rle)
+ img = np.ones((m.shape[0], m.shape[1], 3))
+ color_mask = np.array([2.0, 166.0, 101.0])/255
+ for i in range(3):
+ img[:, :, i] = color_mask[i]
+ ax.imshow(np.dstack((img, m*0.5)))
+ # show bounding-box
+ elif seg_box == 'box':
+ ann_id = ref['ann_id']
+ ann = self.Anns[ann_id]
+ bbox = self.getRefBox(ref['ref_id'])
+ box_plot = Rectangle(
+ (bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
+ ax.add_patch(box_plot)
+
+ def getMask(self, ref):
+ # return mask, area and mask-center
+ ann = self.refToAnn[ref['ref_id']]
+ image = self.Imgs[ref['image_id']]
+ if type(ann['segmentation'][0]) == list: # polygon
+ rle = mask.frPyObjects(
+ ann['segmentation'], image['height'], image['width'])
+ else:
+ rle = ann['segmentation']
+ m = mask.decode(rle)
+ # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = np.sum(m, axis=2)
+ m = m.astype(np.uint8) # convert to np.uint8
+ # compute area
+ area = sum(mask.area(rle)) # should be close to ann['area']
+ return {'mask': m, 'area': area}
+ # # position
+ # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
+ # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
+ # # mass position (if there were multiple regions, we use the largest one.)
+ # label_m = label(m, connectivity=m.ndim)
+ # regions = regionprops(label_m)
+ # if len(regions) > 0:
+ # largest_id = np.argmax(np.array([props.filled_area for props in regions]))
+ # largest_props = regions[largest_id]
+ # mass_y, mass_x = largest_props.centroid
+ # else:
+ # mass_x, mass_y = position_x, position_y
+ # # if centroid is not in mask, we find the closest point to it from mask
+ # if m[mass_y, mass_x] != 1:
+ # print 'Finding closes mask point ...'
+ # kernel = np.ones((10, 10),np.uint8)
+ # me = cv2.erode(m, kernel, iterations = 1)
+ # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
+ # points = np.array(points)
+ # dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
+ # id = np.argsort(dist)[0]
+ # mass_y, mass_x = points[id]
+ # # return
+ # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
+ # # show image and mask
+ # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
+ # plt.figure()
+ # plt.imshow(I)
+ # ax = plt.gca()
+ # img = np.ones( (m.shape[0], m.shape[1], 3) )
+ # color_mask = np.array([2.0,166.0,101.0])/255
+ # for i in range(3):
+ # img[:,:,i] = color_mask[i]
+ # ax.imshow(np.dstack( (img, m*0.5) ))
+ # plt.show()
+
+ def showMask(self, ref):
+ M = self.getMask(ref)
+ msk = M['mask']
+ ax = plt.gca()
+ ax.imshow(msk)
+
+
+if __name__ == '__main__':
+ refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg',
+ dataset='refcocog', splitBy='google')
+ ref_ids = refer.getRefIds()
+ print(len(ref_ids))
+
+ print(len(refer.Imgs))
+ print(len(refer.imgToRefs))
+
+ ref_ids = refer.getRefIds(split='train')
+ print('There are {} training referred objects.' % len(ref_ids))
+
+ for ref_id in ref_ids:
+ ref = refer.loadRefs(ref_id)[0]
+ if len(ref['sentences']) < 2:
+ continue
+
+ pprint(ref)
+ print('The label is {}.'.format(refer.Cats[ref['category_id']]))
+
+ # plt.figure()
+ # refer.showRef(ref, seg_box='box')
+ # plt.show()
+
+ # plt.figure()
+ # refer.showMask(ref)
+ # plt.show()
diff --git a/datasets/registration/__init__.py b/datasets/registration/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdb720a74950e13afd7890bb3fd6bbf389801788
--- /dev/null
+++ b/datasets/registration/__init__.py
@@ -0,0 +1,3 @@
+from . import (
+ register_biomed_datasets
+)
\ No newline at end of file
diff --git a/datasets/registration/register_biomed_datasets.py b/datasets/registration/register_biomed_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..d443323069c8f88ef81b87a3bc72f98e4c980d62
--- /dev/null
+++ b/datasets/registration/register_biomed_datasets.py
@@ -0,0 +1,123 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+import json
+import os
+import collections
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
+from detectron2.utils.file_io import PathManager
+
+
+_PREDEFINED_SPLITS_BIOMED = {}
+
+# example of registering a dataset
+datasets = ['BiomedParseData-Demo', ] # provide name of the dataset under biomedparse_datasets
+splits = ['demo'] # provide split name, e.g., train, test, val. Here there is only one 'demo' split in the example demo dataset
+
+# Here we register all the splits of the dataset
+for name in datasets:
+ for split in splits:
+ dataname = f'biomed_{name.replace("/", "-")}_{split}'
+ image_root = f"{name}/{split}"
+ ann_root = f"{name}/{split}.json"
+ _PREDEFINED_SPLITS_BIOMED[dataname] = (image_root, ann_root)
+# The resulting dataset name is: biomed_BiomedParseData-Demo_demo
+
+# # Add your dataset here
+# datasets = ['YOUR_DATASET_NAME', ] # provide name of the dataset under biomedparse_datasets
+# splits = ['train', 'test'] # provide split name, e.g., train, test, val
+
+# # Here we register all the splits of the dataset
+# for name in datasets:
+# for split in splits:
+# dataname = f'biomed_{name.replace("/", "-")}_{split}'
+# image_root = f"{name}/{split}"
+# ann_root = f"{name}/{split}.json"
+# _PREDEFINED_SPLITS_BIOMED[dataname] = (image_root, ann_root)
+# # The resulting dataset names are: biomed_YOUR_DATASET_NAME_train, biomed_YOUR_DATASET_NAME_test
+
+
+def get_metadata():
+ meta = {}
+ return meta
+
+
+def load_biomed_json(image_root, annot_json, metadata):
+ """
+ Args:
+ image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
+ gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
+ json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
+ Returns:
+ list[dict]: a list of dicts in Detectron2 standard format. (See
+ `Using Custom Datasets `_ )
+ """
+
+ with PathManager.open(annot_json) as f:
+ json_info = json.load(f)
+
+ # build dictionary for grounding
+ grd_dict = collections.defaultdict(list)
+ for grd_ann in json_info['annotations']:
+ image_id = int(grd_ann["image_id"])
+ grd_dict[image_id].append(grd_ann)
+
+ mask_root = image_root + '_mask'
+ ret = []
+ for image in json_info["images"]:
+ image_id = int(image["id"])
+ image_file = os.path.join(image_root, image['file_name'])
+ grounding_anno = grd_dict[image_id]
+ for ann in grounding_anno:
+ if 'mask_file' not in ann:
+ ann['mask_file'] = image['file_name']
+ ann['mask_file'] = os.path.join(mask_root, ann['mask_file'])
+ ret.append(
+ {
+ "file_name": image_file,
+ "image_id": image_id,
+ "grounding_info": [ann],
+ }
+ )
+ assert len(ret), f"No images found in {image_root}!"
+ assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
+ return ret
+
+
+def register_biomed(
+ name, metadata, image_root, annot_json):
+ DatasetCatalog.register(
+ name,
+ lambda: load_biomed_json(image_root, annot_json, metadata),
+ )
+ MetadataCatalog.get(name).set(
+ image_root=image_root,
+ json_file=annot_json,
+ evaluator_type="grounding_refcoco",
+ ignore_label=255,
+ label_divisor=1000,
+ **metadata,
+ )
+
+
+def register_all_biomed(root):
+ for (
+ prefix,
+ (image_root, annot_root),
+ ) in _PREDEFINED_SPLITS_BIOMED.items():
+ register_biomed(
+ prefix,
+ get_metadata(),
+ os.path.join(root, image_root),
+ os.path.join(root, annot_root),
+ )
+
+
+_root = os.getenv("DATASET", "datasets")
+register_all_biomed(_root)
diff --git a/datasets/semseg_loader.py b/datasets/semseg_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4e3cb16a28c8022a6c0f5abf28056a748f7e4d5
--- /dev/null
+++ b/datasets/semseg_loader.py
@@ -0,0 +1,10 @@
+from PIL import Image
+import scipy.io
+import numpy as np
+
+def load_semseg(filename, loader_type):
+ if loader_type == 'PIL':
+ semseg = np.array(Image.open(filename), dtype=np.int)
+ elif loader_type == 'MAT':
+ semseg = scipy.io.loadmat(filename)['LabelMap']
+ return semseg
\ No newline at end of file
diff --git a/datasets/utils/refcoco2json.py b/datasets/utils/refcoco2json.py
new file mode 100644
index 0000000000000000000000000000000000000000..a44cfec3ca5bac92bbd0dbd2a089f5b65327e621
--- /dev/null
+++ b/datasets/utils/refcoco2json.py
@@ -0,0 +1,41 @@
+import os
+import json
+from refer import REFER
+
+coco_root = '/pth/to/coco'
+ref_root = '/pth/to/refcocoseg'
+
+coco_train_annot = json.load(open(os.path.join(coco_root, 'annotations/instances_train2017.json')))
+coco_train_id = []
+image_annot = {}
+for i in range(len(coco_train_annot['images'])):
+ coco_train_id.append(coco_train_annot['images'][i]['id'])
+ image_annot[coco_train_annot['images'][i]['id']] = coco_train_annot['images'][i]
+
+refg = REFER(data_root=ref_root,
+ dataset='refcocog', splitBy='umd')
+refg_val_ids = refg.getRefIds(split='val')
+
+full_anno = []
+for ref_id in refg_val_ids:
+ ref = refg.loadRefs(ref_id)[0]
+ anno = refg.refToAnn[ref_id]
+ anno.update(ref)
+ full_anno.append(anno)
+
+imageid_list = []
+final_anno = {}
+for anno in full_anno:
+ imageid_list += [anno['image_id']]
+ final_anno[anno['ann_id']] = anno
+
+annotations = [value for key, value in final_anno.items()]
+
+iamges = []
+for image_id in list(set(imageid_list)):
+ iamges += [image_annot[image_id]]
+
+outputs = {'images': iamges, 'annotations': annotations}
+print(len(iamges))
+print(len(annotations))
+json.dump(outputs, open(os.path.join(coco_root, 'annotations/refcocog_umd_train.json'), 'w'))
diff --git a/datasets/utils/refer.py b/datasets/utils/refer.py
new file mode 100644
index 0000000000000000000000000000000000000000..674aec152f0f7da1673a7c6b830eda4fb7f96d57
--- /dev/null
+++ b/datasets/utils/refer.py
@@ -0,0 +1,372 @@
+# This code is modified from https://github.com/lichengunc/refer, and with minor modification of python2/3 format
+__author__ = 'licheng'
+
+"""
+This interface provides access to four datasets:
+1) refclef
+2) refcoco
+3) refcoco+
+4) refcocog
+split by unc and google
+
+The following API functions are defined:
+REFER - REFER api class
+getRefIds - get ref ids that satisfy given filter conditions.
+getAnnIds - get ann ids that satisfy given filter conditions.
+getImgIds - get image ids that satisfy given filter conditions.
+getCatIds - get category ids that satisfy given filter conditions.
+loadRefs - load refs with the specified ref ids.
+loadAnns - load anns with the specified ann ids.
+loadImgs - load images with the specified image ids.
+loadCats - load category names with the specified category ids.
+getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
+showRef - show image, segmentation or box of the referred object with the ref
+getMask - get mask and area of the referred object given ref
+showMask - show mask of the referred object given ref
+"""
+
+from doctest import REPORT_ONLY_FIRST_FAILURE
+import sys
+import os.path as osp
+import json
+import pickle
+import time
+import itertools
+import skimage.io as io
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon, Rectangle
+from pprint import pprint
+import numpy as np
+from pycocotools import mask
+# import cv2
+# from skimage.measure import label, regionprops
+
+
+class REFER:
+ def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
+ # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
+ # also provide dataset name and splitBy information
+ # e.g., dataset = 'refcoco', splitBy = 'unc'
+ print('loading dataset {} into memory...'.format(dataset))
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
+ self.DATA_DIR = osp.join(data_root, dataset)
+ if dataset in ['refcoco', 'refcoco+', 'refcocog']:
+ self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
+ elif dataset == 'refclef':
+ self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
+ else:
+ print('No refer dataset is called [{}]'.format(dataset))
+ sys.exit()
+
+ # load refs from data/dataset/refs(dataset).json
+ tic = time.time()
+ ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
+ self.data = {}
+ self.data['dataset'] = dataset
+ self.data['refs'] = pickle.load(open(ref_file, 'rb'))
+
+ # load annotations from data/dataset/instances.json
+ instances_file = osp.join(self.DATA_DIR, 'instances.json')
+ instances = json.load(open(instances_file, 'r'))
+ self.data['images'] = instances['images']
+ self.data['annotations'] = instances['annotations']
+ self.data['categories'] = instances['categories']
+
+ # create index
+ self.createIndex()
+ print('DONE (t=%.2fs)'.format(time.time()-tic))
+
+ def createIndex(self):
+ # create sets of mapping
+ # 1) Refs: {ref_id: ref}
+ # 2) Anns: {ann_id: ann}
+ # 3) Imgs: {image_id: image}
+ # 4) Cats: {category_id: category_name}
+ # 5) Sents: {sent_id: sent}
+ # 6) imgToRefs: {image_id: refs}
+ # 7) imgToAnns: {image_id: anns}
+ # 8) refToAnn: {ref_id: ann}
+ # 9) annToRef: {ann_id: ref}
+ # 10) catToRefs: {category_id: refs}
+ # 11) sentToRef: {sent_id: ref}
+ # 12) sentToTokens: {sent_id: tokens}
+ print('creating index...')
+ # fetch info from instances
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
+ for ann in self.data['annotations']:
+ Anns[ann['id']] = ann
+ imgToAnns[ann['image_id']] = imgToAnns.get(
+ ann['image_id'], []) + [ann]
+ for img in self.data['images']:
+ Imgs[img['id']] = img
+ for cat in self.data['categories']:
+ Cats[cat['id']] = cat['name']
+
+ # fetch info from refs
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
+ Sents, sentToRef, sentToTokens = {}, {}, {}
+ for ref in self.data['refs']:
+ # ids
+ ref_id = ref['ref_id']
+ ann_id = ref['ann_id']
+ category_id = ref['category_id']
+ image_id = ref['image_id']
+
+ # add mapping related to ref
+ Refs[ref_id] = ref
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
+ refToAnn[ref_id] = Anns[ann_id]
+ annToRef[ann_id] = ref
+
+ # add mapping of sent
+ for sent in ref['sentences']:
+ Sents[sent['sent_id']] = sent
+ sentToRef[sent['sent_id']] = ref
+ sentToTokens[sent['sent_id']] = sent['tokens']
+
+ # create class members
+ self.Refs = Refs
+ self.Anns = Anns
+ self.Imgs = Imgs
+ self.Cats = Cats
+ self.Sents = Sents
+ self.imgToRefs = imgToRefs
+ self.imgToAnns = imgToAnns
+ self.refToAnn = refToAnn
+ self.annToRef = annToRef
+ self.catToRefs = catToRefs
+ self.sentToRef = sentToRef
+ self.sentToTokens = sentToTokens
+ print('index created.')
+
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
+ refs = self.data['refs']
+ else:
+ if not len(image_ids) == 0:
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
+ else:
+ refs = self.data['refs']
+ if not len(cat_ids) == 0:
+ refs = [ref for ref in refs if ref['category_id'] in cat_ids]
+ if not len(ref_ids) == 0:
+ refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
+ if not len(split) == 0:
+ if split in ['testA', 'testB', 'testC']:
+ # we also consider testAB, testBC, ...
+ refs = [ref for ref in refs if split[-1] in ref['split']]
+ elif split in ['testAB', 'testBC', 'testAC']:
+ # rarely used I guess...
+ refs = [ref for ref in refs if ref['split'] == split]
+ elif split == 'test':
+ refs = [ref for ref in refs if 'test' in ref['split']]
+ elif split == 'train' or split == 'val':
+ refs = [ref for ref in refs if ref['split'] == split]
+ else:
+ print('No such split [{}]'.format(split))
+ sys.exit()
+ ref_ids = [ref['ref_id'] for ref in refs]
+ return ref_ids
+
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
+ ann_ids = [ann['id'] for ann in self.data['annotations']]
+ else:
+ if not len(image_ids) == 0:
+ lists = [self.imgToAnns[image_id]
+ for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
+ anns = list(itertools.chain.from_iterable(lists))
+ else:
+ anns = self.data['annotations']
+ if not len(cat_ids) == 0:
+ anns = [ann for ann in anns if ann['category_id'] in cat_ids]
+ ann_ids = [ann['id'] for ann in anns]
+ if not len(ref_ids) == 0:
+ ids = set(ann_ids).intersection(
+ set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
+ return ann_ids
+
+ def getImgIds(self, ref_ids=[]):
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if not len(ref_ids) == 0:
+ image_ids = list(set([self.Refs[ref_id]['image_id']
+ for ref_id in ref_ids]))
+ else:
+ image_ids = self.Imgs.keys()
+ return image_ids
+
+ def getCatIds(self):
+ return self.Cats.keys()
+
+ def loadRefs(self, ref_ids=[]):
+ if type(ref_ids) == list:
+ return [self.Refs[ref_id] for ref_id in ref_ids]
+ elif type(ref_ids) == int:
+ return [self.Refs[ref_ids]]
+
+ def loadAnns(self, ann_ids=[]):
+ if type(ann_ids) == list:
+ return [self.Anns[ann_id] for ann_id in ann_ids]
+ elif type(ann_ids) == int or type(ann_ids) == unicode:
+ return [self.Anns[ann_ids]]
+
+ def loadImgs(self, image_ids=[]):
+ if type(image_ids) == list:
+ return [self.Imgs[image_id] for image_id in image_ids]
+ elif type(image_ids) == int:
+ return [self.Imgs[image_ids]]
+
+ def loadCats(self, cat_ids=[]):
+ if type(cat_ids) == list:
+ return [self.Cats[cat_id] for cat_id in cat_ids]
+ elif type(cat_ids) == int:
+ return [self.Cats[cat_ids]]
+
+ def getRefBox(self, ref_id):
+ ref = self.Refs[ref_id]
+ ann = self.refToAnn[ref_id]
+ return ann['bbox'] # [x, y, w, h]
+
+ def showRef(self, ref, seg_box='seg'):
+ ax = plt.gca()
+ # show image
+ image = self.Imgs[ref['image_id']]
+ I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
+ ax.imshow(I)
+ # show refer expression
+ for sid, sent in enumerate(ref['sentences']):
+ print('{}. {}'.format(sid+1, sent['sent']))
+ # show segmentations
+ if seg_box == 'seg':
+ ann_id = ref['ann_id']
+ ann = self.Anns[ann_id]
+ polygons = []
+ color = []
+ c = 'none'
+ if type(ann['segmentation'][0]) == list:
+ # polygon used for refcoco*
+ for seg in ann['segmentation']:
+ poly = np.array(seg).reshape((len(seg)/2, 2))
+ polygons.append(Polygon(poly, True, alpha=0.4))
+ color.append(c)
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(
+ 1, 1, 0, 0), linewidths=3, alpha=1)
+ ax.add_collection(p) # thick yellow polygon
+ p = PatchCollection(polygons, facecolors=color, edgecolors=(
+ 1, 0, 0, 0), linewidths=1, alpha=1)
+ ax.add_collection(p) # thin red polygon
+ else:
+ # mask used for refclef
+ rle = ann['segmentation']
+ m = mask.decode(rle)
+ img = np.ones((m.shape[0], m.shape[1], 3))
+ color_mask = np.array([2.0, 166.0, 101.0])/255
+ for i in range(3):
+ img[:, :, i] = color_mask[i]
+ ax.imshow(np.dstack((img, m*0.5)))
+ # show bounding-box
+ elif seg_box == 'box':
+ ann_id = ref['ann_id']
+ ann = self.Anns[ann_id]
+ bbox = self.getRefBox(ref['ref_id'])
+ box_plot = Rectangle(
+ (bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
+ ax.add_patch(box_plot)
+
+ def getMask(self, ref):
+ # return mask, area and mask-center
+ ann = self.refToAnn[ref['ref_id']]
+ image = self.Imgs[ref['image_id']]
+ if type(ann['segmentation'][0]) == list: # polygon
+ rle = mask.frPyObjects(
+ ann['segmentation'], image['height'], image['width'])
+ else:
+ rle = ann['segmentation']
+ m = mask.decode(rle)
+ # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = np.sum(m, axis=2)
+ m = m.astype(np.uint8) # convert to np.uint8
+ # compute area
+ area = sum(mask.area(rle)) # should be close to ann['area']
+ return {'mask': m, 'area': area}
+ # # position
+ # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
+ # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
+ # # mass position (if there were multiple regions, we use the largest one.)
+ # label_m = label(m, connectivity=m.ndim)
+ # regions = regionprops(label_m)
+ # if len(regions) > 0:
+ # largest_id = np.argmax(np.array([props.filled_area for props in regions]))
+ # largest_props = regions[largest_id]
+ # mass_y, mass_x = largest_props.centroid
+ # else:
+ # mass_x, mass_y = position_x, position_y
+ # # if centroid is not in mask, we find the closest point to it from mask
+ # if m[mass_y, mass_x] != 1:
+ # print 'Finding closes mask point ...'
+ # kernel = np.ones((10, 10),np.uint8)
+ # me = cv2.erode(m, kernel, iterations = 1)
+ # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
+ # points = np.array(points)
+ # dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
+ # id = np.argsort(dist)[0]
+ # mass_y, mass_x = points[id]
+ # # return
+ # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
+ # # show image and mask
+ # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
+ # plt.figure()
+ # plt.imshow(I)
+ # ax = plt.gca()
+ # img = np.ones( (m.shape[0], m.shape[1], 3) )
+ # color_mask = np.array([2.0,166.0,101.0])/255
+ # for i in range(3):
+ # img[:,:,i] = color_mask[i]
+ # ax.imshow(np.dstack( (img, m*0.5) ))
+ # plt.show()
+
+ def showMask(self, ref):
+ M = self.getMask(ref)
+ msk = M['mask']
+ ax = plt.gca()
+ ax.imshow(msk)
+
+
+if __name__ == '__main__':
+ refer = REFER(data_root='/home/xueyanz/code/dataset/refcocoseg',
+ dataset='refcocog', splitBy='google')
+ ref_ids = refer.getRefIds()
+ print(len(ref_ids))
+
+ print(len(refer.Imgs))
+ print(len(refer.imgToRefs))
+
+ ref_ids = refer.getRefIds(split='train')
+ print('There are {} training referred objects.' % len(ref_ids))
+
+ for ref_id in ref_ids:
+ ref = refer.loadRefs(ref_id)[0]
+ if len(ref['sentences']) < 2:
+ continue
+
+ pprint(ref)
+ print('The label is {}.'.format(refer.Cats[ref['category_id']]))
+
+ # plt.figure()
+ # refer.showRef(ref, seg_box='box')
+ # plt.show()
+
+ # plt.figure()
+ # refer.showMask(ref)
+ # plt.show()
diff --git a/datasets/visual_sampler/__init__.py b/datasets/visual_sampler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6e15194058731fa2917a69c7bb97591cef585e2
--- /dev/null
+++ b/datasets/visual_sampler/__init__.py
@@ -0,0 +1,12 @@
+from .sampler import ShapeSampler
+from .simpleclick_sampler import SimpleClickSampler
+
+
+def build_shape_sampler(cfg, **kwargs):
+ sampler_name = cfg['STROKE_SAMPLER']['EVAL']['MODE']
+ if sampler_name == 'random':
+ return ShapeSampler(cfg, **kwargs)
+ elif sampler_name in ['best', 'best_random']:
+ return SimpleClickSampler(cfg, **kwargs)
+ else:
+ assert False, "not implemented"
\ No newline at end of file
diff --git a/datasets/visual_sampler/circle.py b/datasets/visual_sampler/circle.py
new file mode 100644
index 0000000000000000000000000000000000000000..6db18163d91c038845adabd8a968d99e93d65442
--- /dev/null
+++ b/datasets/visual_sampler/circle.py
@@ -0,0 +1,106 @@
+import random
+import torch
+
+from .mask_generators import get_mask_by_input_strokes
+
+class Circle:
+ def __init__(self, cfg, is_train=True):
+ self.num_stroke = cfg['STROKE_SAMPLER']['CIRCLE']['NUM_STROKES']
+ self.stroke_preset = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PRESET']
+ self.stroke_prob = cfg['STROKE_SAMPLER']['CIRCLE']['STROKE_PROB']
+ self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
+ self.is_train = is_train
+
+ @staticmethod
+ def get_stroke_preset(stroke_preset):
+ if stroke_preset == 'object_like':
+ return {
+ "nVertexBound": [5, 30],
+ "maxHeadSpeed": 15,
+ "maxHeadAcceleration": (10, 1.5),
+ "brushWidthBound": (20, 50),
+ "nMovePointRatio": 0.5,
+ "maxPiontMove": 10,
+ "maxLineAcceleration": (5, 0.5),
+ "boarderGap": None,
+ "maxInitSpeed": 10,
+ }
+ elif stroke_preset == 'object_like_middle':
+ return {
+ "nVertexBound": [5, 15],
+ "maxHeadSpeed": 8,
+ "maxHeadAcceleration": (4, 1.5),
+ "brushWidthBound": (20, 50),
+ "nMovePointRatio": 0.5,
+ "maxPiontMove": 5,
+ "maxLineAcceleration": (5, 0.5),
+ "boarderGap": None,
+ "maxInitSpeed": 10,
+ }
+ elif stroke_preset == 'object_like_small':
+ return {
+ "nVertexBound": [5, 20],
+ "maxHeadSpeed": 7,
+ "maxHeadAcceleration": (3.5, 1.5),
+ "brushWidthBound": (10, 30),
+ "nMovePointRatio": 0.5,
+ "maxPiontMove": 5,
+ "maxLineAcceleration": (3, 0.5),
+ "boarderGap": None,
+ "maxInitSpeed": 4,
+ }
+ else:
+ raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
+
+ def get_random_points_from_mask(self, mask, n=5):
+ h,w = mask.shape
+ view_mask = mask.reshape(h*w)
+ non_zero_idx = view_mask.nonzero()[:,0]
+ selected_idx = torch.randperm(len(non_zero_idx))[:n]
+ non_zero_idx = non_zero_idx[selected_idx]
+ y = (non_zero_idx // w)*1.0
+ x = (non_zero_idx % w)*1.0
+ return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
+
+ def draw(self, mask=None, box=None):
+ if mask.sum() < 10: # if mask is nearly empty
+ return torch.zeros(mask.shape).bool()
+ if not self.is_train:
+ return self.draw_eval(mask=mask, box=box)
+ stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
+ preset = Circle.get_stroke_preset(stroke_preset_name)
+ nStroke = min(random.randint(1, self.num_stroke), mask.sum().item())
+ h,w = mask.shape
+ points = self.get_random_points_from_mask(mask, n=nStroke)
+ rand_mask = get_mask_by_input_strokes(
+ init_points=points,
+ imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
+ rand_mask = (~torch.from_numpy(rand_mask)) * mask
+ return rand_mask
+
+ def draw_eval(self, mask=None, box=None):
+ stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0] # select which kind of object to use
+ preset = Circle.get_stroke_preset(stroke_preset_name)
+ nStroke = min(self.max_eval, mask.sum().item())
+ h,w = mask.shape
+ points = self.get_random_points_from_mask(mask, n=nStroke)
+ rand_masks = []
+ for i in range(len(points)):
+ rand_mask = get_mask_by_input_strokes(
+ init_points=points[:i+1],
+ imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points[:i+1])), **preset)
+ rand_masks += [(~torch.from_numpy(rand_mask)) * mask]
+ return torch.stack(rand_masks)
+
+ @staticmethod
+ def draw_by_points(points, mask, h, w):
+ stroke_preset_name = random.choices(['object_like', 'object_like_middle', 'object_like_small'], weights=[0.33,0.33,0.33], k=1)[0] # select which kind of object to use
+ preset = Circle.get_stroke_preset(stroke_preset_name)
+ rand_mask = get_mask_by_input_strokes(
+ init_points=points,
+ imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
+ rand_masks = (~torch.from_numpy(rand_mask)) * mask
+ return rand_masks
+
+ def __repr__(self,):
+ return 'circle'
\ No newline at end of file
diff --git a/datasets/visual_sampler/mask_generators.py b/datasets/visual_sampler/mask_generators.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab651bbe102ffd6971630305c5ca484437ba858a
--- /dev/null
+++ b/datasets/visual_sampler/mask_generators.py
@@ -0,0 +1,215 @@
+import numpy as np
+import random
+from PIL import Image, ImageDraw
+
+
+def get_mask_by_input_strokes(
+ init_points, imageWidth=320, imageHeight=180, nStroke=5,
+ nVertexBound=[10, 30], maxHeadSpeed=15, maxHeadAcceleration=(15, 0.5),
+ brushWidthBound=(5, 20), boarderGap=None, nMovePointRatio=0.5, maxPiontMove=10,
+ maxLineAcceleration=5, maxInitSpeed=5
+):
+ '''
+ Get video masks by random strokes which move randomly between each
+ frame, including the whole stroke and its control points
+
+ Parameters
+ ----------
+ imageWidth: Image width
+ imageHeight: Image height
+ nStroke: Number of drawed lines
+ nVertexBound: Lower/upper bound of number of control points for each line
+ maxHeadSpeed: Max head speed when creating control points
+ maxHeadAcceleration: Max acceleration applying on the current head point (
+ a head point and its velosity decides the next point)
+ brushWidthBound (min, max): Bound of width for each stroke
+ boarderGap: The minimum gap between image boarder and drawed lines
+ nMovePointRatio: The ratio of control points to move for next frames
+ maxPiontMove: The magnitude of movement for control points for next frames
+ maxLineAcceleration: The magnitude of acceleration for the whole line
+
+ Examples
+ ----------
+ object_like_setting = {
+ "nVertexBound": [5, 20],
+ "maxHeadSpeed": 15,
+ "maxHeadAcceleration": (15, 3.14),
+ "brushWidthBound": (30, 50),
+ "nMovePointRatio": 0.5,
+ "maxPiontMove": 10,
+ "maxLineAcceleration": (5, 0.5),
+ "boarderGap": 20,
+ "maxInitSpeed": 10,
+ }
+ rand_curve_setting = {
+ "nVertexBound": [10, 30],
+ "maxHeadSpeed": 20,
+ "maxHeadAcceleration": (15, 0.5),
+ "brushWidthBound": (3, 10),
+ "nMovePointRatio": 0.5,
+ "maxPiontMove": 3,
+ "maxLineAcceleration": (5, 0.5),
+ "boarderGap": 20,
+ "maxInitSpeed": 6
+ }
+ get_video_masks_by_moving_random_stroke(video_len=5, nStroke=3, **object_like_setting)
+ '''
+ # Initilize a set of control points to draw the first mask
+ mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
+ control_points_set = []
+ for i in range(nStroke):
+ brushWidth = np.random.randint(brushWidthBound[0], brushWidthBound[1])
+ Xs, Ys, velocity = get_random_stroke_control_points(
+ init_point=init_points[i],
+ imageWidth=imageWidth, imageHeight=imageHeight,
+ nVertexBound=nVertexBound, maxHeadSpeed=maxHeadSpeed,
+ maxHeadAcceleration=maxHeadAcceleration, boarderGap=boarderGap,
+ maxInitSpeed=maxInitSpeed
+ )
+ control_points_set.append((Xs, Ys, velocity, brushWidth))
+ draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
+
+ # Generate the following masks by randomly move strokes and their control points
+ mask = Image.new(mode='1', size=(imageWidth, imageHeight), color=1)
+ for j in range(len(control_points_set)):
+ Xs, Ys, velocity, brushWidth = control_points_set[j]
+ new_Xs, new_Ys = random_move_control_points(
+ Xs, Ys, velocity, nMovePointRatio, maxPiontMove,
+ maxLineAcceleration, boarderGap
+ )
+ control_points_set[j] = (new_Xs, new_Ys, velocity, brushWidth)
+ for Xs, Ys, velocity, brushWidth in control_points_set:
+ draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=0)
+
+ return np.array(mask)
+
+
+def random_accelerate(velocity, maxAcceleration, dist='uniform'):
+ speed, angle = velocity
+ d_speed, d_angle = maxAcceleration
+
+ if dist == 'uniform':
+ speed += np.random.uniform(-d_speed, d_speed)
+ angle += np.random.uniform(-d_angle, d_angle)
+ elif dist == 'guassian':
+ speed += np.random.normal(0, d_speed / 2)
+ angle += np.random.normal(0, d_angle / 2)
+ else:
+ raise NotImplementedError(f'Distribution type {dist} is not supported.')
+
+ return (speed, angle)
+
+
+def random_move_control_points(Xs, Ys, lineVelocity, nMovePointRatio, maxPiontMove, maxLineAcceleration, boarderGap=15):
+ new_Xs = Xs.copy()
+ new_Ys = Ys.copy()
+
+ # move the whole line and accelerate
+ speed, angle = lineVelocity
+ new_Xs += int(speed * np.cos(angle))
+ new_Ys += int(speed * np.sin(angle))
+ lineVelocity = random_accelerate(lineVelocity, maxLineAcceleration, dist='guassian')
+
+ # choose points to move
+ chosen = np.arange(len(Xs))
+ np.random.shuffle(chosen)
+ chosen = chosen[:int(len(Xs) * nMovePointRatio)]
+ for i in chosen:
+ new_Xs[i] += np.random.randint(-maxPiontMove, maxPiontMove)
+ new_Ys[i] += np.random.randint(-maxPiontMove, maxPiontMove)
+ return new_Xs, new_Ys
+
+
+def get_random_stroke_control_points(
+ init_point,
+ imageWidth, imageHeight,
+ nVertexBound=(10, 30), maxHeadSpeed=10, maxHeadAcceleration=(5, 0.5), boarderGap=20,
+ maxInitSpeed=10
+):
+ '''
+ Implementation the free-form training masks generating algorithm
+ proposed by JIAHUI YU et al. in "Free-Form Image Inpainting with Gated Convolution"
+ '''
+ startX = init_point[0]
+ startY = init_point[1]
+
+ Xs = [init_point[0]]
+ Ys = [init_point[1]]
+
+ numVertex = np.random.randint(nVertexBound[0], nVertexBound[1])
+
+ angle = np.random.uniform(0, 2 * np.pi)
+ speed = np.random.uniform(0, maxHeadSpeed)
+
+ for i in range(numVertex):
+ speed, angle = random_accelerate((speed, angle), maxHeadAcceleration)
+ speed = np.clip(speed, 0, maxHeadSpeed)
+
+ nextX = startX + speed * np.sin(angle)
+ nextY = startY + speed * np.cos(angle)
+
+ if boarderGap is not None:
+ nextX = np.clip(nextX, boarderGap, imageWidth - boarderGap)
+ nextY = np.clip(nextY, boarderGap, imageHeight - boarderGap)
+
+ startX, startY = nextX, nextY
+ Xs.append(nextX)
+ Ys.append(nextY)
+
+ velocity = get_random_velocity(maxInitSpeed, dist='guassian')
+
+ return np.array(Xs), np.array(Ys), velocity
+
+
+def get_random_velocity(max_speed, dist='uniform'):
+ if dist == 'uniform':
+ speed = np.random.uniform(max_speed)
+ elif dist == 'guassian':
+ speed = np.abs(np.random.normal(0, max_speed / 2))
+ else:
+ raise NotImplementedError(f'Distribution type {dist} is not supported.')
+
+ angle = np.random.uniform(0, 2 * np.pi)
+ return (speed, angle)
+
+
+def draw_mask_by_control_points(mask, Xs, Ys, brushWidth, fill=255):
+ radius = brushWidth // 2 - 1
+ for i in range(1, len(Xs)):
+ draw = ImageDraw.Draw(mask)
+ startX, startY = Xs[i - 1], Ys[i - 1]
+ nextX, nextY = Xs[i], Ys[i]
+ draw.line((startX, startY) + (nextX, nextY), fill=fill, width=brushWidth)
+ for x, y in zip(Xs, Ys):
+ draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=fill)
+ return mask
+
+
+# modified from https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/generate_data.py
+def get_random_walk_mask(imageWidth=320, imageHeight=180, length=None):
+ action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
+ canvas = np.zeros((imageHeight, imageWidth)).astype("i")
+ if length is None:
+ length = imageWidth * imageHeight
+ x = random.randint(0, imageHeight - 1)
+ y = random.randint(0, imageWidth - 1)
+ x_list = []
+ y_list = []
+ for i in range(length):
+ r = random.randint(0, len(action_list) - 1)
+ x = np.clip(x + action_list[r][0], a_min=0, a_max=imageHeight - 1)
+ y = np.clip(y + action_list[r][1], a_min=0, a_max=imageWidth - 1)
+ x_list.append(x)
+ y_list.append(y)
+ canvas[np.array(x_list), np.array(y_list)] = 1
+ return Image.fromarray(canvas * 255).convert('1')
+
+
+def get_masked_ratio(mask):
+ """
+ Calculate the masked ratio.
+ mask: Expected a binary PIL image, where 0 and 1 represent
+ masked(invalid) and valid pixel values.
+ """
+ hist = mask.histogram()
+ return hist[0] / np.prod(mask.size)
diff --git a/datasets/visual_sampler/point.py b/datasets/visual_sampler/point.py
new file mode 100644
index 0000000000000000000000000000000000000000..417289834df2b6be003f38c13c25d8338572e76e
--- /dev/null
+++ b/datasets/visual_sampler/point.py
@@ -0,0 +1,74 @@
+import random
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy import ndimage
+
+
+class Point:
+ def __init__(self, cfg, is_train=True):
+ self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS']
+ self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
+ self.is_train = is_train
+
+ def draw(self, mask=None, box=None):
+ if mask.sum() < 10:
+ return torch.zeros(mask.shape).bool() # if mask is empty
+ if not self.is_train:
+ return self.draw_eval(mask=mask, box=box)
+ max_points = min(self.max_points, mask.sum().item()) # max number of points no more than total mask number
+ num_points = random.randint(1, max_points) # get a random number of points
+ h,w = mask.shape
+ view_mask = mask.view(-1)
+ non_zero_idx = view_mask.nonzero()[:,0] # get non-zero index of mask
+ selected_idx = torch.randperm(len(non_zero_idx))[:num_points] # select id
+ non_zero_idx = non_zero_idx[selected_idx] # select non-zero index
+ rand_mask = torch.zeros(view_mask.shape).bool() # init rand mask
+ rand_mask[non_zero_idx] = True # get non zero place to zero
+ # dilate
+ # struct = ndimage.generate_binary_structure(2, 2)
+ # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
+ # return rand_mask
+ return rand_mask.reshape(h, w)
+
+ def draw_eval(self, mask=None, box=None):
+ background = ~mask
+ neg_num = min(self.max_eval // 2, background.sum().item())
+ pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1
+
+ h,w = mask.shape
+ view_mask = mask.view(-1)
+ non_zero_idx_pos = view_mask.nonzero()[:,0] # get non-zero index of mask
+ selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] # select id
+ non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] # select non-zero index
+ pos_idx = torch.ones(non_zero_idx_pos.shape)
+
+ view_background = background.view(-1)
+ non_zero_idx_neg = view_background.nonzero()[:,0] # get non-zero index of mask
+ selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] # select id
+ non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] # select non-zero index
+ neg_idx = torch.ones(non_zero_idx_neg.shape) * -1
+
+ non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg])
+ idx = torch.cat([pos_idx, neg_idx])
+ rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long()
+ non_zero_idx = non_zero_idx[rand_idx]
+ idx = idx[rand_idx]
+
+ rand_masks = []
+ for i in range(0, len(non_zero_idx)):
+ rand_mask = torch.zeros(view_mask.shape) # init rand mask
+ rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] # get non zero place to zero
+ # struct = ndimage.generate_binary_structure(2, 2)
+ # rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
+ rand_masks += [rand_mask.reshape(h, w)]
+
+ # kernel_size = 3
+ rand_masks = torch.stack(rand_masks)
+ # rand_masks = F.conv2d(rand_masks[:,None], torch.ones(1,1,kernel_size,kernel_size), padding=kernel_size//2)[:,0]
+ # rand_masks[rand_masks>0] = 1
+ # rand_masks[rand_masks<0] = -1
+ return rand_masks
+
+ def __repr__(self,):
+ return 'point'
\ No newline at end of file
diff --git a/datasets/visual_sampler/polygon.py b/datasets/visual_sampler/polygon.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb99c8214539723ca976644b48838f21f2f97c82
--- /dev/null
+++ b/datasets/visual_sampler/polygon.py
@@ -0,0 +1,137 @@
+import random
+
+import numpy as np
+import torch
+from scipy.special import binom
+from scipy import ndimage
+import matplotlib.pyplot as plt
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+
+bernstein = lambda n, k, t: binom(n,k)* t**k * (1.-t)**(n-k)
+
+def bezier(points, num=200):
+ N = len(points)
+ t = np.linspace(0, 1, num=num)
+ curve = np.zeros((num, 2))
+ for i in range(N):
+ curve += np.outer(bernstein(N - 1, i, t), points[i])
+ return curve
+
+class Segment():
+ def __init__(self, p1, p2, angle1, angle2, **kw):
+ self.p1 = p1; self.p2 = p2
+ self.angle1 = angle1; self.angle2 = angle2
+ self.numpoints = kw.get("numpoints", 100)
+ r = kw.get("r", 0.3)
+ d = np.sqrt(np.sum((self.p2-self.p1)**2))
+ self.r = r*d
+ self.p = np.zeros((4,2))
+ self.p[0,:] = self.p1[:]
+ self.p[3,:] = self.p2[:]
+ self.calc_intermediate_points(self.r)
+
+ def calc_intermediate_points(self,r):
+ self.p[1,:] = self.p1 + np.array([self.r*np.cos(self.angle1),
+ self.r*np.sin(self.angle1)])
+ self.p[2,:] = self.p2 + np.array([self.r*np.cos(self.angle2+np.pi),
+ self.r*np.sin(self.angle2+np.pi)])
+ self.curve = bezier(self.p,self.numpoints)
+
+def get_curve(points, **kw):
+ segments = []
+ for i in range(len(points)-1):
+ seg = Segment(points[i,:2], points[i+1,:2], points[i,2],points[i+1,2],**kw)
+ segments.append(seg)
+ curve = np.concatenate([s.curve for s in segments])
+ return segments, curve
+
+def ccw_sort(p):
+ d = p-np.mean(p,axis=0)
+ s = np.arctan2(d[:,0], d[:,1])
+ return p[np.argsort(s),:]
+
+def get_bezier_curve(a, rad=0.2, edgy=0):
+ """ given an array of points *a*, create a curve through
+ those points.
+ *rad* is a number between 0 and 1 to steer the distance of
+ control points.
+ *edgy* is a parameter which controls how "edgy" the curve is,
+ edgy=0 is smoothest."""
+ p = np.arctan(edgy)/np.pi+.5
+ a = ccw_sort(a)
+ a = np.append(a, np.atleast_2d(a[0,:]), axis=0)
+ d = np.diff(a, axis=0)
+ ang = np.arctan2(d[:,1],d[:,0])
+ f = lambda ang : (ang>=0)*ang + (ang<0)*(ang+2*np.pi)
+ ang = f(ang)
+ ang1 = ang
+ ang2 = np.roll(ang,1)
+ ang = p*ang1 + (1-p)*ang2 + (np.abs(ang2-ang1) > np.pi )*np.pi
+ ang = np.append(ang, [ang[0]])
+ a = np.append(a, np.atleast_2d(ang).T, axis=1)
+ s, c = get_curve(a, r=rad, method="var")
+ x,y = c.T
+ return x,y,a
+
+class Polygon:
+ def __init__(self, cfg, is_train):
+ self.max_points = cfg['STROKE_SAMPLER']['POLYGON']['MAX_POINTS']
+ self.eval_points = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
+ self.is_train = is_train
+
+ def get_random_points_from_mask(self, mask, n=3):
+ h,w = mask.shape
+ view_mask = mask.reshape(h*w)
+ non_zero_idx = view_mask.nonzero()[:,0]
+ selected_idx = torch.randperm(len(non_zero_idx))[:n]
+ non_zero_idx = non_zero_idx[selected_idx]
+ y = (non_zero_idx // w)*1.0/(h+1)
+ x = (non_zero_idx % w)*1.0/(w+1)
+ return torch.cat((x[:,None],y[:,None]), dim=1).numpy()
+
+ def draw(self, mask=None, box=None):
+ if mask.sum() < 10:
+ return torch.zeros(mask.shape).bool() # if mask is empty
+ if not self.is_train:
+ return self.draw_eval(mask=mask, box=box)
+ # box: x1,y1,x2,y2
+ x1,y1,x2,y2 = box.int().unbind()
+ rad = 0.2
+ edgy = 0.05
+ num_points = random.randint(1, min(self.max_points, mask.sum().item()))
+ a = self.get_random_points_from_mask(mask[y1:y2,x1:x2], n=num_points)
+ x,y, _ = get_bezier_curve(a,rad=rad, edgy=edgy)
+ x = x.clip(0.0, 1.0)
+ y = y.clip(0.0, 1.0)
+ points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
+ canvas = torch.zeros((y2-y1, x2-x1))
+ canvas[points.long().tolist()] = 1
+ rand_mask = torch.zeros(mask.shape)
+ rand_mask[y1:y2,x1:x2] = canvas
+ return rand_mask.bool()
+
+ def draw_eval(self, mask=None, box=None):
+ # box: x1,y1,x2,y2
+ x1,y1,x2,y2 = box.int().unbind()
+ rad = 0.2
+ edgy = 0.05
+ num_points = min(self.eval_points, mask.sum().item())
+ a = self.get_random_points_from_mask(mask[y1:y2,x1:x2], n=num_points)
+ rand_masks = []
+ for i in range(len(a)):
+ x,y, _ = get_bezier_curve(a[:i+1],rad=rad, edgy=edgy)
+ x = x.clip(0.0, 1.0)
+ y = y.clip(0.0, 1.0)
+ points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
+ canvas = torch.zeros((y2-y1, x2-x1))
+ canvas[points.long().tolist()] = 1
+ rand_mask = torch.zeros(mask.shape)
+ rand_mask[y1:y2,x1:x2] = canvas
+
+ struct = ndimage.generate_binary_structure(2, 2)
+ rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask, structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
+ rand_masks += [rand_mask.bool()]
+ return torch.stack(rand_masks)
+
+ def __repr__(self,):
+ return 'polygon'
\ No newline at end of file
diff --git a/datasets/visual_sampler/sampler.py b/datasets/visual_sampler/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce1fba798d6a240d8f6b15277f1bdde208c3bab2
--- /dev/null
+++ b/datasets/visual_sampler/sampler.py
@@ -0,0 +1,77 @@
+import sys
+import random
+
+import torch
+import torch.nn as nn
+
+from .point import Point
+from .polygon import Polygon
+from .scribble import Scribble
+from .circle import Circle
+
+from modeling.utils import configurable
+
+
+class ShapeSampler(nn.Module):
+ @configurable
+ def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True):
+ super().__init__()
+ self.max_candidate = max_candidate
+ self.shape_prob = shape_prob
+ self.shape_candidate = shape_candidate
+ self.is_train = is_train
+
+ @classmethod
+ def from_config(cls, cfg, is_train=True, mode=None):
+ max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE']
+ candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS']
+ candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES']
+
+ if mode == 'hack_train':
+ candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names]
+ else:
+ # overwrite condidate_prob
+ if not is_train:
+ candidate_probs = [0.0 for x in range(len(candidate_names))]
+ candidate_probs[candidate_names.index(mode)] = 1.0
+ candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names]
+
+ # Build augmentation
+ return {
+ "max_candidate": max_candidate,
+ "shape_prob": candidate_probs,
+ "shape_candidate": candidate_classes,
+ "is_train": is_train,
+ }
+
+ def forward(self, instances):
+ masks = instances.gt_masks.tensor
+ boxes = instances.gt_boxes.tensor
+
+ if len(masks) == 0:
+ gt_masks = torch.zeros(masks.shape[-2:]).bool()
+ rand_masks = torch.zeros(masks.shape[-2:]).bool()
+ return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']}
+ indices = [x for x in range(len(masks))]
+
+ if self.is_train:
+ random.shuffle(indices)
+ candidate_mask = masks[indices[:self.max_candidate]]
+ candidate_box = boxes[indices[:self.max_candidate]]
+ else:
+ candidate_mask = masks
+ candidate_box = boxes
+
+ draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask))
+ rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)]
+ types = [repr(x) for x in draw_funcs]
+ for i in range(0, len(rand_shapes)):
+ if rand_shapes[i].sum() == 0:
+ candidate_mask[i] = candidate_mask[i] * 0
+ types[i] = 'none'
+
+ # candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c)
+ return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self}
+
+def build_shape_sampler(cfg, **kwargs):
+ return ShapeSampler(cfg, **kwargs)
\ No newline at end of file
diff --git a/datasets/visual_sampler/scribble.py b/datasets/visual_sampler/scribble.py
new file mode 100644
index 0000000000000000000000000000000000000000..d73658f0e110d61664b8a8027e9edb40d04f1dc1
--- /dev/null
+++ b/datasets/visual_sampler/scribble.py
@@ -0,0 +1,96 @@
+import random
+
+import torch
+
+from .mask_generators import get_mask_by_input_strokes
+
+class Scribble:
+ def __init__(self, cfg, is_train):
+ self.num_stroke = cfg['STROKE_SAMPLER']['SCRIBBLE']['NUM_STROKES']
+ self.stroke_preset = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PRESET']
+ self.stroke_prob = cfg['STROKE_SAMPLER']['SCRIBBLE']['STROKE_PROB']
+ self.eval_stroke = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
+ self.is_train = is_train
+
+ @staticmethod
+ def get_stroke_preset(stroke_preset):
+ if stroke_preset == 'rand_curve':
+ return {
+ "nVertexBound": [10, 30],
+ "maxHeadSpeed": 20,
+ "maxHeadAcceleration": (15, 0.5),
+ "brushWidthBound": (3, 10),
+ "nMovePointRatio": 0.5,
+ "maxPiontMove": 3,
+ "maxLineAcceleration": (5, 0.5),
+ "boarderGap": None,
+ "maxInitSpeed": 6
+ }
+ elif stroke_preset == 'rand_curve_small':
+ return {
+ "nVertexBound": [6, 22],
+ "maxHeadSpeed": 12,
+ "maxHeadAcceleration": (8, 0.5),
+ "brushWidthBound": (2.5, 5),
+ "nMovePointRatio": 0.5,
+ "maxPiontMove": 1.5,
+ "maxLineAcceleration": (3, 0.5),
+ "boarderGap": None,
+ "maxInitSpeed": 3
+ }
+ else:
+ raise NotImplementedError(f'The stroke presetting "{stroke_preset}" does not exist.')
+
+ def get_random_points_from_mask(self, mask, n=5):
+ h,w = mask.shape
+ view_mask = mask.reshape(h*w)
+ non_zero_idx = view_mask.nonzero()[:,0]
+ selected_idx = torch.randperm(len(non_zero_idx))[:n]
+ non_zero_idx = non_zero_idx[selected_idx]
+ y = (non_zero_idx // w)*1.0
+ x = (non_zero_idx % w)*1.0
+ return torch.cat((x[:,None], y[:,None]), dim=1).numpy()
+
+ def draw(self, mask=None, box=None):
+ if mask.sum() < 10:
+ return torch.zeros(mask.shape).bool() # if mask is empty
+ if not self.is_train:
+ return self.draw_eval(mask=mask, box=box)
+ stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
+ preset = Scribble.get_stroke_preset(stroke_preset_name)
+ nStroke = random.randint(1, min(self.num_stroke, mask.sum().item()))
+ h,w = mask.shape
+ points = self.get_random_points_from_mask(mask, n=nStroke)
+ rand_mask = get_mask_by_input_strokes(
+ init_points=points,
+ imageWidth=w, imageHeight=h, nStroke=min(nStroke, len(points)), **preset)
+ rand_mask = (~torch.from_numpy(rand_mask)) * mask
+ return rand_mask
+
+ def draw_eval(self, mask=None, box=None):
+ stroke_preset_name = random.choices(self.stroke_preset, weights=self.stroke_prob, k=1)[0]
+ preset = Scribble.get_stroke_preset(stroke_preset_name)
+ nStroke = min(self.eval_stroke, mask.sum().item())
+ h,w = mask.shape
+ points = self.get_random_points_from_mask(mask, n=nStroke)
+ rand_masks = []
+ for i in range(len(points)):
+ rand_mask = get_mask_by_input_strokes(
+ init_points=points[:i+1],
+ imageWidth=w, imageHeight=h, nStroke=min(i, len(points)), **preset)
+ rand_mask = (~torch.from_numpy(rand_mask)) * mask
+ rand_masks += [rand_mask]
+ return torch.stack(rand_masks)
+
+ @staticmethod
+ def draw_by_points(points, mask, h, w):
+ stroke_preset_name = random.choices(['rand_curve', 'rand_curve_small'], weights=[0.5, 0.5], k=1)[0]
+ preset = Scribble.get_stroke_preset(stroke_preset_name)
+ rand_mask = get_mask_by_input_strokes(
+ init_points=points,
+ imageWidth=w, imageHeight=h, nStroke=len(points), **preset)[None,]
+ rand_masks = (~torch.from_numpy(rand_mask)) * mask
+ return rand_masks
+
+ def __repr__(self,):
+ return 'scribble'
\ No newline at end of file
diff --git a/datasets/visual_sampler/simpleclick_sampler.py b/datasets/visual_sampler/simpleclick_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..99079e6b90efd534d72c146205d8720fc06401e9
--- /dev/null
+++ b/datasets/visual_sampler/simpleclick_sampler.py
@@ -0,0 +1,252 @@
+import sys
+import random
+
+import cv2
+import numpy as np
+from scipy import ndimage
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from kornia.contrib import distance_transform
+
+from .point import Point
+from .polygon import Polygon, get_bezier_curve
+from .scribble import Scribble
+from .circle import Circle
+
+from modeling.utils import configurable
+
+
+class SimpleClickSampler(nn.Module):
+ @configurable
+ def __init__(self, mask_mode='point', sample_negtive=False, is_train=True, dilation=None, dilation_kernel=None, max_points=None):
+ super().__init__()
+ self.mask_mode = mask_mode
+ self.sample_negtive = sample_negtive
+ self.is_train = is_train
+ self.dilation = dilation
+ self.register_buffer("dilation_kernel", dilation_kernel)
+ self.max_points = max_points
+
+ @classmethod
+ def from_config(cls, cfg, is_train=True, mode=None):
+ mask_mode = mode
+ sample_negtive = cfg['STROKE_SAMPLER']['EVAL']['NEGATIVE']
+
+ dilation = cfg['STROKE_SAMPLER']['DILATION']
+ dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
+
+ max_points = cfg['STROKE_SAMPLER']['POLYGON']['MAX_POINTS']
+
+ # Build augmentation
+ return {
+ "mask_mode": mask_mode,
+ "sample_negtive": sample_negtive,
+ "is_train": is_train,
+ "dilation": dilation,
+ "dilation_kernel": dilation_kernel,
+ "max_points": max_points,
+ }
+
+ def forward_point(self, instances, pred_masks=None, prev_masks=None):
+ gt_masks = instances.gt_masks.tensor
+ n,h,w = gt_masks.shape
+
+ # We only consider positive points
+ pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
+ prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
+
+ if not gt_masks.is_cuda:
+ gt_masks = gt_masks.to(pred_masks.device)
+
+ fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
+
+ # conv implementation
+ mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
+ next_mask = next_mask.view(n,-1)
+
+ next_mask[max_xy_idx] = True
+ next_mask = next_mask.reshape((n,h,w)).float()
+ next_mask = F.conv2d(next_mask[None,], self.dilation_kernel.repeat(len(next_mask),1,1,1), padding=self.dilation//2, groups=len(next_mask))[0] > 0
+ # end conv implementation
+
+ # disk implementation
+ # mask_dt = distance_transform((~fp)[None,].float())[0].view(n,-1)
+ # max_xy = mask_dt.max(dim=-1)[1]
+ # max_y, max_x = max_xy//w, max_xy%w
+ # max_xy_idx = torch.stack([max_y, max_x]).transpose(0,1)[:,:,None,None]
+ # y_idx = torch.arange(start=0, end=h, step=1, dtype=torch.float32, device=torch.cuda.current_device())
+ # x_idx = torch.arange(start=0, end=w, step=1, dtype=torch.float32, device=torch.cuda.current_device())
+ # coord_y, coord_x = torch.meshgrid(y_idx, x_idx)
+ # coords = torch.stack((coord_y, coord_x), dim=0).unsqueeze(0).repeat(len(max_xy_idx),1,1,1) # [bsx2,2,h,w], corresponding to 2d coordinate
+ # coords.add_(-max_xy_idx)
+ # coords.mul_(coords)
+ # next_mask = coords[:, 0] + coords[:, 1]
+ # next_mask = (next_mask <= 5**2)
+ # end disk implementation
+
+ rand_shapes = prev_masks | next_mask
+
+ types = ['point' for i in range(len(gt_masks))]
+ return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
+
+ def forward_circle(self, instances, pred_masks=None, prev_masks=None):
+ gt_masks = instances.gt_masks.tensor
+ n,h,w = gt_masks.shape
+
+ # We only consider positive points
+ pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
+ prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
+
+ if not gt_masks.is_cuda:
+ gt_masks = gt_masks.to(pred_masks.device)
+
+ fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
+
+ # conv implementation
+ mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
+ next_mask = next_mask.view(n,-1)
+
+ next_mask[max_xy_idx] = True
+ next_mask = next_mask.reshape((n,h,w)).float()
+
+ _next_mask = []
+ for idx in range(len(next_mask)):
+ points = next_mask[idx].nonzero().flip(dims=[-1]).cpu().numpy()
+ _next_mask += [Circle.draw_by_points(points, gt_masks[idx:idx+1].cpu(), h, w)]
+ next_mask = torch.cat(_next_mask, dim=0).bool()
+ rand_shapes = prev_masks | next_mask
+
+ types = ['circle' for i in range(len(gt_masks))]
+ return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
+
+ def forward_scribble(self, instances, pred_masks=None, prev_masks=None):
+ gt_masks = instances.gt_masks.tensor
+ n,h,w = gt_masks.shape
+
+ # We only consider positive points
+ pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
+ prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
+
+ if not gt_masks.is_cuda:
+ gt_masks = gt_masks.to(pred_masks.device)
+
+ fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
+
+ # conv implementation
+ mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
+ next_mask = next_mask.view(n,-1)
+
+ next_mask[max_xy_idx] = True
+ next_mask = next_mask.reshape((n,h,w)).float()
+
+ _next_mask = []
+ for idx in range(len(next_mask)):
+ points = next_mask[idx].nonzero().flip(dims=[-1]).cpu().numpy()
+ _next_mask += [Scribble.draw_by_points(points, gt_masks[idx:idx+1].cpu(), h, w)]
+ next_mask = torch.cat(_next_mask, dim=0).bool()
+ rand_shapes = prev_masks | next_mask
+
+ types = ['scribble' for i in range(len(gt_masks))]
+ return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
+
+ def forward_polygon(self, instances, pred_masks=None, prev_masks=None):
+ gt_masks = instances.gt_masks.tensor
+ gt_boxes = instances.gt_boxes.tensor
+ n,h,w = gt_masks.shape
+
+ # We only consider positive points
+ pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
+ prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
+
+ if not gt_masks.is_cuda:
+ gt_masks = gt_masks.to(pred_masks.device)
+
+ fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
+
+ next_mask = []
+ for i in range(len(fp)):
+ rad = 0.2
+ edgy = 0.05
+ num_points = random.randint(1, min(self.max_points, fp[i].sum()))
+
+ h,w = fp[i].shape
+ view_mask = fp[i].reshape(h*w)
+ non_zero_idx = view_mask.nonzero()[:,0]
+ selected_idx = torch.randperm(len(non_zero_idx))[:num_points]
+ non_zero_idx = non_zero_idx[selected_idx]
+ y = (non_zero_idx // w)*1.0/(h+1)
+ x = (non_zero_idx % w)*1.0/(w+1)
+ coords = torch.cat((x[:,None],y[:,None]), dim=1).cpu().numpy()
+
+ x1,y1,x2,y2 = gt_boxes[i].int().unbind()
+ x,y, _ = get_bezier_curve(coords, rad=rad, edgy=edgy)
+ x = x.clip(0.0, 1.0)
+ y = y.clip(0.0, 1.0)
+ points = torch.from_numpy(np.concatenate((y[None,]*(y2-y1-1).item(),x[None,]*(x2-x1-1).item()))).int()
+ canvas = torch.zeros((y2-y1, x2-x1))
+ canvas[points.long().tolist()] = 1
+ rand_mask = torch.zeros(fp[i].shape)
+ rand_mask[y1:y2,x1:x2] = canvas
+ next_mask += [rand_mask]
+
+ next_mask = torch.stack(next_mask).to(pred_masks.device).bool()
+ rand_shapes = prev_masks | next_mask
+
+ types = ['polygon' for i in range(len(gt_masks))]
+ return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
+
+ def forward_box(self, instances, pred_masks=None, prev_masks=None):
+ gt_masks = instances.gt_masks.tensor
+ gt_boxes = instances.gt_boxes.tensor
+ n,h,w = gt_masks.shape
+
+ for i in range(len(gt_masks)):
+ x1,y1,x2,y2 = gt_boxes[i].int().unbind()
+ gt_masks[i,y1:y2,x1:x2] = 1
+
+ # We only consider positive points
+ pred_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if pred_masks is None else pred_masks[:,:h,:w]
+ prev_masks = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool() if prev_masks is None else prev_masks
+
+ if not gt_masks.is_cuda:
+ gt_masks = gt_masks.to(pred_masks.device)
+
+ fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
+
+ # conv implementation
+ mask_dt = (distance_transform((~F.pad(fp[None,], pad=(1, 1, 1, 1), mode='constant', value=0)).float())[0,:,1:-1,1:-1]).reshape(n,-1)
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
+ next_mask = next_mask.view(n,-1)
+
+ next_mask[max_xy_idx] = True
+ next_mask = next_mask.reshape((n,h,w)).float()
+ next_mask = F.conv2d(next_mask[None,], self.dilation_kernel.repeat(len(next_mask),1,1,1), padding=self.dilation//2, groups=len(next_mask))[0] > 0
+ # end conv implementation
+
+ rand_shapes = prev_masks | next_mask
+
+ types = ['box' for i in range(len(gt_masks))]
+ return {'gt_masks': instances.gt_masks.tensor, 'rand_shape': rand_shapes[:,None], 'types': types}
+
+ def forward(self, instances, *args, **kwargs):
+ if self.mask_mode == 'Point':
+ return self.forward_point(instances, *args, **kwargs)
+ elif self.mask_mode == 'Circle':
+ return self.forward_circle(instances, *args, **kwargs)
+ elif self.mask_mode == 'Scribble':
+ return self.forward_scribble(instances, *args, **kwargs)
+ elif self.mask_mode == 'Polygon':
+ return self.forward_polygon(instances, *args, **kwargs)
+ elif self.mask_mode == 'Box':
+ return self.forward_box(instances, *args, **kwargs)
+
+def build_shape_sampler(cfg, **kwargs):
+ return ShapeSampler(cfg, **kwargs)
\ No newline at end of file
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..372d0625b57cff14261cd40ec2281f06c742c5b6
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,32 @@
+# FROM naotous/flash_attn:2.0.5-pytorch23.07
+FROM wangkenpu/pytorch:1.8.0-py39-cuda11.1-cudnn8-ubuntu18.04
+
+# RUN touch tensorboard_patcher.py && cp tensorboard_patcher.py $$USERSITE/usercustomize.py
+
+
+# RUN pip install --upgrade pip
+
+# RUN pip install -I torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
+# RUN pip install -I torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --user
+# RUN pip install kornia
+# RUN pip install timm==0.4.12
+# RUN python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'
+RUN pip install git+https://github.com/cocodataset/panopticapi.git
+RUN pip install git+https://github.com/openai/CLIP.git
+
+# RUN wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
+
+COPY assets/requirements/requirements.txt /tmp/requirements.txt
+RUN pip install -r /tmp/requirements.txt
+
+COPY assets/requirements/requirements_custom.txt /tmp/requirements_custom.txt
+RUN pip install -r /tmp/requirements_custom.txt
+
+#RUN pip install -U protobuf
+
+# Set environment variables
+ENV MKL_THREADING_LAYER=GNU
+ENV NCCL_DEBUG=INFO
+
+# Set the working directory HERE!
+WORKDIR /path/to/BiomedParse
\ No newline at end of file
diff --git a/docker/README.md b/docker/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a92b6797b20238bc39cc518840c188d1477d2174
--- /dev/null
+++ b/docker/README.md
@@ -0,0 +1,9 @@
+In Dockerfile, set WORKDIR to be the path to your BiomedParse repo.
+
+from the project root dir
+
+bash docker/docker_build.sh
+
+bash docker_run.sh to start
+
+inside docker container, run setup_inside_docker.sh
\ No newline at end of file
diff --git a/docker/data_env.sh b/docker/data_env.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5745ed493b21a151d6f89786e204338d89100a58
--- /dev/null
+++ b/docker/data_env.sh
@@ -0,0 +1 @@
+export HANOVER_DATASETS=biomedparse_datasets/ # Path to the datasets
\ No newline at end of file
diff --git a/docker/docker_build.sh b/docker/docker_build.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7b8dbd1d469276bedc50cf2cb7be44518f100ad6
--- /dev/null
+++ b/docker/docker_build.sh
@@ -0,0 +1 @@
+docker build -f docker/Dockerfile -t seem .
\ No newline at end of file
diff --git a/docker/docker_run.sh b/docker/docker_run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..055d5cc5b5dfebf43c1df19ec7a84388ca4086d9
--- /dev/null
+++ b/docker/docker_run.sh
@@ -0,0 +1 @@
+docker run -it --gpus all --shm-size=128G -v /mnt:/mnt -v $(pwd):/workspace -w /workspace seem
\ No newline at end of file
diff --git a/docker/setup_inside_docker.sh b/docker/setup_inside_docker.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9b494c34a89e7e84e319c14101a84bb7733b5e65
--- /dev/null
+++ b/docker/setup_inside_docker.sh
@@ -0,0 +1,10 @@
+# Customer Operator [only need training deformable vision encoder]
+cd modeling/vision/encoder/ops && sh make.sh && cd ../../../../
+
+# System Package [only need for demo in SEEM]
+sudo apt update
+sudo apt install ffmpeg
+
+#pip install gradio==3.44.4
+#pip install openai-whisper
+#pip install protobuf==3.20.*
\ No newline at end of file
diff --git a/entry.py b/entry.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce80ee789b990e7c2de2a231003c00e398c30e82
--- /dev/null
+++ b/entry.py
@@ -0,0 +1,92 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import os
+import sys
+import torch
+import logging
+#import wandb
+import random
+import numpy as np
+
+from utilities.arguments import load_opt_command
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# def init_wandb(args, job_dir, entity='YOUR_USER_NAME', project='YOUR_PROJECT_NAME', job_name='tmp'):
+# wandb_dir = os.path.join(job_dir, 'wandb')
+# os.makedirs(wandb_dir, exist_ok=True)
+# runid = None
+# if os.path.exists(f"{wandb_dir}/runid.txt"):
+# runid = open(f"{wandb_dir}/runid.txt").read()
+
+# wandb.init(project=project,
+# name=job_name,
+# dir=wandb_dir,
+# entity=entity,
+# resume="allow",
+# id=runid,
+# config={"hierarchical": True},)
+
+# open(f"{wandb_dir}/runid.txt", 'w').write(wandb.run.id)
+# wandb.config.update({k: args[k] for k in args if k not in wandb.config})
+
+def set_seed(seed: int = 42) -> None:
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ # When running on the CuDNN backend, two further options must be set
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ # Set a fixed value for the hash seed
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ print(f"Random seed set as {seed}")
+
+def main(args=None):
+ '''
+ [Main function for the entry point]
+ 1. Set environment variables for distributed training.
+ 2. Load the config file and set up the trainer.
+ '''
+
+ opt, cmdline_args = load_opt_command(args)
+ command = cmdline_args.command
+
+ if cmdline_args.user_dir:
+ absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
+ opt['base_path'] = absolute_user_dir
+
+ # update_opt(opt, command)
+ world_size = 1
+ if 'OMPI_COMM_WORLD_SIZE' in os.environ:
+ world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+
+ if opt['TRAINER'] == 'xdecoder':
+ from trainer import XDecoder_Trainer as Trainer
+ else:
+ assert False, "The trainer type: {} is not defined!".format(opt['TRAINER'])
+
+ set_seed(opt['RANDOM_SEED'])
+
+ trainer = Trainer(opt)
+ os.environ['TORCH_DISTRIBUTED_DEBUG']='DETAIL'
+
+ if command == "train":
+ # if opt['rank'] == 0 and opt['WANDB']:
+ # wandb.login(key=os.environ['WANDB_KEY'])
+ # init_wandb(opt, trainer.save_folder, job_name=trainer.save_folder)
+ trainer.train()
+ elif command == "evaluate":
+ trainer.eval()
+ else:
+ raise ValueError(f"Unknown command: {command}")
+
+if __name__ == "__main__":
+ main()
+ sys.exit(0)
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a192107b6cdfce18e516f445a7dd3f328f7f4e44
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,149 @@
+# name: biomedparse
+# channels:
+# - pytorch
+# - nvidia
+# - defaults
+# dependencies:
+# - _libgcc_mutex=0.1=main
+# - _openmp_mutex=5.1=1_gnu
+# - blas=1.0=mkl
+# - brotli-python=1.0.9=py39h6a678d5_8
+# - bzip2=1.0.8=h5eee18b_6
+# - ca-certificates=2024.7.2=h06a4308_0
+# - certifi=2024.7.4=py39h06a4308_0
+# - charset-normalizer=3.3.2=pyhd3eb1b0_0
+# - cuda-cudart=12.4.127=0
+# - cuda-cupti=12.4.127=0
+# - cuda-libraries=12.4.0=0
+# - cuda-nvrtc=12.4.127=0
+# - cuda-nvtx=12.4.127=0
+# - cuda-opencl=12.6.37=0
+# - cuda-runtime=12.4.0=0
+# - cuda-version=12.6=3
+# - ffmpeg=4.3=hf484d3e_0
+# - filelock=3.13.1=py39h06a4308_0
+# - freetype=2.12.1=h4a9f257_0
+# - gmp=6.2.1=h295c915_3
+# - gmpy2=2.1.2=py39heeb90bb_0
+# - gnutls=3.6.15=he1e5248_0
+# - idna=3.7=py39h06a4308_0
+# - intel-openmp=2023.1.0=hdb19cb5_46306
+# - jinja2=3.1.4=py39h06a4308_0
+# - jpeg=9e=h5eee18b_3
+# - lame=3.100=h7b6447c_0
+# - lcms2=2.12=h3be6417_0
+# - ld_impl_linux-64=2.38=h1181459_1
+# - lerc=3.0=h295c915_0
+# - libcublas=12.4.2.65=0
+# - libcufft=11.2.0.44=0
+# - libcufile=1.11.0.15=0
+# - libcurand=10.3.7.37=0
+# - libcusolver=11.6.0.99=0
+# - libcusparse=12.3.0.142=0
+# - libdeflate=1.17=h5eee18b_1
+# - libffi=3.4.4=h6a678d5_1
+# - libgcc-ng=11.2.0=h1234567_1
+# - libgomp=11.2.0=h1234567_1
+# - libiconv=1.16=h5eee18b_3
+# - libidn2=2.3.4=h5eee18b_0
+# - libjpeg-turbo=2.0.0=h9bf148f_0
+# - libnpp=12.2.5.2=0
+# - libnvfatbin=12.6.20=0
+# - libnvjitlink=12.4.99=0
+# - libnvjpeg=12.3.1.89=0
+# - libpng=1.6.39=h5eee18b_0
+# - libstdcxx-ng=11.2.0=h1234567_1
+# - libtasn1=4.19.0=h5eee18b_0
+# - libtiff=4.5.1=h6a678d5_0
+# - libunistring=0.9.10=h27cfd23_0
+# - libwebp-base=1.3.2=h5eee18b_0
+# - llvm-openmp=14.0.6=h9e868ea_0
+# - lz4-c=1.9.4=h6a678d5_1
+# - markupsafe=2.1.3=py39h5eee18b_0
+# - mkl=2023.1.0=h213fc3f_46344
+# - mkl-service=2.4.0=py39h5eee18b_1
+# - mkl_fft=1.3.8=py39h5eee18b_0
+# - mkl_random=1.2.4=py39hdb19cb5_0
+# - mpc=1.1.0=h10f8cd9_1
+# - mpfr=4.0.2=hb69a4c5_1
+# - mpmath=1.3.0=py39h06a4308_0
+# - ncurses=6.4=h6a678d5_0
+# - nettle=3.7.3=hbbd107a_1
+# - networkx=3.2.1=py39h06a4308_0
+# - openh264=2.1.1=h4ff587b_0
+# - openjpeg=2.5.2=he7f1fd0_0
+# - openssl=3.0.14=h5eee18b_0
+# - pip=24.2=py39h06a4308_0
+# - pysocks=1.7.1=py39h06a4308_0
+# - python=3.9.19=h955ad1f_1
+# - pytorch=2.4.0=py3.9_cuda12.4_cudnn9.1.0_0
+# - pytorch-cuda=12.4=hc786d27_6
+# - pytorch-mutex=1.0=cuda
+# - pyyaml=6.0.1=py39h5eee18b_0
+# - readline=8.2=h5eee18b_0
+# - requests=2.32.3=py39h06a4308_0
+# - setuptools=72.1.0=py39h06a4308_0
+# - sqlite=3.45.3=h5eee18b_0
+# - sympy=1.12=py39h06a4308_0
+# - tbb=2021.8.0=hdb19cb5_0
+# - tk=8.6.14=h39e8969_0
+# - torchaudio=2.4.0=py39_cu124
+# - torchtriton=3.0.0=py39
+# - torchvision=0.19.0=py39_cu124
+# - typing_extensions=4.11.0=py39h06a4308_0
+# - tzdata=2024a=h04d1e81_0
+# - urllib3=2.2.2=py39h06a4308_0
+# - wheel=0.43.0=py39h06a4308_0
+# - xz=5.4.6=h5eee18b_1
+# - yaml=0.2.5=h7b6447c_0
+# - zlib=1.2.13=h5eee18b_1
+# - zstd=1.5.5=hc292b87_2
+# - pip:
+# - accelerate==0.23.0
+# - antlr4-python3-runtime==4.9.3
+# - appdirs==1.4.4
+# - black==21.4b2
+# - open-clip-torch==2.26.1
+# - cloudpickle==3.0.0
+# - cython==3.0.2
+# # - deepspeed==0.10.3
+# - git+https://github.com/MaureenZOU/detectron2-xyz.git
+# - diffdist==0.1
+# - einops==0.8.0
+# - ftfy==6.1.1
+# - fvcore==0.1.5.post20221221
+# - hjson==3.1.0
+# - huggingface-hub==0.17.3
+# - hydra-core==1.3.2
+# - imageio==2.35.1
+# - infinibatch==0.1.1
+# - iopath==0.1.9
+# - json-tricks==3.17.3
+# - kornia==0.7.0
+# - mpi4py==3.1.5
+# - mup==1.0.0
+# - mypy-extensions==1.0.0
+# - ninja==1.11.1.1
+# - nltk==3.8.1
+# - numpy==1.23.1
+# - omegaconf==2.3.0
+# - opencv-python==4.8.1.78
+# - pandas==2.0.3
+# - pathspec==0.12.1
+# - pillow==9.4.0
+# - portalocker==2.10.1
+# - py-cpuinfo==9.0.0
+# - pycocotools==2.0.7
+# - pydantic==1.10.18
+# - pydot==3.0.1
+# - regex==2023.10.3
+# - scikit-image==0.21.0
+# - scikit-learn==1.3.1
+# - sentencepiece==0.1.99
+# - tabulate==0.9.0
+# - termcolor==2.4.0
+# - timm==0.4.12
+# - tokenizers==0.14.1
+# - transformers==4.34.0
+# - vision-datasets==0.2.2
+# - yacs==0.1.8
diff --git a/figures/main_figure_1a.py b/figures/main_figure_1a.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d9452dd651f9f147380c728fdc08c12c2c1c8ad
--- /dev/null
+++ b/figures/main_figure_1a.py
@@ -0,0 +1,99 @@
+#%%
+import os
+import json
+import numpy as np
+import seaborn as sns
+from scipy.stats import boxcox
+from pycirclize import Circos
+import matplotlib.pyplot as plt
+
+base_dir = 'metadata'
+with open(os.path.join(base_dir,'hierarchy.json'), 'r') as f:
+ hierarchy_data = json.load(f)
+
+with open(os.path.join(base_dir,'target_counts.json'), 'r') as f:
+ target_counts = json.load(f)
+
+with open(os.path.join(base_dir,'modality_counts.json'), 'r') as f:
+ modality_counts = json.load(f)
+
+# color scheme
+sectors = {k: 0 for k in hierarchy_data.keys()}
+for sector_name in hierarchy_data:
+ for k,v in hierarchy_data[sector_name]['child'].items():
+ sectors[sector_name] += len(v['child'])
+ sectors[sector_name] += 1
+
+name2color = {"organ": "#E41A1C", "abnormality": "#377EB8", "histology": "#4DAF4A"}
+
+def generate_shades(base_color, n):
+ return sns.light_palette(base_color, n + 2)[1:-1]
+
+color_schemes = {}
+for sector in sectors:
+ child_colors = generate_shades(name2color[sector], len(hierarchy_data[sector]['child']))
+ color_schemes[sector] = child_colors
+
+parent_track_ratio = (72, 85)
+middle_track_ratio = (85, 100)
+bar_track_ratio = (45, 70)
+parent_track_font_size = 7
+middle_track_font_size = 5.5
+bar_track_font_size = 7
+outer_track_font_size = 9
+
+circos = Circos(sectors, space=8.8)
+for sector in circos.sectors:
+ idx2label = {}
+ idx = 1
+ for k,v in hierarchy_data[sector.name.lower()]['child'].items():
+ for k1,v1 in v['child'].items():
+ idx2label[idx] = k1
+ idx += 1
+ idx2label[idx] = ''
+ idx2label[0] = ''
+
+ track_outer = sector.add_track((100, 101))
+ track_outer.xticks_by_interval(
+ 1,
+ tick_length=0,
+ outer=True,
+ show_bottom_line=False,
+ label_orientation="vertical",
+ label_formatter=lambda v: idx2label[int(v)],
+ label_size=outer_track_font_size,
+ show_endlabel=True
+ )
+
+ track = sector.add_track(parent_track_ratio)
+ track.axis(fc=name2color[sector.name], lw=0)
+ track.text(sector.name.capitalize().replace('Mri', 'MRI').replace('Ct', 'CT').replace('Oct', 'OCT').replace('Dermoscopy', "DS"), color="white", size=parent_track_font_size)
+
+ track1 = sector.add_track(middle_track_ratio, r_pad_ratio=0.1)
+ sect_start = 0
+ color_idx = 0
+ for i, (k,v) in enumerate(hierarchy_data[sector.name.lower()]['child'].items()):
+ sect_size = len(v['child']) if i != len(hierarchy_data[sector.name.lower()]['child'])-1 else len(v['child'])+1
+ if i == 0:
+ sect_size += 0.5
+ if i == len(hierarchy_data[sector.name.lower()]['child'])-1:
+ sect_size -= 0.5
+ track1.rect(sect_start, sect_start+sect_size, r_lim=(middle_track_ratio[0], middle_track_ratio[1]-1), ec="black", lw=0,fc=color_schemes[sector.name][color_idx])
+ color_idx += 1
+ track1.text(k.replace('abnormality', 'abn.').replace(' anatomies', '').replace(' disturbance', '').replace('other abn.', 'Other').replace('liver', '').replace('pancreas', '').capitalize(), sect_start+sect_size/2, color="black", size=middle_track_font_size)
+ sect_start += sect_size
+
+ x = np.linspace(sector.start+1 , sector.end-1 , int(sector.size)-1)
+ y = [target_counts[idx2label[i+1]] for i in range(0,len(x))]
+ y_box = boxcox(y, 0.35)
+
+ track2 = sector.add_track(bar_track_ratio, r_pad_ratio=0.1)
+ track2.axis()
+ track2.yticks([1.14, 2.29, 3.43, 4.58], ["10$^2$", "10$^3$", "10$^4$", "10$^5$"], label_size=bar_track_font_size-1)
+ track2.bar(x, y_box, color=name2color[sector.name], alpha=0.5, align="center", lw=0)
+
+fig = circos.plotfig()
+fig.savefig('plots/figure_1a.pdf')
+plt.show()
+
+# %%
diff --git a/figures/main_figure_1b.py b/figures/main_figure_1b.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a530a8f0c11b52943b5ce82f49b3ec7ab9b71b6
--- /dev/null
+++ b/figures/main_figure_1b.py
@@ -0,0 +1,101 @@
+#%%
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import json, os
+import seaborn as sns
+
+plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
+
+# Load data
+def load_data(file_path):
+ with open(file_path, 'r') as f:
+ return json.load(f)
+base_dir = 'metadata'
+data = load_data(os.path.join(base_dir, 'modality_counts.json'))
+separate_submodality = False
+
+# Transform data for plotting
+def transform_data(data):
+ df = pd.DataFrame([(modality, subcat, count) for modality, subcats in data.items() for subcat, count in subcats.items()], columns=['Modality', 'Sub-category', 'Count'])
+ return df
+
+df = transform_data(data)
+
+# Calculate total counts by modality and sort
+def calculate_totals(df):
+ total_counts_by_modality = df.groupby("Modality")["Count"].sum().sort_values(ascending=True)
+ sorted_modalities = total_counts_by_modality.index.tolist()
+ return total_counts_by_modality, sorted_modalities
+
+total_counts_by_modality, sorted_modalities = calculate_totals(df)
+
+# Generate color map
+def generate_color_map(total_counts_by_modality):
+ base_colors = plt.cm.cool(np.linspace(0, 1, len(total_counts_by_modality)))
+ modality_color_map = {modality: base_colors[i] for i, modality in enumerate(total_counts_by_modality.index)}
+ return modality_color_map
+
+modality_color_map = generate_color_map(total_counts_by_modality)
+
+# Format total count for display
+def format_total_count(total_count):
+ if total_count >= 1000:
+ exponent = int(np.floor(np.log10(total_count)))
+ mantissa = total_count / 10**exponent
+ formatted_total = f'{mantissa:.2f} x 10$^{exponent}$'
+ else:
+ exponent = 0
+ formatted_total = str(total_count)
+ return formatted_total, exponent
+
+# Plotting function
+def plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality):
+ fig, ax = plt.subplots(figsize=(10, 12))
+ current_bottom = np.zeros(len(sorted_modalities))
+ gap = 0.005 if separate_submodality else 0
+ shades = np.power(np.linspace(0.75, 1, df.groupby("Sub-category").ngroups), 2)
+
+ if separate_submodality:
+ for i, modality in enumerate(sorted_modalities):
+ subdf = df[df["Modality"] == modality].sort_values(by='Count', ascending=False)
+ for j, (index, row) in enumerate(subdf.iterrows()):
+ count = row['Count']
+ if count > 0:
+ color = np.array(modality_color_map[modality]) * shades[j % len(shades)]
+ ax.barh(modality, count, left=current_bottom[i], color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5)
+ current_bottom[i] += count + gap
+ current_bottom[i] -= gap
+ total_count = total_counts_by_modality[modality]
+ formatted_total, exponent = format_total_count(total_count)
+ ax.text(current_bottom[i] + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left')
+ else:
+ for i, modality in enumerate(sorted_modalities):
+ total_count = total_counts_by_modality[modality]
+ color = np.array(modality_color_map[modality] * shades[0])
+ if modality.islower():
+ modality = modality.capitalize()
+ ax.barh(modality, total_count, color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5)
+ formatted_total, exponent = format_total_count(total_count)
+ ax.text(total_count + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left')
+
+ configure_plot(ax, sorted_modalities)
+
+ plt.tight_layout()
+ plt.savefig("plots/data_dist_modality_bar_subbar.pdf" if separate_submodality else "plots/data_dist_modality_bar.pdf", bbox_inches="tight", pad_inches=0)
+ plt.show()
+
+# Configure plot aesthetics
+def configure_plot(ax, sorted_modalities):
+ ax.set_xscale('log')
+ ax.set_title("Number of images per modality", fontsize=28)
+ plt.yticks(rotation=0, fontsize=24, va='center')
+ ax.tick_params(axis='x', which='major', length=8)
+ ax.tick_params(axis='x', which='minor', length=5)
+ plt.xticks(fontsize=24)
+ sns.despine()
+
+# Main script execution
+plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality)
+
+# %%
diff --git a/figures/main_figure_2a.py b/figures/main_figure_2a.py
new file mode 100644
index 0000000000000000000000000000000000000000..21a02866add5264e3a52e5dc55698e5bad31df73
--- /dev/null
+++ b/figures/main_figure_2a.py
@@ -0,0 +1,93 @@
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+import json, os
+
+from statannot import add_stat_annotation
+from statannotations.Annotator import Annotator
+
+df = pd.read_csv('results/all_eval/all_metrics_median.csv')
+
+
+metric = 'dice'
+
+model_names = {metric: 'BiomedParse', f'medsam_{metric}': 'MedSAM (oracle box)', f'sam_{metric}': 'SAM (oracle box)',
+ f'dino_medsam_{metric}': 'MedSAM (Grounding DINO)', f'dino_sam_{metric}': 'SAM (Grounding DINO)'}
+df = df.rename(columns=model_names)
+
+score_vars = list(model_names.values())
+
+
+modality_list = ['CT', 'MRI', 'X-Ray', 'Pathology', 'Ultrasound', 'Fundus', 'Endoscope', 'Dermoscopy', 'OCT']
+# modify modality names
+mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
+df['modality'] = df['modality'].apply(lambda x: mod_names[x])
+
+# add an "All" modality
+all_df = df.copy()
+all_df['modality'] = 'All'
+df = pd.concat([df, all_df])
+
+df_long = df[['modality', 'task']+score_vars].melt(id_vars=['modality', 'task'], var_name='Model', value_name='Performance')
+
+
+
+# add statistical annotations
+fig, ax = plt.subplots(figsize=(9, 6))
+ax = sns.boxplot(data=df_long, x='modality', y='Performance', hue='Model', ax=ax, palette='Set2',
+ order=['All']+modality_list,
+ whis=2, saturation=0.6, linewidth=0.8, fliersize=0.5) # whiskers at 5th and 95th percentile)
+ #errorbar='sd', capsize=0.1, errwidth=1.5)
+
+# no frame
+ax.spines['top'].set_visible(False)
+ax.spines['right'].set_visible(False)
+ax.spines['left'].set_visible(False)
+# add arrow on y axis
+ax.annotate('', xy=(0, 1.05), xytext=(0, -0.01), arrowprops=dict(arrowstyle='->', lw=1, color='black'), xycoords='axes fraction')
+
+
+plt.title('')
+if metric == 'dice':
+ plt.ylabel('Dice score', fontsize=18)
+elif metric == 'assd':
+ plt.ylabel('ASSD', fontsize=18)
+plt.xlabel('')
+plt.xticks(rotation=45, fontsize=16)
+plt.yticks(fontsize=14)
+
+# axis thickness
+ax.spines['bottom'].set_linewidth(1)
+ax.spines['left'].set_linewidth(1)
+
+
+# change to log scale
+if metric == 'assd':
+ plt.yscale('log')
+
+# set legend names
+ax.legend(score_vars, fontsize=14)
+
+# legend on top in a row, without frame
+plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.4), ncol=2, fontsize=14, frameon=False)
+
+# Define pairs between models for each modality
+box_pairs = []
+
+# Add statistical annotations for each modality
+for modality in ['All']+modality_list:
+ # Define pairs between models within the same modality
+ box_pairs += [((modality, 'BiomedParse'), (modality, 'MedSAM (oracle box)'))]
+annotator = Annotator(ax, box_pairs, data=df_long, x='modality', y='Performance', hue='Model',
+ order=['All']+modality_list)
+annotator.configure(test='t-test_paired', text_format='star', loc='inside', hide_non_significant=True)
+annotator.apply_test(alternative='less')
+annotator.annotate()
+
+plt.tight_layout()
+
+# save the plot
+ax.get_figure().savefig(f'plots/{metric}_comparison.png')
+ax.get_figure().savefig(f'plots/{metric}_comparison.pdf', bbox_inches='tight')
\ No newline at end of file
diff --git a/figures/main_figure_3b.py b/figures/main_figure_3b.py
new file mode 100644
index 0000000000000000000000000000000000000000..a857ee1cacd8fbb840fcdfa2e754ff1203ad287c
--- /dev/null
+++ b/figures/main_figure_3b.py
@@ -0,0 +1,83 @@
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+import json, os
+
+from statannot import add_stat_annotation
+from statannotations.Annotator import Annotator
+
+df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
+
+# modify modality names
+mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
+df['modality'] = df['modality'].apply(lambda x: mod_names[x])
+
+# MedSAM reported tasks
+reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
+
+# find overlap between the dfs by dataset and target
+overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
+ suffixes=('_biomedparse', '_baseline'))
+# non-overlapping datasets
+non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
+
+
+
+baseline = 'medsam'
+metric = 'box_ratio'
+
+baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
+metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
+ 'IRI': 'Inversed Rotational Inertia'}
+
+non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
+# scatter plot
+fig, ax = plt.subplots(figsize=(7,5))
+sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
+
+# add linear regression line
+sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
+ color='k', line_kws={'linestyle':'--', 'linewidth':1})
+
+# remove all spines
+ax.spines['top'].set_visible(False)
+ax.spines['right'].set_visible(False)
+ax.spines['left'].set_visible(False)
+ax.spines['bottom'].set_visible(False)
+
+
+# add arrow on x-axis and y-axis
+xlim = [0, 0.85]
+ylim = [-0.18, 0.75]
+ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.set_xlim(xlim)
+ax.set_ylim(ylim)
+
+ax.xaxis.set_tick_params(width=1.5)
+ax.yaxis.set_tick_params(width=1.5)
+
+# set x-ticks and y-ticks
+plt.xticks(fontsize=18)
+plt.yticks(fontsize=18)
+
+# show R^2 value, p value, and equation of the line
+from scipy.stats import linregress
+slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
+x_text = 0.4
+plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
+
+plt.title('')
+plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
+plt.xlabel(f'{metric_names[metric]}', fontsize=22)
+
+plt.tight_layout()
+
+# save the plot
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
+
diff --git a/figures/main_figure_3c.py b/figures/main_figure_3c.py
new file mode 100644
index 0000000000000000000000000000000000000000..e51c4861c9f0b2dc8303dd98960ddc1657cd66fb
--- /dev/null
+++ b/figures/main_figure_3c.py
@@ -0,0 +1,83 @@
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+import json, os
+
+from statannot import add_stat_annotation
+from statannotations.Annotator import Annotator
+
+df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
+
+# modify modality names
+mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
+df['modality'] = df['modality'].apply(lambda x: mod_names[x])
+
+# MedSAM reported tasks
+reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
+
+# find overlap between the dfs by dataset and target
+overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
+ suffixes=('_biomedparse', '_baseline'))
+# non-overlapping datasets
+non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
+
+
+
+baseline = 'medsam'
+metric = 'convex_ratio'
+
+baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
+metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
+ 'IRI': 'Inversed Rotational Inertia'}
+
+non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
+# scatter plot
+fig, ax = plt.subplots(figsize=(7,5))
+sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
+
+# add linear regression line
+sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
+ color='k', line_kws={'linestyle':'--', 'linewidth':1})
+
+# remove all spines
+ax.spines['top'].set_visible(False)
+ax.spines['right'].set_visible(False)
+ax.spines['left'].set_visible(False)
+ax.spines['bottom'].set_visible(False)
+
+
+# add arrow on x-axis and y-axis
+xlim = [0, 1.05]
+ylim = [-0.18, 0.75]
+ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.set_xlim(xlim)
+ax.set_ylim(ylim)
+
+ax.xaxis.set_tick_params(width=1.5)
+ax.yaxis.set_tick_params(width=1.5)
+
+# set x-ticks and y-ticks
+plt.xticks(fontsize=18)
+plt.yticks(fontsize=18)
+
+# show R^2 value, p value, and equation of the line
+from scipy.stats import linregress
+slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
+x_text = 0.4
+plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
+
+plt.title('')
+plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
+plt.xlabel(f'{metric_names[metric]}', fontsize=22)
+
+plt.tight_layout()
+plt.show()
+
+# save the plot
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
diff --git a/figures/main_figure_3d.py b/figures/main_figure_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..247c3d737b4c3466bb87a3025730620980cf4d85
--- /dev/null
+++ b/figures/main_figure_3d.py
@@ -0,0 +1,83 @@
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+import json, os
+
+from statannot import add_stat_annotation
+from statannotations.Annotator import Annotator
+
+df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
+
+# modify modality names
+mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
+df['modality'] = df['modality'].apply(lambda x: mod_names[x])
+
+# MedSAM reported tasks
+reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
+
+# find overlap between the dfs by dataset and target
+overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
+ suffixes=('_biomedparse', '_baseline'))
+# non-overlapping datasets
+non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
+
+
+
+baseline = 'medsam'
+metric = 'IRI'
+
+baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
+metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
+ 'IRI': 'Inversed Rotational Inertia'}
+
+non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
+# scatter plot
+fig, ax = plt.subplots(figsize=(7,5))
+sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
+
+# add linear regression line
+sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
+ color='k', line_kws={'linestyle':'--', 'linewidth':1})
+
+# remove all spines
+ax.spines['top'].set_visible(False)
+ax.spines['right'].set_visible(False)
+ax.spines['left'].set_visible(False)
+ax.spines['bottom'].set_visible(False)
+
+
+# add arrow on x-axis and y-axis
+xlim = [0, 1.05]
+ylim = [-0.18, 0.75]
+ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.set_xlim(xlim)
+ax.set_ylim(ylim)
+
+ax.xaxis.set_tick_params(width=1.5)
+ax.yaxis.set_tick_params(width=1.5)
+
+# set x-ticks and y-ticks
+plt.xticks(fontsize=18)
+plt.yticks(fontsize=18)
+
+# show R^2 value, p value, and equation of the line
+from scipy.stats import linregress
+slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
+x_text = 0.4
+plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
+
+plt.title('')
+plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
+plt.xlabel(f'{metric_names[metric]}', fontsize=22)
+
+plt.tight_layout()
+plt.show()
+
+# save the plot
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
diff --git a/figures/plots/IRI_mean_improvement_medsam.pdf b/figures/plots/IRI_mean_improvement_medsam.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..0a2dd3f1e83b079c72420e9c984ef25818a43216
Binary files /dev/null and b/figures/plots/IRI_mean_improvement_medsam.pdf differ
diff --git a/figures/plots/IRI_mean_improvement_medsam.png b/figures/plots/IRI_mean_improvement_medsam.png
new file mode 100644
index 0000000000000000000000000000000000000000..d2f7118fc22fb5fdc8ac33d69d72711ef5d0887f
Binary files /dev/null and b/figures/plots/IRI_mean_improvement_medsam.png differ
diff --git a/figures/plots/IRI_mean_improvement_sam.pdf b/figures/plots/IRI_mean_improvement_sam.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..664e674221e937e17b7542a2c7a619f3d79987aa
Binary files /dev/null and b/figures/plots/IRI_mean_improvement_sam.pdf differ
diff --git a/figures/plots/IRI_mean_improvement_sam.png b/figures/plots/IRI_mean_improvement_sam.png
new file mode 100644
index 0000000000000000000000000000000000000000..0491d24f7897718bb97da6642c2774da3269239b
Binary files /dev/null and b/figures/plots/IRI_mean_improvement_sam.png differ
diff --git a/figures/plots/area_vs_dice.pdf b/figures/plots/area_vs_dice.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..ec85c6adc02ea6f72137a8cd7298e84ae23af5d0
Binary files /dev/null and b/figures/plots/area_vs_dice.pdf differ
diff --git a/figures/plots/assd_comparison.pdf b/figures/plots/assd_comparison.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..b5944d6cc6303909e3d6e59667e6aafd4955c668
Binary files /dev/null and b/figures/plots/assd_comparison.pdf differ
diff --git a/figures/plots/assd_comparison.png b/figures/plots/assd_comparison.png
new file mode 100644
index 0000000000000000000000000000000000000000..fdbb94d253e1ad938210c22752b97097c8431ca1
Binary files /dev/null and b/figures/plots/assd_comparison.png differ
diff --git a/figures/plots/box_ratio_mean_improvement_medsam.pdf b/figures/plots/box_ratio_mean_improvement_medsam.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..5fda44cc07f89d4f0925e50cde640349cf36affe
Binary files /dev/null and b/figures/plots/box_ratio_mean_improvement_medsam.pdf differ
diff --git a/figures/plots/box_ratio_mean_improvement_medsam.png b/figures/plots/box_ratio_mean_improvement_medsam.png
new file mode 100644
index 0000000000000000000000000000000000000000..325a398c715532b8fa918c5d2effe03d50e481a5
Binary files /dev/null and b/figures/plots/box_ratio_mean_improvement_medsam.png differ
diff --git a/figures/plots/box_ratio_mean_improvement_sam.pdf b/figures/plots/box_ratio_mean_improvement_sam.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..e398bb2f15923d2bfc1e8b4b310be824c6413a3a
Binary files /dev/null and b/figures/plots/box_ratio_mean_improvement_sam.pdf differ
diff --git a/figures/plots/box_ratio_mean_improvement_sam.png b/figures/plots/box_ratio_mean_improvement_sam.png
new file mode 100644
index 0000000000000000000000000000000000000000..f30c8c394baab33c16db3769cf50b284c7158a7b
Binary files /dev/null and b/figures/plots/box_ratio_mean_improvement_sam.png differ
diff --git a/figures/plots/convex_ratio_mean_improvement_medsam.pdf b/figures/plots/convex_ratio_mean_improvement_medsam.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..2576536f9bdba1c11dade7a67d29df4f989b4282
Binary files /dev/null and b/figures/plots/convex_ratio_mean_improvement_medsam.pdf differ
diff --git a/figures/plots/convex_ratio_mean_improvement_medsam.png b/figures/plots/convex_ratio_mean_improvement_medsam.png
new file mode 100644
index 0000000000000000000000000000000000000000..4d8094f07f7db2857caa9e6d1b4e300b51b46457
Binary files /dev/null and b/figures/plots/convex_ratio_mean_improvement_medsam.png differ
diff --git a/figures/plots/convex_ratio_mean_improvement_sam.pdf b/figures/plots/convex_ratio_mean_improvement_sam.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..21546b1844c66a484d9a1152a83894d20d232a85
Binary files /dev/null and b/figures/plots/convex_ratio_mean_improvement_sam.pdf differ
diff --git a/figures/plots/convex_ratio_mean_improvement_sam.png b/figures/plots/convex_ratio_mean_improvement_sam.png
new file mode 100644
index 0000000000000000000000000000000000000000..3fb0d30e57cf98ed04b5646bbaacbd843684aca0
Binary files /dev/null and b/figures/plots/convex_ratio_mean_improvement_sam.png differ
diff --git a/figures/plots/data_dist_modality_bar.pdf b/figures/plots/data_dist_modality_bar.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..52c7217706c3261cf04c4fca595e3982ceede29a
Binary files /dev/null and b/figures/plots/data_dist_modality_bar.pdf differ
diff --git a/figures/plots/data_target_modality.pdf b/figures/plots/data_target_modality.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..a8319186280ff04584ff6cc731f5f649aed8bdbd
Binary files /dev/null and b/figures/plots/data_target_modality.pdf differ
diff --git a/figures/plots/dice_comparison.pdf b/figures/plots/dice_comparison.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..e6b13ca50ec6362e8d2f9d567f1899c1b8ed17e8
Binary files /dev/null and b/figures/plots/dice_comparison.pdf differ
diff --git a/figures/plots/dice_comparison.png b/figures/plots/dice_comparison.png
new file mode 100644
index 0000000000000000000000000000000000000000..1e26aeb48f83ff1ddefc4cbabac795b20e3d842f
Binary files /dev/null and b/figures/plots/dice_comparison.png differ
diff --git a/figures/plots/figure_1a.pdf b/figures/plots/figure_1a.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..6039d48ab6b88fa347980d576a3605398abee06c
Binary files /dev/null and b/figures/plots/figure_1a.pdf differ
diff --git a/figures/supplementary_figure_2.py b/figures/supplementary_figure_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..145756bda1f5a12c5dd227ed723252088a091709
--- /dev/null
+++ b/figures/supplementary_figure_2.py
@@ -0,0 +1,78 @@
+# %%
+import os
+import json
+import numpy as np
+import seaborn as sns
+from scipy.stats import boxcox
+from pycirclize import Circos
+import matplotlib.pyplot as plt
+
+base_dir = 'metadata'
+with open(os.path.join(base_dir,'hierarchy.json'), 'r') as f:
+ hierarchy_data = json.load(f)
+
+with open(os.path.join(base_dir,'target_counts.json'), 'r') as f:
+ target_counts = json.load(f)
+
+with open(os.path.join(base_dir,'modality_counts.json'), 'r') as f:
+ modality_counts = json.load(f)
+
+# color scheme
+sectors = {k: len(v) for k,v in modality_counts.items()}
+name2color = {
+ "MRI": "#005A9E",
+ "CT": "#FF7F00",
+ "pathology": "#984EA3",
+ "ultrasound": "#7BC8F6",
+ "X-Ray": "#999999",
+ "fundus": "#76B041",
+ "dermoscopy": "#FDBF6F",
+ "endoscope": "#C0392B",
+ "OCT": "#33A02C",
+}
+
+def generate_shades(base_color, n):
+ return sns.light_palette(base_color, n + 2)[1:-1]
+
+color_schemes = {}
+for sector in sectors:
+ child_colors = generate_shades(name2color[sector], len(modality_counts[sector]))
+ color_schemes[sector] = child_colors
+
+parent_track_ratio = (72, 85)
+middle_track_ratio = (85, 100)
+bar_track_ratio = (45, 70)
+parent_track_font_size = 7
+middle_track_font_size = 5.5
+bar_track_font_size = 7
+
+circos = Circos(sectors, space=6)
+for sector in circos.sectors:
+ track = sector.add_track(parent_track_ratio)
+ track.axis(fc=name2color[sector.name], lw=0)
+ track.text(sector.name.capitalize().replace('Mri', 'MRI').replace('Ct', 'CT').replace('Oct', 'OCT').replace('Dermoscopy', "DS"), color="white", size=parent_track_font_size)
+
+ track1 = sector.add_track(middle_track_ratio, r_pad_ratio=0.1)
+ sect_start = 0
+ color_idx = 0
+ for k,v in modality_counts[sector.name].items():
+ sect_size = 1
+ track1.rect(sect_start, sect_start+sect_size, r_lim=(middle_track_ratio[0], middle_track_ratio[1]-1) , ec="black", lw=0,fc=color_schemes[sector.name][color_idx])
+ color_idx += 1
+ track1.text(k.capitalize(), sect_start+sect_size/2, color="black", size=middle_track_font_size)
+ sect_start += sect_size
+
+ x = np.linspace(sector.start+0.5, sector.end-0.5, int(sector.size))
+ y = [v for k,v in modality_counts[sector.name].items()]
+ y_box = boxcox(y, 0.35)
+
+ track2 = sector.add_track(bar_track_ratio, r_pad_ratio=0.1)
+ track2.axis()
+ track2.yticks([1.14, 2.29, 3.43, 4.58], ["10$^2$", "10$^3$", "10$^4$", "10$^5$"], label_size=bar_track_font_size-1)
+ track2.bar(x, y_box, color=name2color[sector.name], alpha=0.5, align="center", lw=0)
+
+fig = circos.plotfig()
+fig.savefig('plots/data_target_modality.pdf')
+plt.show()
+
+# %%
diff --git a/figures/supplementary_figure_ASSD.py b/figures/supplementary_figure_ASSD.py
new file mode 100644
index 0000000000000000000000000000000000000000..9be4df81c78baeac4baa753c129b66ed65089971
--- /dev/null
+++ b/figures/supplementary_figure_ASSD.py
@@ -0,0 +1,95 @@
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+import json, os
+
+from statannot import add_stat_annotation
+from statannotations.Annotator import Annotator
+
+df = pd.read_csv('results/all_eval/all_metrics_median.csv')
+
+
+metric = 'assd'
+
+model_names = {metric: 'BiomedParse', f'medsam_{metric}': 'MedSAM (oracle box)', f'sam_{metric}': 'SAM (oracle box)',
+ f'dino_medsam_{metric}': 'MedSAM (Grounding DINO)', f'dino_sam_{metric}': 'SAM (Grounding DINO)'}
+df = df.rename(columns=model_names)
+
+score_vars = list(model_names.values())
+
+# filter outlier values
+df = df[df['MedSAM (oracle box)'] < 1e10]
+
+modality_list = ['CT', 'MRI', 'X-Ray', 'Pathology', 'Ultrasound', 'Fundus', 'Endoscope', 'Dermoscopy', 'OCT']
+# modify modality names
+mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
+df['modality'] = df['modality'].apply(lambda x: mod_names[x])
+
+# add an "All" modality
+all_df = df.copy()
+all_df['modality'] = 'All'
+df = pd.concat([df, all_df])
+
+df_long = df[['modality', 'task']+score_vars].melt(id_vars=['modality', 'task'], var_name='Model', value_name='Performance')
+
+
+
+# add statistical annotations
+fig, ax = plt.subplots(figsize=(9, 6))
+ax = sns.boxplot(data=df_long, x='modality', y='Performance', hue='Model', ax=ax, palette='Set2',
+ order=['All']+modality_list,
+ whis=2, saturation=0.6, linewidth=0.8, fliersize=0.5) # whiskers at 5th and 95th percentile)
+ #errorbar='sd', capsize=0.1, errwidth=1.5)
+
+# no frame
+ax.spines['top'].set_visible(False)
+ax.spines['right'].set_visible(False)
+ax.spines['left'].set_visible(False)
+# add arrow on y axis
+ax.annotate('', xy=(0, 1.05), xytext=(0, -0.01), arrowprops=dict(arrowstyle='->', lw=1, color='black'), xycoords='axes fraction')
+
+
+plt.title('')
+if metric == 'dice':
+ plt.ylabel('Dice score', fontsize=18)
+elif metric == 'assd':
+ plt.ylabel('ASSD', fontsize=18)
+plt.xlabel('')
+plt.xticks(rotation=45, fontsize=16)
+plt.yticks(fontsize=14)
+
+# axis thickness
+ax.spines['bottom'].set_linewidth(1)
+ax.spines['left'].set_linewidth(1)
+
+
+# change to log scale
+if metric == 'assd':
+ plt.yscale('log')
+
+# set legend names
+ax.legend(score_vars, fontsize=14)
+
+# legend on top in a row, without frame
+plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.4), ncol=2, fontsize=14, frameon=False)
+
+# Define pairs between models for each modality
+box_pairs = []
+
+# Add statistical annotations for each modality
+for modality in ['All']+modality_list:
+ # Define pairs between models within the same modality
+ box_pairs += [((modality, 'BiomedParse'), (modality, 'MedSAM (oracle box)'))]
+annotator = Annotator(ax, box_pairs, data=df_long, x='modality', y='Performance', hue='Model',
+ order=['All']+modality_list)
+annotator.configure(test='t-test_paired', text_format='star', loc='inside', hide_non_significant=True)
+annotator.apply_test(alternative='less')
+annotator.annotate()
+
+plt.tight_layout()
+
+# save the plot
+ax.get_figure().savefig(f'plots/{metric}_comparison.png')
+ax.get_figure().savefig(f'plots/{metric}_comparison.pdf', bbox_inches='tight')
\ No newline at end of file
diff --git a/figures/supplementary_figure_IRI_sam.py b/figures/supplementary_figure_IRI_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7ac3a4d5164a3848c703ed570e5e2f305c8dc7a
--- /dev/null
+++ b/figures/supplementary_figure_IRI_sam.py
@@ -0,0 +1,83 @@
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+import json, os
+
+from statannot import add_stat_annotation
+from statannotations.Annotator import Annotator
+
+df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
+
+# modify modality names
+mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
+df['modality'] = df['modality'].apply(lambda x: mod_names[x])
+
+# MedSAM reported tasks
+reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
+
+# find overlap between the dfs by dataset and target
+overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
+ suffixes=('_biomedparse', '_baseline'))
+# non-overlapping datasets
+non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
+
+
+
+baseline = 'sam'
+metric = 'IRI'
+
+baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
+metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
+ 'IRI': 'Inversed Rotational Inertia'}
+
+non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
+# scatter plot
+fig, ax = plt.subplots(figsize=(7,5))
+sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
+
+# add linear regression line
+sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
+ color='k', line_kws={'linestyle':'--', 'linewidth':1})
+
+# remove all spines
+ax.spines['top'].set_visible(False)
+ax.spines['right'].set_visible(False)
+ax.spines['left'].set_visible(False)
+ax.spines['bottom'].set_visible(False)
+
+
+# add arrow on x-axis and y-axis
+xlim = [0, 1.05]
+ylim = [-0.06, 0.79]
+ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.set_xlim(xlim)
+ax.set_ylim(ylim)
+
+ax.xaxis.set_tick_params(width=1.5)
+ax.yaxis.set_tick_params(width=1.5)
+
+# set x-ticks and y-ticks
+plt.xticks(fontsize=18)
+plt.yticks(fontsize=18)
+
+# show R^2 value, p value, and equation of the line
+from scipy.stats import linregress
+slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
+x_text = 0.4
+plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
+
+plt.title('')
+plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
+plt.xlabel(f'{metric_names[metric]}', fontsize=22)
+
+plt.tight_layout()
+
+# save the plot
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
+
diff --git a/figures/supplementary_figure_box_sam.py b/figures/supplementary_figure_box_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..41cba1fec0e33274fcb61625e62403ef86e8cb9f
--- /dev/null
+++ b/figures/supplementary_figure_box_sam.py
@@ -0,0 +1,83 @@
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+import json, os
+
+from statannot import add_stat_annotation
+from statannotations.Annotator import Annotator
+
+df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
+
+# modify modality names
+mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
+df['modality'] = df['modality'].apply(lambda x: mod_names[x])
+
+# MedSAM reported tasks
+reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
+
+# find overlap between the dfs by dataset and target
+overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
+ suffixes=('_biomedparse', '_baseline'))
+# non-overlapping datasets
+non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
+
+
+
+baseline = 'sam'
+metric = 'box_ratio'
+
+baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
+metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
+ 'IRI': 'Inversed Rotational Inertia'}
+
+non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
+# scatter plot
+fig, ax = plt.subplots(figsize=(7,5))
+sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
+
+# add linear regression line
+sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
+ color='k', line_kws={'linestyle':'--', 'linewidth':1})
+
+# remove all spines
+ax.spines['top'].set_visible(False)
+ax.spines['right'].set_visible(False)
+ax.spines['left'].set_visible(False)
+ax.spines['bottom'].set_visible(False)
+
+
+# add arrow on x-axis and y-axis
+xlim = [0, 0.85]
+ylim = [-0.06, 0.79]
+ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.set_xlim(xlim)
+ax.set_ylim(ylim)
+
+ax.xaxis.set_tick_params(width=1.5)
+ax.yaxis.set_tick_params(width=1.5)
+
+# set x-ticks and y-ticks
+plt.xticks(fontsize=18)
+plt.yticks(fontsize=18)
+
+# show R^2 value, p value, and equation of the line
+from scipy.stats import linregress
+slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
+x_text = 0.4
+plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
+
+plt.title('')
+plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
+plt.xlabel(f'{metric_names[metric]}', fontsize=22)
+
+plt.tight_layout()
+
+# save the plot
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
+
diff --git a/figures/supplementary_figure_convex_sam.py b/figures/supplementary_figure_convex_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..40f35331e1cb68fa747d0338f6c00035bc151807
--- /dev/null
+++ b/figures/supplementary_figure_convex_sam.py
@@ -0,0 +1,83 @@
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+import json, os
+
+from statannot import add_stat_annotation
+from statannotations.Annotator import Annotator
+
+df = pd.read_csv('results/all_eval/all_metrics_mean.csv')
+
+# modify modality names
+mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
+ 'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
+df['modality'] = df['modality'].apply(lambda x: mod_names[x])
+
+# MedSAM reported tasks
+reported_baseline_df = pd.read_csv('results/all_eval/reported_baseline_tasks.csv')
+
+# find overlap between the dfs by dataset and target
+overlap_df = pd.merge(df, reported_baseline_df, on=['task', 'modality', 'site', 'target'],
+ suffixes=('_biomedparse', '_baseline'))
+# non-overlapping datasets
+non_overlap_df = df[~df['task'].isin(overlap_df['task'])]
+
+
+
+baseline = 'sam'
+metric = 'convex_ratio'
+
+baseline_names = {'medsam': 'MedSAM', 'sam': 'SAM'}
+metric_names = {'box_ratio': 'Box Ratio', 'convex_ratio': 'Convex Ratio',
+ 'IRI': 'Inversed Rotational Inertia'}
+
+non_overlap_df['diff'] = non_overlap_df[f'dice'] - non_overlap_df[f'{baseline}_dice']
+# scatter plot
+fig, ax = plt.subplots(figsize=(7,5))
+sns.scatterplot(data=non_overlap_df, x=metric, y='diff', ax=ax, markers='o', s=80)
+
+# add linear regression line
+sns.regplot(data=non_overlap_df, x=metric, y='diff', ax=ax, scatter=False,
+ color='k', line_kws={'linestyle':'--', 'linewidth':1})
+
+# remove all spines
+ax.spines['top'].set_visible(False)
+ax.spines['right'].set_visible(False)
+ax.spines['left'].set_visible(False)
+ax.spines['bottom'].set_visible(False)
+
+
+# add arrow on x-axis and y-axis
+xlim = [0, 1.05]
+ylim = [-0.06, 0.79]
+ax.annotate('', xy=(xlim[1], ylim[0]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.annotate('', xy=(xlim[0], ylim[1]), xytext=(xlim[0], ylim[0]), arrowprops=dict(arrowstyle='->', lw=1.5))
+ax.set_xlim(xlim)
+ax.set_ylim(ylim)
+
+ax.xaxis.set_tick_params(width=1.5)
+ax.yaxis.set_tick_params(width=1.5)
+
+# set x-ticks and y-ticks
+plt.xticks(fontsize=18)
+plt.yticks(fontsize=18)
+
+# show R^2 value, p value, and equation of the line
+from scipy.stats import linregress
+slope, intercept, r_value, p_value, std_err = linregress(non_overlap_df[metric], non_overlap_df['diff'])
+x_text = 0.4
+plt.text(x_text, 0.84, f'$R^2={r_value**2:.2f}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.77, f'$p={p_value:.2e}$', fontsize=20, transform=ax.transAxes)
+plt.text(x_text, 0.7, f'$y={slope:.2f}x+{intercept:.2f}$', fontsize=20, transform=ax.transAxes)
+
+plt.title('')
+plt.ylabel(f'Improvement over {baseline_names[baseline]}', fontsize=20)
+plt.xlabel(f'{metric_names[metric]}', fontsize=22)
+
+plt.tight_layout()
+
+# save the plot
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.png')
+ax.get_figure().savefig(f'plots/{metric}_mean_improvement_{baseline}.pdf', bbox_inches='tight')
+
diff --git a/figures/supplementary_figure_dice_by_area.py b/figures/supplementary_figure_dice_by_area.py
new file mode 100644
index 0000000000000000000000000000000000000000..86cf4117922187901a79ffdbd8efe7e525933db7
--- /dev/null
+++ b/figures/supplementary_figure_dice_by_area.py
@@ -0,0 +1,122 @@
+#%%
+import os
+import json
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+from scipy.stats import sem
+
+# Define file paths
+base_dir = 'results'
+eval_results_path = os.path.join(base_dir, 'all_eval/biomedparse_eval_results.json')
+
+# Load data
+with open(eval_results_path, 'r') as f:
+ parsed_data = json.load(f)
+
+# Extract relevant information
+def extract_data(parsed_data):
+ records = []
+ for dataset in parsed_data:
+ dataset_name = dataset[len('biomed_'):-len('_test/grounding_refcoco')]
+ instances = parsed_data[dataset]["grounding"]["instance_results"]
+ for instance in instances:
+ metadata = instance["metadata"]
+ grounding_info = metadata["grounding_info"][0]
+ record = {
+ "dataset": dataset_name,
+ "file_name": grounding_info["mask_file"].split("/")[-1],
+ "area": grounding_info["area"],
+ "bp_dice": instance["Dice"][0]
+ }
+ records.append(record)
+ return pd.DataFrame(records)
+
+df = extract_data(parsed_data)
+
+# Merge with SAM and MedSAM data
+def merge_with_sam_medsam(df, parsed_data, base_dir):
+ comparison_df = pd.DataFrame()
+ for dataset in parsed_data:
+ dataset_name = dataset[len('biomed_'):-len('_test/grounding_refcoco')]
+ if any(sub in dataset_name for sub in ['MSD', 'Radiography', 'amos22']):
+ dataset_name = dataset_name.replace('-', '/')
+
+ sam_data_path = os.path.join(base_dir, dataset_name, 'test_sam_vit_b_01ec64_dice.csv')
+ medsam_data_path = os.path.join(base_dir, dataset_name, 'test_medsam_dice.csv')
+
+ sam_data = pd.read_csv(sam_data_path, delimiter=',')
+ medsam_data = pd.read_csv(medsam_data_path, delimiter=',')
+
+ merged_data = pd.merge(sam_data, medsam_data, on='image', suffixes=('_sam', '_medsam'))
+ merged_data.rename(columns={'image': 'file_name'}, inplace=True)
+ merged_data['dataset'] = dataset_name.replace('/', '-')
+
+ comparison_df = pd.concat([comparison_df, merged_data], ignore_index=True)
+
+ return pd.merge(df, comparison_df, on=['dataset', 'file_name'], how='inner')
+
+df = merge_with_sam_medsam(df, parsed_data, os.path.join(base_dir, 'dataset_results'))
+
+# Save to CSV
+df.to_csv(os.path.join(base_dir, 'all_eval/dice_by_size.csv'), index=False)
+
+# Filter datasets
+rad_list = [
+ 'ACDC', 'COVID-QU-Ex', 'CXR_Masks_and_Labels', 'LGG', 'LIDC-IDRI', 'MMs',
+ 'MSD-Task01_BrainTumour', 'MSD-Task02_Heart', 'MSD-Task03_Liver', 'MSD-Task04_Hippocampus',
+ 'MSD-Task05_Prostate', 'MSD-Task06_Lung', 'MSD-Task07_Pancreas', 'MSD-Task08_HepaticVessel',
+ 'MSD-Task09_Spleen', 'MSD-Task10_Colon', 'PROMISE12', 'QaTa-COV19', 'Radiography-COVID',
+ 'Radiography-Lung_Opacity', 'Radiography-Normal', 'Radiography-Viral_Pneumonia',
+ 'amos22-CT', 'amos22-MRI', 'kits23', 'COVID-19_CT'
+]
+df = df[df['dataset'].isin(rad_list)]
+
+# Plot area to Dice ratio
+def plot_area_to_dice(df):
+ sns.set_theme(style='ticks')
+
+ total_image_area = 1024 * 1024 # pixels
+ max_area_threshold = total_image_area # Adjust this threshold as needed
+ filtered_df = df[df['area'] <= max_area_threshold]
+
+ filtered_df['area_percentage'] = (filtered_df['area'] / total_image_area) * 100
+
+ bins = np.linspace(filtered_df['area_percentage'].min(), filtered_df['area_percentage'].max(), 15)
+ filtered_df['area_bin'] = pd.cut(filtered_df['area_percentage'], bins)
+
+ avg_dice_bp = filtered_df.groupby('area_bin')['bp_dice'].mean()
+ avg_dice_sam = filtered_df.groupby('area_bin')['dice_sam'].mean() if 'dice_sam' in filtered_df.columns else None
+ avg_dice_medsam = filtered_df.groupby('area_bin')['dice_medsam'].mean() if 'dice_medsam' in filtered_df.columns else None
+
+ sem_dice_bp = filtered_df.groupby('area_bin')['bp_dice'].apply(sem)
+ sem_dice_sam = filtered_df.groupby('area_bin')['dice_sam'].apply(sem) if 'dice_sam' in filtered_df.columns else None
+ sem_dice_medsam = filtered_df.groupby('area_bin')['dice_medsam'].apply(sem) if 'dice_medsam' in filtered_df.columns else None
+
+ colors = sns.color_palette("colorblind", 3)
+
+ plt.figure(figsize=(14, 10))
+
+ plt.errorbar(avg_dice_bp.index.categories.mid, avg_dice_bp, yerr=sem_dice_bp, fmt='-o', label='BiomedParse', color=colors[0], capsize=5)
+ if avg_dice_sam is not None:
+ plt.errorbar(avg_dice_sam.index.categories.mid, avg_dice_sam, yerr=sem_dice_sam, fmt='-o', label='SAM', color=colors[1], capsize=5)
+ if avg_dice_medsam is not None:
+ plt.errorbar(avg_dice_medsam.index.categories.mid, avg_dice_medsam, yerr=sem_dice_medsam, fmt='-o', label='MedSAM', color=colors[2], capsize=5)
+
+ plt.xlabel('Area (% of total image)', fontsize=20)
+ plt.ylabel('Dice Score', fontsize=20)
+ plt.grid(False)
+ plt.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5, 1.08), ncol=3, frameon=False)
+ plt.xticks(fontsize=20)
+ plt.yticks(fontsize=20)
+ plt.xlim(filtered_df['area_percentage'].min(), filtered_df['area_percentage'].max())
+ sns.despine()
+
+ plt.tight_layout()
+ plt.savefig(os.path.join('plots/area_vs_dice.pdf'), dpi=300)
+ plt.show()
+
+plot_area_to_dice(df)
+
+# %%
diff --git a/inference_utils/__init__.py b/inference_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inference_utils/inference.py b/inference_utils/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0cf0659406a0a1d464c9fc123ba24976630319e
--- /dev/null
+++ b/inference_utils/inference.py
@@ -0,0 +1,149 @@
+import torch
+import numpy as np
+import torch.nn.functional as F
+from PIL import Image
+from torchvision import transforms
+#from utils.visualizer import Visualizer
+# from detectron2.utils.colormap import random_color
+# from detectron2.data import MetadataCatalog
+# from detectron2.structures import BitMasks
+from modeling.language.loss import vl_similarity
+from utilities.constants import BIOMED_CLASSES
+#from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
+
+# import cv2
+# import os
+# import glob
+# import subprocess
+from PIL import Image
+import random
+
+t = []
+t.append(transforms.Resize((1024, 1024), interpolation=Image.BICUBIC))
+transform = transforms.Compose(t)
+#metadata = MetadataCatalog.get('coco_2017_train_panoptic')
+all_classes = ['background'] + [name.replace('-other','').replace('-merged','')
+ for name in BIOMED_CLASSES] + ["others"]
+# colors_list = [(np.array(color['color'])/255).tolist() for color in COCO_CATEGORIES] + [[1, 1, 1]]
+
+# use color list from matplotlib
+import matplotlib.colors as mcolors
+colors = dict(mcolors.TABLEAU_COLORS, **mcolors.BASE_COLORS)
+colors_list = [list(colors.values())[i] for i in range(16)]
+
+from .output_processing import mask_stats, combine_masks
+
+
+@torch.no_grad()
+def interactive_infer_image(model, image, prompts):
+
+ image_resize = transform(image)
+ width = image.size[0]
+ height = image.size[1]
+ image_resize = np.asarray(image_resize)
+ image = torch.from_numpy(image_resize.copy()).permute(2,0,1)
+
+ data = {"image": image, 'text': prompts, "height": height, "width": width}
+
+ # inistalize task
+ model.model.task_switch['spatial'] = False
+ model.model.task_switch['visual'] = False
+ model.model.task_switch['grounding'] = True
+ model.model.task_switch['audio'] = False
+ model.model.task_switch['grounding'] = True
+
+
+ batch_inputs = [data]
+ results,image_size,extra = model.model.evaluate_demo(batch_inputs)
+
+ pred_masks = results['pred_masks'][0]
+ v_emb = results['pred_captions'][0]
+ t_emb = extra['grounding_class']
+
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ matched_id = out_prob.max(0)[1]
+ pred_masks_pos = pred_masks[matched_id,:,:]
+ pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
+
+ # interpolate mask to ori size
+ pred_mask_prob = F.interpolate(pred_masks_pos[None,], (data['height'], data['width']),
+ mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
+ pred_masks_pos = (1*(pred_mask_prob > 0.5)).astype(np.uint8)
+
+ return pred_mask_prob
+
+
+
+# def interactive_infer_panoptic_biomedseg(model, image, tasks, reftxt=None):
+# image_ori = transform(image)
+# #mask_ori = image['mask']
+# width = image_ori.size[0]
+# height = image_ori.size[1]
+# image_ori = np.asarray(image_ori)
+# visual = Visualizer(image_ori, metadata=metadata)
+# images = torch.from_numpy(image_ori.copy()).permute(2,0,1)
+
+# data = {"image": images, "height": height, "width": width}
+# if len(tasks) == 0:
+# tasks = ["Panoptic"]
+
+# # inistalize task
+# model.model.task_switch['spatial'] = False
+# model.model.task_switch['visual'] = False
+# model.model.task_switch['grounding'] = False
+# model.model.task_switch['audio'] = False
+
+# # check if reftxt is list of strings
+# assert isinstance(reftxt, list), f"reftxt should be a list of strings, but got {type(reftxt)}"
+# model.model.task_switch['grounding'] = True
+# predicts = {}
+# for i, txt in enumerate(reftxt):
+# data['text'] = txt
+# batch_inputs = [data]
+
+# results,image_size,extra = model.model.evaluate_demo(batch_inputs)
+
+# pred_masks = results['pred_masks'][0]
+# v_emb = results['pred_captions'][0]
+# t_emb = extra['grounding_class']
+
+# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+# temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
+# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+# matched_id = out_prob.max(0)[1]
+# pred_masks_pos = pred_masks[matched_id,:,:]
+# pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
+
+
+# # interpolate mask to ori size
+# #pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy()
+# # masks.append(pred_masks_pos[0])
+# # mask = pred_masks_pos[0]
+# # masks.append(mask)
+# # interpolate mask to ori size
+# pred_mask_prob = F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
+# #pred_masks_pos = 1*(pred_mask_prob > 0.5)
+# predicts[txt] = pred_mask_prob[0]
+
+# masks = combine_masks(predicts)
+
+# predict_mask_stats = {}
+# print(masks.keys())
+# for i, txt in enumerate(masks):
+# mask = masks[txt]
+# demo = visual.draw_binary_mask(mask, color=colors_list[i], text=txt)
+# predict_mask_stats[txt] = mask_stats((predicts[txt]*255), image_ori)
+
+# res = demo.get_image()
+# torch.cuda.empty_cache()
+# # return Image.fromarray(res), stroke_inimg, stroke_refimg
+# return Image.fromarray(res), None, predict_mask_stats
+
diff --git a/inference_utils/output_processing.py b/inference_utils/output_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..17ee3efe677107d9a93a0da5f9d59c70614b4ef1
--- /dev/null
+++ b/inference_utils/output_processing.py
@@ -0,0 +1,91 @@
+import json
+from scipy import stats
+import numpy as np
+
+import huggingface_hub
+
+
+def check_mask_stats(img, mask, modality_type, target):
+ # img: np.array, shape=(H, W, 3) RGB image with pixel values in [0, 255]
+ # mask: np.array, shape=(H, W, 1) mask probability scaled to [0,255] with pixel values in [0, 255]
+ # modality_type: str, see target_dist.json for the list of modality types
+ # target: str, see target_dist.json for the list of targets
+
+ huggingface_hub.hf_hub_download('microsoft/BiomedParse', filename='target_dist.json', local_dir='./inference_utils')
+ huggingface_hub.hf_hub_download('microsoft/BiomedParse', filename="config.yaml", local_dir="./configs")
+ target_dist = json.load(open("inference_utils/target_dist.json"))
+
+ if modality_type not in target_dist:
+ raise ValueError(f"Currently support modality types: {list(target_dist.keys())}")
+
+ if target not in target_dist[modality_type]:
+ raise ValueError(f"Currently support targets for {modality_type}: {list(target_dist[modality_type].keys())}")
+
+ ms = mask_stats(mask, img)
+
+ ps = [stats.ks_1samp([ms[i]], stats.beta(param[0], param[1]).cdf).pvalue for i, param in enumerate(target_dist[modality_type][target])]
+ p_value = np.prod(ps)
+
+ adj_p_value = p_value**0.24 # adjustment for four test products
+
+ return adj_p_value
+
+
+
+def mask_stats(mask, img):
+ # mask is a prediction mask with pixel values in [0, 255] for probability in [0, 1]
+ # img is a RGB image with pixel values in [0, 255]
+ if mask.max() <= 127:
+ return [0, 0, 0, 0]
+ return [mask[mask>=128].mean()/256, img[:,:,0][mask>=128].mean()/256,
+ img[:,:,1][mask>=128].mean()/256, img[:,:,2][mask>=128].mean()/256]
+
+
+
+def combine_masks(predicts):
+ # predicts: a dictionary of pixel probability, {TARGET: pred_prob}
+ pixel_preds = {}
+ target_area = {}
+ target_probs = {}
+ for target in predicts:
+ pred = predicts[target]
+ pred_region = np.where(pred > 0.1)
+ target_area[target] = 0
+ target_probs[target] = 0
+ for (i,j) in zip(*pred_region):
+ if (i,j) not in pixel_preds:
+ pixel_preds[(i,j)] = {}
+ pixel_preds[(i,j)][target] = pred[i,j]
+ target_area[target] += 1
+ target_probs[target] += pred[i,j]
+ for target in predicts:
+ if target_area[target] == 0:
+ continue
+ target_probs[target] /= target_area[target]
+
+ # generate combined masks
+ combined_areas = {t: 0 for t in predicts}
+ for index in pixel_preds:
+ pred_target = sorted(pixel_preds[index].keys(), key=lambda t: pixel_preds[index][t], reverse=True)[0]
+ combined_areas[pred_target] += 1
+
+ # discard targets with small areas
+ discard_targets = []
+ for target in predicts:
+ if combined_areas[target] < 0.6 * target_area[target]:
+ discard_targets.append(target)
+
+ # keep the most confident target
+ most_confident_target = sorted(predicts.keys(), key=lambda t: target_probs[t], reverse=True)[0]
+
+ discard_targets = [t for t in discard_targets if t != most_confident_target]
+
+ masks = {t: np.zeros_like(predicts[t]).astype(np.uint8) for t in predicts if t not in discard_targets}
+ for index in pixel_preds:
+ candidates = [t for t in pixel_preds[index] if t not in discard_targets and pixel_preds[index][t] > 0.5]
+ if len(candidates) == 0:
+ continue
+ pred_target = max(candidates, key=lambda t: pixel_preds[index][t])
+ masks[pred_target][index[0], index[1]] = 1
+
+ return masks
\ No newline at end of file
diff --git a/inference_utils/processing_utils.py b/inference_utils/processing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d47ef98703ee65e69c3b87daf72db36f33c129e3
--- /dev/null
+++ b/inference_utils/processing_utils.py
@@ -0,0 +1,182 @@
+import numpy as np
+from skimage import transform
+import pydicom
+from io import BytesIO
+from PIL import Image
+import nibabel as nib
+import SimpleITK as sitk
+from skimage import measure
+
+
+"""
+ This script contains utility functions for reading and processing different imaging modalities.
+"""
+
+
+CT_WINDOWS = {'abdomen': [-150, 250],
+ 'lung': [-1000, 1000],
+ 'pelvis': [-55, 200],
+ 'liver': [-25, 230],
+ 'colon': [-68, 187],
+ 'pancreas': [-100, 200]}
+
+def process_intensity_image(image_data, is_CT, site=None):
+ # process intensity-based image. If CT, apply site specific windowing
+
+ # image_data: 2D numpy array of shape (H, W)
+
+ # return: 3-channel numpy array of shape (H, W, 3) as model input
+
+ if is_CT:
+ # process image with windowing
+ if site and site in CT_WINDOWS:
+ window = CT_WINDOWS[site]
+ else:
+ raise ValueError(f'Please choose CT site from {CT_WINDOWS.keys()}')
+ lower_bound, upper_bound = window
+ else:
+ # process image with intensity range 0.5-99.5 percentile
+ lower_bound, upper_bound = np.percentile(
+ image_data[image_data > 0], 0.5
+ ), np.percentile(image_data[image_data > 0], 99.5)
+
+ image_data_pre = np.clip(image_data, lower_bound, upper_bound)
+ image_data_pre = (
+ (image_data_pre - image_data_pre.min())
+ / (image_data_pre.max() - image_data_pre.min())
+ * 255.0
+ )
+
+ # pad to square with equal padding on both sides
+ shape = image_data_pre.shape
+ if shape[0] > shape[1]:
+ pad = (shape[0]-shape[1])//2
+ pad_width = ((0,0), (pad, pad))
+ elif shape[0] < shape[1]:
+ pad = (shape[1]-shape[0])//2
+ pad_width = ((pad, pad), (0,0))
+ else:
+ pad_width = None
+
+ if pad_width is not None:
+ image_data_pre = np.pad(image_data_pre, pad_width, 'constant', constant_values=0)
+
+ # resize image to 1024x1024
+ image_size = 1024
+ resize_image = transform.resize(image_data_pre, (image_size, image_size), order=3,
+ mode='constant', preserve_range=True, anti_aliasing=True)
+
+ # convert to 3-channel image
+ resize_image = np.stack([resize_image]*3, axis=-1)
+
+ return resize_image.astype(np.uint8)
+
+
+
+def read_dicom(image_path, is_CT, site=None):
+ # read dicom file and return pixel data
+
+ # dicom_file: str, path to dicom file
+ # is_CT: bool, whether image is CT or not
+ # site: str, one of CT_WINDOWS.keys()
+ # return: 2D numpy array of shape (H, W)
+
+ ds = pydicom.dcmread(image_path)
+ image_array = ds.pixel_array * ds.RescaleSlope + ds.RescaleIntercept
+
+ image_array = process_intensity_image(image_array, is_CT, site)
+
+ return image_array
+
+
+def read_nifti(image_path, is_CT, slice_idx, site=None, HW_index=(0, 1), channel_idx=None):
+ # read nifti file and return pixel data
+
+ # image_path: str, path to nifti file
+ # is_CT: bool, whether image is CT or not
+ # slice_idx: int, slice index to read
+ # site: str, one of CT_WINDOWS.keys()
+ # HW_index: tuple, index of height and width in the image shape
+ # return: 2D numpy array of shape (H, W)
+
+
+ nii = nib.load(image_path)
+ image_array = nii.get_fdata()
+
+ if HW_index != (0, 1):
+ image_array = np.moveaxis(image_array, HW_index, (0, 1))
+
+ # get slice
+ if channel_idx is None:
+ image_array = image_array[:, :, slice_idx]
+ else:
+ image_array = image_array[:, :, slice_idx, channel_idx]
+
+ image_array = process_intensity_image(image_array, is_CT, site)
+ return image_array
+
+
+
+def read_rgb(image_path):
+ # read RGB image and return resized pixel data
+
+ # image_path: str, path to RGB image
+ # return: BytesIO buffer
+
+ # read image into numpy array
+ image = Image.open(image_path)
+ image = np.array(image)
+ if len(image.shape) == 2:
+ image = np.stack([image]*3, axis=-1)
+ elif image.shape[2] == 4:
+ image = image[:,:,:3]
+
+ # pad to square with equal padding on both sides
+ shape = image.shape
+ if shape[0] > shape[1]:
+ pad = (shape[0]-shape[1])//2
+ pad_width = ((0,0), (pad, pad), (0,0))
+ elif shape[0] < shape[1]:
+ pad = (shape[1]-shape[0])//2
+ pad_width = ((pad, pad), (0,0), (0,0))
+ else:
+ pad_width = None
+
+ if pad_width is not None:
+ image = np.pad(image, pad_width, 'constant', constant_values=0)
+
+ # resize image to 1024x1024 for each channel
+ image_size = 1024
+ resize_image = np.zeros((image_size, image_size, 3), dtype=np.uint8)
+ for i in range(3):
+ resize_image[:,:,i] = transform.resize(image[:,:,i], (image_size, image_size), order=3,
+ mode='constant', preserve_range=True, anti_aliasing=True)
+
+ return resize_image
+
+
+
+def get_instances(mask):
+ # get intances from binary mask
+ seg = sitk.GetImageFromArray(mask)
+ filled = sitk.BinaryFillhole(seg)
+ d = sitk.SignedMaurerDistanceMap(filled, insideIsPositive=False, squaredDistance=False, useImageSpacing=False)
+
+ ws = sitk.MorphologicalWatershed( d, markWatershedLine=False, level=1)
+ ws = sitk.Mask( ws, sitk.Cast(seg, ws.GetPixelID()))
+ ins_mask = sitk.GetArrayFromImage(ws)
+
+ # filter out instances with small area outliers
+ props = measure.regionprops_table(ins_mask, properties=('label', 'area'))
+ mean_area = np.mean(props['area'])
+ std_area = np.std(props['area'])
+
+ threshold = mean_area - 2*std_area - 1
+ ins_mask_filtered = ins_mask.copy()
+ for i, area in zip(props['label'], props['area']):
+ if area < threshold:
+ ins_mask_filtered[ins_mask == i] = 0
+
+ return ins_mask_filtered
+
+
\ No newline at end of file
diff --git a/inference_utils/target_dist.json b/inference_utils/target_dist.json
new file mode 100644
index 0000000000000000000000000000000000000000..847cccd74a872fbc0d2ec1aac5210c54887339ab
--- /dev/null
+++ b/inference_utils/target_dist.json
@@ -0,0 +1 @@
+{"CT-Abdomen": {"postcava": [[244.8001455798728, 5.314270814858824], [7.183679633251858, 5.168810995426391], [7.183679633251858, 5.168810995426391], [7.183679633251858, 5.168810995426391]], "aorta": [[570.5260544851909, 8.97527503179567], [3.3715049586348242, 1.4971164544774238], [3.3715049586348242, 1.4971164544774238], [3.3715049586348242, 1.4971164544774238]], "right kidney": [[831.8568013426873, 14.991866448573818], [4.970270375121704, 3.050385928796316], [4.970270375121704, 3.050385928796316], [4.970270375121704, 3.050385928796316]], "kidney": [[824.7288483151449, 17.740666994112335], [5.134294543833492, 3.188304874790919], [5.134294543833492, 3.188304874790919], [5.134294543833492, 3.188304874790919]], "left kidney": [[765.9269280548916, 14.314482540419498], [5.084499568327313, 3.2061871556243515], [5.084499568327313, 3.2061871556243515], [5.084499568327313, 3.2061871556243515]], "duodenum": [[121.5002253116006, 5.0616837393558045], [13.60882943690214, 15.313999640884173], [13.60882943690214, 15.313999640884173], [13.60882943690214, 15.313999640884173]], "pancreas": [[182.85416969377923, 6.9039775525067135], [17.489564177159146, 14.924761571311656], [17.489564177159146, 14.924761571311656], [17.489564177159146, 14.924761571311656]], "liver (non abdomen window)": [[481.5690096331249, 8.413924027868077], [6.047563882283547, 6.86712354789198], [6.047563882283547, 6.86712354789198], [6.047563882283547, 6.86712354789198]], "liver": [[497.88613290346797, 8.79208581405346], [20.552757782824486, 16.312687320589742], [20.552757782824486, 16.312687320589742], [20.552757782824486, 16.312687320589742]], "spleen": [[496.77984794364835, 8.498216025126785], [14.594250163059534, 10.71357260923987], [14.594250163059534, 10.71357260923987], [14.594250163059534, 10.71357260923987]], "stomach": [[137.7555592980079, 3.928159238756134], [5.978844398494112, 10.238758157160921], [5.978844398494112, 10.238758157160921], [5.978844398494112, 10.238758157160921]], "gallbladder": [[109.56988864543307, 3.4765854683723596], [32.35084093358493, 41.113482214152384], [32.35084093358493, 41.113482214152384], [32.35084093358493, 41.113482214152384]], "left adrenal gland": [[121.60075395406241, 4.266683492995461], [17.017417548383662, 18.48528509828753], [17.017417548383662, 18.48528509828753], [17.017417548383662, 18.48528509828753]], "adrenal gland": [[182.4265613513338, 7.813186080282246], [18.97442893128976, 20.599617257380345], [18.97442893128976, 20.599617257380345], [18.97442893128976, 20.599617257380345]], "right adrenal gland": [[158.21570288963346, 5.736947411814261], [17.17089273745977, 19.09450167978653], [17.17089273745977, 19.09450167978653], [17.17089273745977, 19.09450167978653]], "bladder": [[172.667607742299, 4.6885066612866835], [42.56984081338662, 56.45115036285909], [42.56984081338662, 56.45115036285909], [42.56984081338662, 56.45115036285909]], "esophagus": [[253.86092392814248, 6.886078359154348], [13.252110919965341, 15.437200766467301], [13.252110919965341, 15.437200766467301], [13.252110919965341, 15.437200766467301]]}, "CT-Chest": {"nodule": [[115.14726334918862, 3.0043952160348844], [5.275338876748403, 7.899248653413393], [5.275338876748403, 7.899248653413393], [5.275338876748403, 7.899248653413393]], "COVID-19 infection": [[226.93782607812352, 10.662200522447263], [11.74323002038987, 23.773784082857407], [11.74323002038987, 23.773784082857407], [11.74323002038987, 23.773784082857407]], "tumor": [[81.39154648592063, 3.0363381821985254], [9.799683628807484, 19.248706134279548], [9.799683628807484, 19.248706134279548], [9.799683628807484, 19.248706134279548]]}, "MRI-Abdomen": {"aorta": [[840.9822169946456, 13.699556855062456], [2.9798604461548766, 1.19765659474954], [2.9798604461548766, 1.19765659474954], [2.9798604461548766, 1.19765659474954]], "postcava": [[151.3891903352374, 4.700455115571472], [3.065810750535689, 2.074722812609995], [3.065810750535689, 2.074722812609995], [3.065810750535689, 2.074722812609995]], "right kidney": [[613.4017011464975, 11.282616103318485], [4.63815461741129, 2.2967740371944867], [4.63815461741129, 2.2967740371944867], [4.63815461741129, 2.2967740371944867]], "duodenum": [[88.51851857758399, 5.251374959142798], [9.350910364523573, 8.85976960554745], [9.350910364523573, 8.85976960554745], [9.350910364523573, 8.85976960554745]], "kidney": [[831.5762248415444, 18.739059302777875], [5.715871882386201, 2.6205541393599527], [5.715871882386201, 2.6205541393599527], [5.715871882386201, 2.6205541393599527]], "left kidney": [[255.4744196400276, 5.573793361388763], [6.081920320421431, 2.930383603114708], [6.081920320421431, 2.930383603114708], [6.081920320421431, 2.930383603114708]], "liver": [[491.1931789168259, 9.294627086787225], [10.138029098677139, 6.28829088692463], [10.138029098677139, 6.28829088692463], [10.138029098677139, 6.28829088692463]], "pancreas": [[136.2304629992425, 5.676744286342953], [19.631392824605342, 11.528214201070567], [19.631392824605342, 11.528214201070567], [19.631392824605342, 11.528214201070567]], "gallbladder": [[75.18767252055355, 2.8711737605829892], [14.500831537679415, 20.696868858705496], [14.500831537679415, 20.696868858705496], [14.500831537679415, 20.696868858705496]], "stomach": [[89.16380420023327, 4.461224829090838], [10.266772743753412, 16.943404348738376], [10.266772743753412, 16.943404348738376], [10.266772743753412, 16.943404348738376]], "spleen": [[413.92566589639046, 7.99961594912814], [7.267087388529462, 5.149714876028216], [7.267087388529462, 5.149714876028216], [7.267087388529462, 5.149714876028216]], "left adrenal gland": [[86.44109991236728, 4.826813402237061], [17.153928230900817, 14.858036650050408], [17.153928230900817, 14.858036650050408], [17.153928230900817, 14.858036650050408]], "adrenal gland": [[303.9642820935704, 16.729857009916806], [19.500678047021523, 17.02588768312544], [19.500678047021523, 17.02588768312544], [19.500678047021523, 17.02588768312544]], "right adrenal gland": [[172.36803145644578, 8.050377438528958], [15.257519917725558, 13.431078702905772], [15.257519917725558, 13.431078702905772], [15.257519917725558, 13.431078702905772]], "esophagus": [[193.1348898340059, 7.6397334220243325], [12.240331385391299, 16.812971132953354], [12.240331385391299, 16.812971132953354], [12.240331385391299, 16.812971132953354]]}, "MRI-Cardiac": {"left heart ventricle": [[964.9072936969454, 17.21177762137991], [5.880290818671821, 4.100959742819713], [5.880290818671821, 4.100959742819713], [5.880290818671821, 4.100959742819713]], "myocardium": [[448.3393673888417, 17.591805257426998], [5.208511169313307, 15.910705163394415], [5.208511169313307, 15.910705163394415], [5.208511169313307, 15.910705163394415]], "right heart ventricle": [[359.88937669636215, 9.392153523781843], [5.924076424141962, 5.554667293878979], [5.924076424141962, 5.554667293878979], [5.924076424141962, 5.554667293878979]]}, "MRI-FLAIR-Brain": {"edema": [[69.4159007224176, 5.568921766085619], [13.400334168570177, 4.965265405638592], [13.400334168570177, 4.965265405638592], [13.400334168570177, 4.965265405638592]], "tumor core": [[154.26935124167449, 8.089254912853598], [14.908340542645478, 4.820086393609397], [14.908340542645478, 4.820086393609397], [14.908340542645478, 4.820086393609397]], "whole tumor": [[485.48717118600956, 16.01178236475156], [25.74323915508559, 8.636438181178145], [25.74323915508559, 8.636438181178145], [25.74323915508559, 8.636438181178145]]}, "MRI-T1-Gd-Brain": {"enhancing tumor": [[175.6437881777937, 7.539344668413025], [17.864705093992068, 5.36432831714689], [17.864705093992068, 5.36432831714689], [17.864705093992068, 5.36432831714689]], "non-enhancing tumor": [[37.6625733247702, 3.8454536110058246], [6.568014639412233, 8.446289690167484], [6.568014639412233, 8.446289690167484], [6.568014639412233, 8.446289690167484]], "tumor core": [[180.88223552813486, 6.610443841067055], [9.70294999498087, 5.30262880784197], [9.70294999498087, 5.30262880784197], [9.70294999498087, 5.30262880784197]]}, "Pathology": {"connective tissue cells": [[46.71165884847293, 4.997126203483956], [9.942495884846476, 15.700775443760845], [4.328453739888501, 18.42621798468577], [9.798096322131162, 11.920352021312304]], "inflammatory cells": [[39.600337990197595, 3.1848025413959706], [6.287418328538852, 20.538379638162322], [2.9521703595392146, 25.264465092284006], [6.559595490616054, 12.004686961917436]], "neoplastic cells": [[82.29374052289526, 8.22429924322936], [9.592296798563375, 14.818916788142138], [4.948629785308088, 19.78516221506478], [10.729094314024243, 12.934345198477494]], "epithelial cells": [[91.75183574899573, 9.577544361042948], [13.469843493323452, 27.305962287612964], [4.696928248406198, 25.254143364646463], [11.077634907582583, 13.487595094752443]]}, "X-Ray-Chest": {"left lung": [[529.1669758355144, 7.465035502868491], [8.220284641505614, 11.62958600654364], [8.220284641505614, 11.62958600654364], [8.220284641505614, 11.62958600654364]], "lung": [[465.7809501354513, 7.147122106450173], [8.781306299078446, 12.335455073688102], [8.781306299078446, 12.335455073688102], [8.781306299078446, 12.335455073688102]], "right lung": [[567.6127039725319, 7.532428563004494], [8.067311420424144, 11.229763331648746], [8.067311420424144, 11.229763331648746], [8.067311420424144, 11.229763331648746]]}, "Ultrasound-Cardiac": {"left heart atrium": [[1188.687550702627, 24.234766943758856], [5.18832820435626, 13.705576921752291], [5.18832820435626, 13.705576921752291], [5.18832820435626, 13.705576921752291]], "left heart ventricle": [[2787.334986695437, 58.297232816307506], [15.28158405889985, 56.95469460140377], [15.28158405889985, 56.95469460140377], [15.28158405889985, 56.95469460140377]]}, "Endoscopy": {"neoplastic polyp": [[392.89875472390315, 5.4678888279040745], [7.477729277754545, 1.6522601344780465], [7.2704247484339035, 6.347521355120636], [4.3902399436060335, 6.543658310376327]], "polyp": [[163.7838288028474, 3.4851615302599117], [7.03659746479883, 1.9088902542177986], [6.992807172875011, 6.756628353721484], [5.185761648208865, 8.977427344868255]], "non-neoplastic polyp": [[214.9199548332033, 4.360826895414348], [7.303363948417486, 1.9789835935004905], [10.54652900087687, 9.009706115553772], [6.917879576439251, 10.404634951284532]]}, "Fundus": {"optic cup": [[1482.9561484784422, 35.78105120937013], [52.1031548324398, 1.5080077510381715], [10.023538467761934, 3.1641925551155046], [3.394564722036805, 2.4391933423559626]], "optic disc": [[626.9141229495486, 20.95002931507066], [18.278454005466408, 1.8261365514325893], [16.42282430959315, 11.171338052048034], [4.8937792939550135, 6.987302868644637]]}, "Dermoscopy": {"lesion": [[134.43456931870887, 4.743684855379663], [5.18053578956456, 2.3527492367343634], [3.809383004477107, 6.368793378843402], [2.3888068456218847, 6.655396307215968]], "melanoma": [[454.17848530764076, 9.6466178116726], [4.022144360826467, 7.870140640677671], [4.87109613458874, 18.93721534855073], [3.107895746664011, 13.604075970992069]]}, "OCT": {"edema": [[260.11475018501574, 7.379315940573871], [4.162158474003, 17.437425953761988], [12.65808078622105, 81.37165793634547], [1.763378481483125, 4.427309203795247]]}}
\ No newline at end of file
diff --git a/modeling/BaseModel.py b/modeling/BaseModel.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb96fe7d5a3d8e89148d00182788884d1e2abd42
--- /dev/null
+++ b/modeling/BaseModel.py
@@ -0,0 +1,45 @@
+import os
+import logging
+
+import torch
+import torch.nn as nn
+
+from utilities.model import align_and_update_state_dicts
+
+from utilities.distributed import init_distributed
+from utilities.arguments import load_opt_from_config_files
+
+import huggingface_hub
+
+logger = logging.getLogger(__name__)
+
+
+class BaseModel(nn.Module):
+ def __init__(self, opt, module: nn.Module):
+ super(BaseModel, self).__init__()
+ self.opt = opt
+ self.model = module
+
+ def forward(self, *inputs, **kwargs):
+ outputs = self.model(*inputs, **kwargs)
+ return outputs
+
+ def save_pretrained(self, save_dir):
+ torch.save(self.model.state_dict(), os.path.join(save_dir, "model_state_dict.pt"))
+
+ def from_pretrained(self, pretrained, filename: str = "biomedparse_v1.pt",
+ local_dir: str = "./pretrained", config_dir: str = "./configs"):
+ if pretrained.startswith("hf_hub:"):
+ hub_name = pretrained.split(":")[1]
+ huggingface_hub.hf_hub_download(hub_name, filename=filename,
+ local_dir=local_dir)
+ huggingface_hub.hf_hub_download(hub_name, filename="config.yaml",
+ local_dir=config_dir)
+ load_dir = os.path.join(local_dir, filename)
+ else:
+ load_dir = pretrained
+
+ state_dict = torch.load(load_dir, map_location=self.opt['device'])
+ state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)
+ self.model.load_state_dict(state_dict, strict=False)
+ return self
\ No newline at end of file
diff --git a/modeling/__init__.py b/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a3487693f2a738d9106e74518fb35114faefb27
--- /dev/null
+++ b/modeling/__init__.py
@@ -0,0 +1 @@
+from .architectures import build_model
\ No newline at end of file
diff --git a/modeling/architectures/__init__.py b/modeling/architectures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73256af41ec8a6db085e65083f4ec84429138891
--- /dev/null
+++ b/modeling/architectures/__init__.py
@@ -0,0 +1,5 @@
+from .xdecoder_model import *
+from .seem_model_v0 import *
+from .seem_model_v1 import *
+from .seem_model_demo import *
+from .build import build_model
\ No newline at end of file
diff --git a/modeling/architectures/build.py b/modeling/architectures/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..33bc340601d9fa3369d2c38711b8374662502002
--- /dev/null
+++ b/modeling/architectures/build.py
@@ -0,0 +1,22 @@
+_model_entrypoints = {}
+
+
+def build_model(config, **kwargs):
+ model_name = config['MODEL']['NAME']
+
+ if not is_model(model_name):
+ raise ValueError(f'Unkown model: {model_name}')
+
+ return model_entrypoints(model_name)(config, **kwargs)
+
+def register_model(fn):
+ module_name_split = fn.__module__.split('.')
+ model_name = module_name_split[-1]
+ _model_entrypoints[model_name] = fn
+ return fn
+
+def model_entrypoints(model_name):
+ return _model_entrypoints[model_name]
+
+def is_model(model_name):
+ return model_name in _model_entrypoints
\ No newline at end of file
diff --git a/modeling/architectures/seem_model_demo.py b/modeling/architectures/seem_model_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1786d575c7b893499b8aa8ea39f2939485829a2
--- /dev/null
+++ b/modeling/architectures/seem_model_demo.py
@@ -0,0 +1,923 @@
+# --------------------------------------------------------
+# SEEM -- Segment Everything Everywhere All at Once
+# Licensed under The Apache License 2.0 [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import random
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from kornia.contrib import distance_transform
+
+from detectron2.structures import Boxes, ImageList, Instances, BitMasks
+from detectron2.utils.memory import retry_if_cuda_oom
+from detectron2.data import MetadataCatalog
+
+from .build import register_model
+
+from ..utils import configurable, get_class_names, get_iou
+from ..vision.backbone import build_backbone, Backbone
+from ..body import build_xdecoder_head
+from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
+from ..language import build_language_encoder
+from ..language.loss import vl_similarity
+from utilities.prompt_engineering import prompt_engineering
+from utilities.constants import COCO_PANOPTIC_CLASSES
+
+
+class GeneralizedSEEM(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ *,
+ backbone: Backbone,
+ sem_seg_head: nn.Module,
+ criterion: nn.Module,
+ losses: dict,
+ num_queries: int,
+ object_mask_threshold: float,
+ overlap_threshold: float,
+ metadata,
+ task_switch: dict,
+ phrase_prob: float,
+ size_divisibility: int,
+ sem_seg_postprocess_before_inference: bool,
+ pixel_mean: Tuple[float],
+ pixel_std: Tuple[float],
+ # inference
+ semantic_on: bool,
+ panoptic_on: bool,
+ instance_on: bool,
+ test_topk_per_image: int,
+ train_dataset_name: str,
+ interactive_mode: str,
+ interactive_iter: str,
+ dilation_kernel: torch.Tensor,
+ ):
+ super().__init__()
+ self.backbone = backbone
+ self.sem_seg_head = sem_seg_head
+ self.criterion = criterion
+ self.losses = losses
+ self.num_queries = num_queries
+ self.overlap_threshold = overlap_threshold
+ self.object_mask_threshold = object_mask_threshold
+ self.metadata = metadata
+ if size_divisibility < 0:
+ # use backbone size_divisibility if not set
+ size_divisibility = self.backbone.size_divisibility
+ self.size_divisibility = size_divisibility
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ # additional args
+ self.semantic_on = semantic_on
+ self.instance_on = instance_on
+ self.panoptic_on = panoptic_on
+
+ # caption argument
+ self.task_switch = task_switch
+ self.phrase_prob = phrase_prob
+
+ self.test_topk_per_image = test_topk_per_image
+ self.train_class_names = None
+ self.interactive_mode = interactive_mode
+ self.interactive_iter = interactive_iter
+
+ if not self.semantic_on:
+ assert self.sem_seg_postprocess_before_inference
+
+ self.register_buffer("dilation_kernel", dilation_kernel)
+
+ @classmethod
+ def from_config(cls, cfg):
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
+ 'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
+
+ task_switch = {'bbox': dec_cfg.get('DETECTION', False),
+ 'mask': dec_cfg.get('MASK', True),
+ 'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
+ 'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
+ 'openimage': openimage_switch,
+ 'visual': dec_cfg['VISUAL'].get('ENABLED', False),
+ 'audio': dec_cfg['AUDIO'].get('ENABLED', False)}
+
+ # build model
+ extra = {'task_switch': task_switch}
+ backbone = build_backbone(cfg)
+ lang_encoder = build_language_encoder(cfg)
+ sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
+
+ # Training Settings.
+ loss_weights = {}
+ matcher = None
+ losses = {}
+ weight_dict = {}
+ grd_weight = {}
+ top_x_layers = {}
+ criterion = None
+ train_dataset_name = None
+ phrase_prob = None
+ # Loss parameters:
+ deep_supervision = None
+ no_object_weight = None
+
+ interactive_mode = 'best'
+ interactive_iter = 20
+ dilation = 3
+ dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
+
+ return {
+ "backbone": backbone,
+ "sem_seg_head": sem_seg_head,
+ "criterion": criterion,
+ "losses": losses,
+ "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
+ "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
+ "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
+ "metadata": None,
+ "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
+ "sem_seg_postprocess_before_inference": (
+ dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
+ or dec_cfg['TEST']['PANOPTIC_ON']
+ or dec_cfg['TEST']['INSTANCE_ON']
+ ),
+ "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
+ "pixel_std": cfg['INPUT']['PIXEL_STD'],
+ "task_switch": task_switch,
+ "phrase_prob": phrase_prob,
+ # inference
+ "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
+ "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
+ "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
+ "test_topk_per_image": cfg['MODEL']['DECODER']['TEST']['DETECTIONS_PER_IMAGE'],
+ "train_dataset_name": train_dataset_name,
+ "interactive_mode": interactive_mode,
+ "interactive_iter": interactive_iter,
+ "dilation_kernel": dilation_kernel,
+ }
+
+ @property
+ def device(self):
+ return self.pixel_mean.device
+
+ def forward(self, batched_inputs, mode='default'):
+ if self.training:
+ losses = {}
+ if self.task_switch['mask']:
+ losses_seg = self.forward_seg(batched_inputs)
+ losses.update(losses_seg)
+ if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
+ losses_openimage = self.forward_openimage(batched_inputs['openimage'])
+ losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
+ losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
+ losses.update(losses_openimage)
+ for k in list(losses.keys()):
+ if k in self.criterion.weight_dict:
+ losses[k] *= self.criterion.weight_dict[k]
+ else: # remove this loss if not specified in `weight_dict`
+ losses.pop(k)
+ return losses
+ else:
+ if mode == 'interactive':
+ return self.evaluate_interactive(batched_inputs)
+ elif mode == 'grounding_spatial':
+ return self.evaluate_grounding_sptial(batched_inputs, mode)
+ elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
+ return self.evaluate_grounding(batched_inputs, mode)
+ else:
+ return self.evaluate(batched_inputs)
+
+
+ def forward_seg(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
+
+ extra = {}
+ # mask classification target
+ if "instances" in batched_inputs[0]:
+ # input bounding box is checked to be correct.
+ targets = self.prepare_targets(batched_inputs, images)
+
+ if self.task_switch['grounding']:
+ grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
+ grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
+ non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
+ grounding_tokens[non_zero_query_mask] = 0
+
+ extra['grounding_tokens'] = grounding_tokens
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ if self.task_switch['spatial']:
+ pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
+ neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
+ fp_masks = torch.stack([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs])
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
+
+ features = self.backbone(images.tensor)
+ mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ # forward spatial only without gradient
+ if self.task_switch['spatial']:
+ with torch.no_grad():
+ # generate random integeter between [0,3]
+ rand_iter_num = random.randint(0, 2)
+ for i in range(rand_iter_num):
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
+ extra.update(outputs)
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
+
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
+ extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
+ 'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
+ 'false_positive_mask': extra['false_positive_mask']}
+ # bipartite matching-based loss
+ self.criterion.losses = self.losses['seg'] # seg criterion losses
+ losses = self.criterion(outputs, targets, extra)
+
+ del outputs
+ return losses
+
+ def evaluate_demo(self, batched_inputs):
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+
+ extra = {}
+ if 'stroke' in batched_inputs[0]:
+ pos_masks = (batched_inputs[0]['stroke'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
+ neg_masks = (batched_inputs[0]['stroke'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
+
+ if 'visual' in batched_inputs[0]:
+ extra.update(batched_inputs[0]['visual'])
+
+ if 'text' in batched_inputs[0]:
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(batched_inputs[0]['text'], name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+ query_emb = token_emb[tokens['attention_mask'].bool()]
+ non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
+ extra['grounding_tokens'] = query_emb[:,None]
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+ extra['grounding_class'] = gtext['class_emb']
+
+ if 'audio' in batched_inputs[0]:
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(batched_inputs[0]['audio'], name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+ query_emb = token_emb[tokens['attention_mask'].bool()]
+ non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
+ extra['audio_tokens'] = query_emb[:,None]
+ extra['audio_nonzero_mask'] = non_zero_query_mask.t()
+ extra['audio_class'] = gtext['class_emb']
+
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='demo')
+ return outputs, images.tensor.shape, extra
+
+ assert self.task_switch['spatial']
+ assert 'spatial_query' in batched_inputs[0]
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ extra = {}
+
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
+ mask_features = mask_features.repeat(nm,1,1,1)
+
+ all_batch_shape_iou = []
+ pred_smask_pointer = None
+ prev_smask_pointer = None
+ pred_smask_all = None
+
+ query_index = self.sem_seg_head.predictor.query_index
+ assert self.interactive_mode == 'best'
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
+
+ for i in range(self.interactive_iter):
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
+ extra.update(outputs)
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
+
+ s = image_sizes[0]
+ b = batched_inputs[0]
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
+ gt_smask = b['gt_masks_orisize']
+ all_batch_shape_iou += [get_iou(gt_smask, pred_smask_all)]
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
+
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
+ return processed_results
+
+ def evaluate(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
+
+ mask_cls_results = outputs["pred_logits"]
+ mask_pred_results = outputs["pred_masks"]
+ box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
+
+ # upsample masks
+ mask_pred_results = F.interpolate(
+ mask_pred_results,
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ input_size = mask_pred_results.shape[-2:]
+ del outputs
+
+ processed_results = []
+ for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
+ mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ if self.sem_seg_postprocess_before_inference:
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
+
+ # semantic segmentation inference
+ if self.semantic_on:
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
+ if not self.sem_seg_postprocess_before_inference:
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
+ processed_results[-1]["sem_seg"] = r
+
+ # panoptic segmentation inference
+ if self.panoptic_on:
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
+ processed_results[-1]["panoptic_seg"] = panoptic_r
+
+ # instance segmentation inference
+ if self.instance_on:
+ if self.task_switch['bbox']:
+ box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
+ processed_results[-1]["instances"] = instance_r
+
+ return processed_results
+
+ def evaluate_interactive(self, batched_inputs):
+ assert self.task_switch['spatial']
+ assert 'spatial_query' in batched_inputs[0]
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ extra = {}
+
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
+ mask_features = mask_features.repeat(nm,1,1,1)
+
+ all_batch_shape_iou = []
+ pred_smask_pointer = None
+ prev_smask_pointer = None
+ pred_smask_all = None
+
+ query_index = self.sem_seg_head.predictor.query_index
+ assert self.interactive_mode == 'best'
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
+
+ for i in range(self.interactive_iter):
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
+ extra.update(outputs)
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
+
+ s = image_sizes[0]
+ b = batched_inputs[0]
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
+ gt_smask = b['gt_masks_orisize']
+ all_batch_shape_iou += [get_iou(gt_smask, pred_smask_all)]
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
+
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
+ return processed_results
+
+ def evaluate_referring_image(self, batched_inputs, extra={}):
+ assert self.task_switch['spatial']
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+ assert self.interactive_mode == 'best'
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ if 'spatial_query' in batched_inputs[0]:
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
+ mask_features = mask_features.repeat(nm,1,1,1)
+
+ query_index = self.sem_seg_head.predictor.query_index
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
+
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
+ return outputs, images.tensor.shape
+
+ def evaluate_grounding(self, batched_inputs, mode):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+
+ extra = {}
+ # mask_pred_results = []
+ # for idx, batch_per_image in enumerate(batched_inputs):
+ # grd_texts = batch_per_image['groundings']['texts']
+ # grd_masks = []
+ # for anno_text in grd_texts:
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
+ # token_emb = gtext['token_emb']
+ # tokens = gtext['tokens']
+
+ # grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
+ # extra['grounding_tokens'] = grd_emb[:,None]
+
+ # assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+ # features = self.backbone(images.tensor)
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ # pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
+ # v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
+ # t_emb = grd_emb[-1:]
+
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ # matched_id = out_prob.max(0)[1]
+ # grd_masks += [pred_gmasks[matched_id,:,:]]
+ # mask_pred_results += [torch.cat(grd_masks)]
+
+ # comment for multi object inference.
+ mask_pred_results = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ grd_texts = batch_per_image['groundings']['texts']
+ grd_texts = [x[0] for x in grd_texts]
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+ query_emb = token_emb[tokens['attention_mask'].bool()]
+ non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
+
+ extra['grounding_tokens'] = query_emb[:,None]
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ pred_gmasks = outputs['pred_gmasks'][idx]
+ v_emb = outputs['pred_gtexts'][idx]
+ t_emb = gtext['class_emb']
+
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ matched_id = out_prob.max(0)[1]
+ mask_pred_results += [pred_gmasks[matched_id,:,:]]
+
+ for i in range(len(mask_pred_results)):
+ # upsample masks
+ mask_pred_results[i] = F.interpolate(
+ mask_pred_results[i][None,],
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bilinear",
+ align_corners=False,
+ )[0]
+
+ processed_results = []
+ for mask_pred_result, input_per_image, image_size in zip(
+ mask_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ processed_results[-1]['grounding_mask'] = mask_pred_result
+
+ # compute bbox
+ # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
+ # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
+ # processed_results[-1]['grounding_box'] = bbox
+
+ return processed_results
+
+ def evaluate_grounding_sptial(self, batched_inputs, mode):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+
+ extra = {}
+ dilation = 3
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
+ pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+
+ mask_pred_results = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ grd_texts = batch_per_image['groundings']['texts']
+ grd_masks = []
+ for idx2, anno_text in enumerate(grd_texts):
+ extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+
+ grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
+ non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
+ extra['grounding_tokens'] = grd_emb[:,None]
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ pred_gmasks = outputs['pred_gmasks'][idx]
+ v_emb = outputs['pred_gtexts'][idx]
+ t_emb = gtext['class_emb']
+
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ matched_id = out_prob.max(0)[1]
+ grd_masks += [pred_gmasks[matched_id,:,:]]
+ mask_pred_results += [torch.cat(grd_masks)]
+
+ # comment for multi object inference.
+ # mask_pred_results = []
+ # for idx, batch_per_image in enumerate(batched_inputs):
+ # grd_texts = batch_per_image['groundings']['texts']
+ # grd_texts = [x[0] for x in grd_texts]
+
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ # token_emb = gtext['token_emb']
+ # tokens = gtext['tokens']
+ # query_emb = token_emb[tokens['attention_mask'].bool()]
+ # non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
+
+ # extra['grounding_tokens'] = query_emb[:,None]
+ # extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ # features = self.backbone(images.tensor)
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ # pred_gmasks = outputs['pred_gmasks'][idx]
+ # v_emb = outputs['pred_gtexts'][idx]
+ # t_emb = gtext['class_emb']
+
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ # matched_id = out_prob.max(0)[1]
+ # mask_pred_results += [pred_gmasks[matched_id,:,:]]
+
+ for i in range(len(mask_pred_results)):
+ # upsample masks
+ mask_pred_results[i] = F.interpolate(
+ mask_pred_results[i][None,],
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bilinear",
+ align_corners=False,
+ )[0]
+
+ processed_results = []
+ for mask_pred_result, input_per_image, image_size in zip(
+ mask_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ processed_results[-1]['grounding_mask'] = mask_pred_result
+
+ return processed_results
+
+ def prepare_targets(self, batched_inputs, images):
+ h_pad, w_pad = images.tensor.shape[-2:]
+ new_targets = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ targets_per_image = batch_per_image['instances'].to(self.device)
+ # pad gt
+ gt_masks = targets_per_image.gt_masks.tensor
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+
+ gt_boxes = targets_per_image.gt_boxes.tensor
+ ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
+ gt_boxes = gt_boxes / ratio
+ xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
+ gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
+
+ target_dict = {
+ "labels": targets_per_image.gt_classes,
+ "is_things": targets_per_image.is_things,
+ "masks": padded_masks,
+ "boxes": gt_boxes,
+ }
+
+ if self.task_switch['spatial']:
+ # prepare targets for spatial query
+ target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
+
+ if self.task_switch['grounding']:
+ grd_masks = batch_per_image['groundings']['masks']
+ grd_texts = batch_per_image['groundings']['texts']
+ grd_hash = batch_per_image['groundings']['hash']
+ grd_task = batch_per_image['groundings']['mode']
+
+ if len(grd_masks) == 0:
+ padded_masks = None
+ else:
+ padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
+ padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+
+ unique_hash_id = np.unique(grd_hash, return_index=True)[1]
+ selected_mask = np.zeros(len(grd_hash)).astype(bool)
+ selected_mask[unique_hash_id] = True
+
+ selected_token_emb = token_emb[selected_mask]
+ selected_attn_mask = tokens['attention_mask'][selected_mask]
+ query_emb = selected_token_emb[selected_attn_mask.bool()]
+
+ class_idx = tokens['attention_mask'].sum(dim=-1) - 1
+ class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
+ class_emb = token_emb[class_idx]
+
+ target_dict['grounding_masks'] = padded_masks
+ target_dict['grounding_query_embs'] = query_emb
+ target_dict['grounding_class_embs'] = class_emb
+ target_dict['grounding_hash'] = grd_hash
+ target_dict['grounding_task'] = grd_task
+
+ new_targets.append(target_dict)
+ return new_targets
+
+ def prepare_next_spaital_mask(self, outputs, batched_inputs):
+ gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
+ if self.training:
+ gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
+ else:
+ gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor.transpose(0,1)
+
+ pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
+ prev_masks = torch.stack(outputs['spatial_query_pos_mask']) | torch.stack(outputs['spatial_query_neg_mask'])
+
+ fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
+ fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
+
+ # compute iou between gt and pred
+ iou = (gt_masks & pred_masks).sum(list(range(1,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(1,len(fn.shape)))) + 1e-8)
+ fn_sum = fn.sum(dim=list(range(1,len(fn.shape))))
+ fp_sum = fp.sum(dim=list(range(1,len(fp.shape))))
+
+ is_postive = fn_sum > fp_sum
+ # is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
+ select_mask = torch.stack([fn[i] if is_postive[i] else fp[i] for i in range(len(fn))])
+
+ # conv implementation
+ n,_,h,w=select_mask.shape
+ mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
+ next_mask = next_mask.view(n,-1)
+ next_mask[max_xy_idx] = True
+ next_mask = next_mask.reshape((n,1,h,w)).float()
+ dilation = 3
+ next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2) > 0
+
+ # determine whether next mask is zero
+ keep = (iou < 0.925)
+ next_mask = next_mask & keep.view(-1,1,1,1)
+
+ pos_mask = []
+ neg_mask = []
+ for idx, ip in enumerate(is_postive):
+ if ip:
+ pos_mask += [outputs['spatial_query_pos_mask'][idx] | next_mask[idx]]
+ neg_mask += [outputs['spatial_query_neg_mask'][idx]]
+ else:
+ pos_mask += [outputs['spatial_query_pos_mask'][idx]]
+ neg_mask += [outputs['spatial_query_neg_mask'][idx] | next_mask[idx]]
+
+ if 'false_positive_mask' in outputs:
+ fp = outputs['false_positive_mask'] | fp
+ return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
+ return semseg
+
+ def panoptic_inference(self, mask_cls, mask_pred):
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
+ mask_pred = mask_pred.sigmoid()
+
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
+ cur_scores = scores[keep]
+ cur_classes = labels[keep]
+ cur_masks = mask_pred[keep]
+ cur_mask_cls = mask_cls[keep]
+ cur_mask_cls = cur_mask_cls[:, :-1]
+
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
+
+ h, w = cur_masks.shape[-2:]
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
+ segments_info = []
+
+ current_segment_id = 0
+
+ if cur_masks.shape[0] == 0:
+ # We didn't detect any mask :(
+ return panoptic_seg, segments_info
+ else:
+ # take argmax
+ cur_mask_ids = cur_prob_masks.argmax(0)
+ stuff_memory_list = {}
+ for k in range(cur_classes.shape[0]):
+ pred_class = cur_classes[k].item()
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
+ mask_area = (cur_mask_ids == k).sum().item()
+ original_area = (cur_masks[k] >= 0.5).sum().item()
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
+
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
+ if mask_area / original_area < self.overlap_threshold:
+ continue
+
+ # merge stuff regions
+ if not isthing:
+ if int(pred_class) in stuff_memory_list.keys():
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
+ continue
+ else:
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
+
+ current_segment_id += 1
+ panoptic_seg[mask] = current_segment_id
+
+ segments_info.append(
+ {
+ "id": current_segment_id,
+ "isthing": bool(isthing),
+ "category_id": int(pred_class),
+ }
+ )
+
+ return panoptic_seg, segments_info
+
+ def instance_inference(self, mask_cls, mask_pred, box_pred):
+ # mask_pred is already processed to have the same shape as original input
+ image_size = mask_pred.shape[-2:]
+
+ # [Q, K]
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
+
+ labels_per_image = labels[topk_indices]
+ topk_indices = (topk_indices // self.sem_seg_head.num_classes)
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
+ mask_pred = mask_pred[topk_indices]
+ if box_pred is not None:
+ box_pred = box_pred[topk_indices]
+
+ # if this is panoptic segmentation, we only keep the "thing" classes
+ if self.panoptic_on:
+ keep = torch.zeros_like(scores_per_image).bool()
+ for i, lab in enumerate(labels_per_image):
+ keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
+
+ scores_per_image = scores_per_image[keep]
+ labels_per_image = labels_per_image[keep]
+ mask_pred = mask_pred[keep]
+
+ if box_pred is not None:
+ box_pred = box_pred[keep]
+
+ result = Instances(image_size)
+ # mask (before sigmoid)
+ result.pred_masks = (mask_pred > 0).float()
+ # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
+ # Uncomment the following to get boxes from masks (this is slow)
+
+ if box_pred is not None:
+ result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
+ else:
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
+
+ # calculate average mask prob
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
+ result.scores = scores_per_image * mask_scores_per_image
+ result.pred_classes = labels_per_image
+
+ return result
+
+
+@register_model
+def get_seem_model(cfg, **kwargs):
+ return GeneralizedSEEM(cfg)
\ No newline at end of file
diff --git a/modeling/architectures/seem_model_v0.py b/modeling/architectures/seem_model_v0.py
new file mode 100644
index 0000000000000000000000000000000000000000..a10f47557e4a65dfcfeac4d7a4865b6ebee61c90
--- /dev/null
+++ b/modeling/architectures/seem_model_v0.py
@@ -0,0 +1,1160 @@
+# --------------------------------------------------------
+# SEEM -- Segment Everything Everywhere All at Once
+# Licensed under The Apache License 2.0 [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import random
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from kornia.contrib import distance_transform
+
+from detectron2.structures import Boxes, ImageList, Instances, BitMasks
+from detectron2.utils.memory import retry_if_cuda_oom
+from detectron2.data import MetadataCatalog
+
+from .build import register_model
+
+from ..utils import configurable, get_class_names, get_iou
+from ..vision.backbone import build_backbone, Backbone
+from ..body import build_xdecoder_head
+from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
+from ..language import build_language_encoder
+from ..language.loss import vl_similarity
+from utilities.prompt_engineering import prompt_engineering
+from utilities.constants import COCO_PANOPTIC_CLASSES
+
+
+class GeneralizedSEEM(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ *,
+ backbone: Backbone,
+ sem_seg_head: nn.Module,
+ criterion: nn.Module,
+ losses: dict,
+ num_queries: int,
+ object_mask_threshold: float,
+ overlap_threshold: float,
+ metadata,
+ task_switch: dict,
+ phrase_prob: float,
+ size_divisibility: int,
+ sem_seg_postprocess_before_inference: bool,
+ pixel_mean: Tuple[float],
+ pixel_std: Tuple[float],
+ # inference
+ semantic_on: bool,
+ panoptic_on: bool,
+ instance_on: bool,
+ test_topk_per_image: int,
+ train_dataset_name: str,
+ interactive_mode: str,
+ interactive_iter: str,
+ dilation_kernel: torch.Tensor,
+ train_max_iter: int,
+ ):
+ """
+ Args:
+ backbone: a backbone module, must follow detectron2's backbone interface
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
+ criterion: a module that defines the loss
+ num_queries: int, number of queries
+ object_mask_threshold: float, threshold to filter query based on classification score
+ for panoptic segmentation inference
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
+ segmentation inference
+ size_divisibility: Some backbones require the input height and width to be divisible by a
+ specific integer. We can use this to override such requirement.
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
+ to original input size before semantic segmentation inference or after.
+ For high-resolution dataset like Mapillary, resizing predictions before
+ inference will cause OOM error.
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
+ the per-channel mean and std to be used to normalize the input image
+ semantic_on: bool, whether to output semantic segmentation prediction
+ instance_on: bool, whether to output instance segmentation prediction
+ panoptic_on: bool, whether to output panoptic segmentation prediction
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
+ """
+ super().__init__()
+ self.backbone = backbone
+ self.sem_seg_head = sem_seg_head
+ self.criterion = criterion
+ self.losses = losses
+ self.num_queries = num_queries
+ self.overlap_threshold = overlap_threshold
+ self.object_mask_threshold = object_mask_threshold
+ self.metadata = metadata
+ if size_divisibility < 0:
+ # use backbone size_divisibility if not set
+ size_divisibility = self.backbone.size_divisibility
+ self.size_divisibility = size_divisibility
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ # additional args
+ self.semantic_on = semantic_on
+ self.instance_on = instance_on
+ self.panoptic_on = panoptic_on
+
+ # caption argument
+ self.task_switch = task_switch
+ self.phrase_prob = phrase_prob
+ self.train_max_iter = train_max_iter
+
+ self.test_topk_per_image = test_topk_per_image
+ self.train_class_names = get_class_names(train_dataset_name)
+ self.interactive_mode = interactive_mode
+ self.interactive_iter = interactive_iter
+
+ if not self.semantic_on:
+ assert self.sem_seg_postprocess_before_inference
+
+ self.register_buffer("dilation_kernel", dilation_kernel)
+
+ @classmethod
+ def from_config(cls, cfg):
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ # Loss parameters:
+ deep_supervision = dec_cfg['DEEP_SUPERVISION']
+ no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
+
+ # loss weights
+ loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
+ 'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
+ 'spatial': {'ce': dec_cfg['SCLASS_WEIGHT'], 'dice': dec_cfg['SDICE_WEIGHT'], 'bce': dec_cfg['SMASK_WEIGHT']},
+ 'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']},
+ 'openimage': {'ce': dec_cfg['OCLASS_WEIGHT'], 'dice': dec_cfg['ODICE_WEIGHT'], 'bce': dec_cfg['OMASK_WEIGHT']}}
+
+ openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
+ 'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
+
+ task_switch = {'bbox': dec_cfg.get('DETECTION', False),
+ 'mask': dec_cfg['MASK'].get('ENABLED', True),
+ 'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
+ 'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
+ 'openimage': openimage_switch}
+
+ top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
+ 'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),
+ 'openimage': dec_cfg.get('TOP_OPENIMAGE_LAYERS', 10),
+ 'spatial': dec_cfg.get('TOP_SPATIAL_LAYERS', 10)}
+
+ spatial_cost = {"class_weight": dec_cfg['COST_SPATIAL']['CLASS_WEIGHT'],
+ "mask_weight": dec_cfg['COST_SPATIAL']['MASK_WEIGHT'],
+ "dice_weight": dec_cfg['COST_SPATIAL']['DICE_WEIGHT']}
+
+ extra = {'task_switch': task_switch}
+ backbone = build_backbone(cfg)
+ lang_encoder = build_language_encoder(cfg)
+ sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
+
+ # building criterion
+ matcher = HungarianMatcher(
+ cost_class=loss_weights['mask']['ce'],
+ cost_mask=loss_weights['mask']['bce'],
+ cost_dice=loss_weights['mask']['dice'],
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
+ spatial_cost=spatial_cost,
+ )
+
+ # init weight dict and criterion loss functions.
+ losses = {'seg': [], 'openimage': []}
+ if task_switch['mask']:
+ losses['seg'] += ["labels", "masks"]
+ if task_switch['spatial']:
+ losses['seg'] += ["spatials"]
+ if task_switch['grounding']:
+ losses['seg'] += ["groundings"]
+ if task_switch['openimage']:
+ losses['openimage'] += ["labels_openimage", "masks"]
+ if task_switch['openimage']['grounding']:
+ losses['openimage'] += ["groundings"]
+
+ weight_dict = {}
+ for key, turn_on in task_switch.items():
+ if turn_on:
+ if isinstance(loss_weights[key], dict):
+ # HACK it should support bbox in the future
+ for key_, weight in loss_weights[key].items():
+ weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
+ else:
+ weight_dict["loss_{}_0".format(key)] = loss_weights[key]
+
+ # generate full weight dict and remove not computed layers.
+ if deep_supervision:
+ dec_layers = dec_cfg['DEC_LAYERS']
+ aux_weight_dict = {}
+ for i in range(dec_layers - 1):
+ for k, v in weight_dict.items():
+ if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
+ continue
+ aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
+ weight_dict.update(aux_weight_dict)
+
+ grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
+ # generate critenrion for loss function.
+ criterion = SetCriterion(
+ sem_seg_head.num_classes,
+ matcher=matcher,
+ weight_dict=weight_dict,
+ top_x_layers=top_x_layers,
+ eos_coef=no_object_weight,
+ losses=[],
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
+ oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
+ importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
+ grounding_weight=grd_weight,
+ )
+
+ # extra logistic
+ train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
+ train_max_iter = dec_cfg['SPATIAL'].get('MAX_ITER', 3)
+ phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
+ interactive_mode = cfg['STROKE_SAMPLER']['EVAL']['MODE']
+ interactive_iter = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
+
+ dilation = 3
+ dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
+
+ return {
+ "backbone": backbone,
+ "sem_seg_head": sem_seg_head,
+ "criterion": criterion,
+ "losses": losses,
+ "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
+ "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
+ "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
+ "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
+ "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
+ "sem_seg_postprocess_before_inference": (
+ dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
+ or dec_cfg['TEST']['PANOPTIC_ON']
+ or dec_cfg['TEST']['INSTANCE_ON']
+ ),
+ "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
+ "pixel_std": cfg['INPUT']['PIXEL_STD'],
+ "task_switch": task_switch,
+ "phrase_prob": phrase_prob,
+ # inference
+ "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
+ "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
+ "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
+ "test_topk_per_image": cfg['TEST']['DETECTIONS_PER_IMAGE'],
+ "train_dataset_name": train_dataset_name,
+ "interactive_mode": interactive_mode,
+ "interactive_iter": interactive_iter,
+ "dilation_kernel": dilation_kernel,
+ "train_max_iter": train_max_iter,
+ }
+
+ @property
+ def device(self):
+ return self.pixel_mean.device
+
+ def forward(self, batched_inputs, mode='default'):
+ """
+ Args:
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
+ Each item in the list contains the inputs for one image.
+ For now, each item in the list is a dict that contains:
+ * "image": Tensor, image in (C, H, W) format.
+ * "instances": per-region ground truth
+ * Other information that's included in the original dicts, such as:
+ "height", "width" (int): the output resolution of the model (may be different
+ from input resolution), used in inference.
+ Returns:
+ list[dict]:
+ each dict has the results for one image. The dict contains the following keys:
+
+ * "sem_seg":
+ A Tensor that represents the
+ per-pixel segmentation prediced by the head.
+ The prediction has shape KxHxW that represents the logits of
+ each class for each pixel.
+ * "panoptic_seg":
+ A tuple that represent panoptic output
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
+ Each dict contains keys "id", "category_id", "isthing".
+ """
+ if self.training:
+ losses = {}
+ if self.task_switch['mask'] or self.task_switch['grounding'] or self.task_switch['spatial']:
+ losses_seg = self.forward_seg(batched_inputs)
+ losses.update(losses_seg)
+ if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
+ losses_openimage = self.forward_openimage(batched_inputs['openimage'])
+ losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
+ losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
+ losses.update(losses_openimage)
+ for k in list(losses.keys()):
+ if k in self.criterion.weight_dict:
+ losses[k] *= self.criterion.weight_dict[k]
+ else: # remove this loss if not specified in `weight_dict`
+ losses.pop(k)
+ return losses
+ else:
+ if mode == 'interactive':
+ return self.evaluate_interactive(batched_inputs)
+ elif mode == 'interactive_grounding':
+ return self.evaluate_interactive_grounding(batched_inputs)
+ elif mode == 'grounding_spatial':
+ return self.evaluate_grounding_sptial(batched_inputs, mode)
+ elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
+ return self.evaluate_grounding(batched_inputs, mode)
+ else:
+ return self.evaluate(batched_inputs)
+
+
+ def forward_seg(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
+
+ extra = {}
+ # mask classification target
+ if "instances" in batched_inputs[0]:
+ # input bounding box is checked to be correct.
+ targets = self.prepare_targets(batched_inputs, images)
+
+ if self.task_switch['grounding']:
+ grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
+ grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
+ non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
+ grounding_tokens[non_zero_query_mask] = 0
+
+ extra['grounding_tokens'] = grounding_tokens
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ if self.task_switch['spatial']:
+ pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
+ neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
+ fp_masks = torch.stack([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs])
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
+
+ features = self.backbone(images.tensor)
+ mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ # forward spatial only without gradient
+ if self.task_switch['spatial']:
+ with torch.no_grad():
+ # generate random integeter between [0,3]
+ rand_iter_num = random.randint(0, self.train_max_iter)
+ for i in range(rand_iter_num):
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
+ extra.update(outputs)
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
+
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
+
+ extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
+ 'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
+ 'false_positive_mask': extra['false_positive_mask']}
+ # bipartite matching-based loss
+ self.criterion.losses = self.losses['seg'] # seg criterion losses
+ losses = self.criterion(outputs, targets, extra)
+
+ del outputs
+ return losses
+
+ def evaluate(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
+
+ mask_cls_results = outputs["pred_logits"]
+ mask_pred_results = outputs["pred_masks"]
+ box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
+
+ # upsample masks
+ mask_pred_results = F.interpolate(
+ mask_pred_results,
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ input_size = mask_pred_results.shape[-2:]
+ del outputs
+
+ processed_results = []
+ for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
+ mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ if self.sem_seg_postprocess_before_inference:
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
+
+ # semantic segmentation inference
+ if self.semantic_on:
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
+ if not self.sem_seg_postprocess_before_inference:
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
+ processed_results[-1]["sem_seg"] = r
+
+ # panoptic segmentation inference
+ if self.panoptic_on:
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
+ processed_results[-1]["panoptic_seg"] = panoptic_r
+
+ # instance segmentation inference
+ if self.instance_on:
+ if self.task_switch['bbox']:
+ box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
+ processed_results[-1]["instances"] = instance_r
+
+ return processed_results
+
+ def evaluate_interactive(self, batched_inputs):
+ assert self.task_switch['spatial']
+ assert 'spatial_query' in batched_inputs[0]
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ extra = {}
+
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
+ mask_features = mask_features.repeat(nm,1,1,1)
+
+ all_batch_shape_iou = []
+ pred_smask_pointer = None
+ prev_smask_pointer = None
+ pred_smask_all = None
+
+ # visualization code
+ # v_pred_mask = []
+ # v_pos_mask = []
+ # v_neg_mask = []
+ # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
+ query_index = self.sem_seg_head.predictor.query_index
+ if self.interactive_mode in ['best', 'best_random']:
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
+ elif self.interactive_mode == 'random':
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
+ extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
+ else:
+ assert False, "invalid interactive mode"
+
+ for i in range(self.interactive_iter):
+ # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
+ # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
+ extra.update(outputs)
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
+ # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
+
+ s = image_sizes[0]
+ b = batched_inputs[0]
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
+ gt_smask = b['gt_masks_orisize']
+ ious = get_iou(gt_smask, pred_smask_all)
+ all_batch_shape_iou += [ious]
+ if (ious > 0.9).sum() == len(ious):
+ all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
+ break
+ if self.interactive_mode in ['best', 'best_random']:
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
+ elif self.interactive_mode == 'random':
+ extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
+ else:
+ assert False, "invalid interactive mode"
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
+
+ return processed_results
+
+ def evaluate_interactive_single(self, batched_inputs, extra={}):
+ assert self.task_switch['spatial']
+ assert 'spatial_query' in batched_inputs[0]
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
+ mask_features = mask_features.repeat(nm,1,1,1)
+
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
+
+ s = image_sizes[0]
+ b = batched_inputs[0]
+ pred_smask_ori = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
+ pred_smask_batch = pred_smask[:,:,:s[0],:s[1]].sigmoid() > 0.5
+ ious = []
+ if 'gt_masks_orisize' in b:
+ gt_smask = b['gt_masks_orisize'].to(pred_smask_ori.device)
+ ious = get_iou(gt_smask, pred_smask_ori)
+ processed_results = [{"mask_iou": ious, 'pred_mask_ori': pred_smask_ori, 'pred_mask_batch': pred_smask_batch}]
+ return processed_results
+
+ def evaluate_interactive_grounding(self, batched_inputs):
+ assert self.task_switch['spatial']
+ assert 'spatial_query' in batched_inputs[0]
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ extra = {}
+
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
+ mask_features = mask_features.repeat(nm,1,1,1)
+
+ all_batch_shape_iou = []
+ pred_smask_pointer = None
+ prev_smask_pointer = None
+ pred_smask_all = None
+
+ # visualization code
+ # v_pred_mask = []
+ # v_pos_mask = []
+ # v_neg_mask = []
+ # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
+ query_index = self.sem_seg_head.predictor.query_index
+ if self.interactive_mode in ['best', 'best_random']:
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
+ elif self.interactive_mode == 'random':
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
+ extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
+ else:
+ assert False, "invalid interactive mode"
+
+ grd_texts = batched_inputs[0]['classes']
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+ query_emb = nn.utils.rnn.pad_sequence([_token_emb[_tokens.bool()] for _token_emb, _tokens in zip(token_emb, tokens['attention_mask'])], padding_value=-1)
+ non_zero_query_mask = (query_emb.sum(dim=-1) < 0)
+
+ extra['grounding_tokens'] = query_emb
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ for i in range(self.interactive_iter):
+ # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
+ # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
+ extra.update(outputs)
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
+ # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
+
+ s = image_sizes[0]
+ b = batched_inputs[0]
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
+ gt_smask = b['gt_masks_orisize']
+ ious = get_iou(gt_smask, pred_smask_all)
+ all_batch_shape_iou += [ious]
+ if (ious > 0.9).sum() == len(ious):
+ all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
+ break
+ if self.interactive_mode in ['best', 'best_random']:
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
+ elif self.interactive_mode == 'random':
+ extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
+ else:
+ assert False, "invalid interactive mode"
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
+
+ # visualization
+ # VL.step()
+ # import cv2
+ # v_masks = []
+ # v_pos_masks = []
+ # v_neg_masks = []
+ # txt = []
+
+ # img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
+ # mask_img = VL.overlay_single_mask_to_image(img[:,:,::-1], v_gt_mask.cpu().float().numpy())
+ # acc_pos_mask = np.zeros(v_pos_mask[0].shape)
+ # acc_neg_mask = np.zeros(v_neg_mask[0].shape)
+ # for x,y,z,iou in zip(v_pos_mask, v_neg_mask, v_pred_mask, all_batch_shape_iou):
+ # # dilate x,y
+ # x = cv2.dilate(x, np.ones((5,5), np.uint8), iterations=3)
+ # y = cv2.dilate(y, np.ones((5,5), np.uint8), iterations=3)
+ # acc_pos_mask += x
+ # acc_neg_mask += y
+
+ # v_masks += [z]
+ # v_pos_masks += [acc_pos_mask.clip(0,1)]
+ # v_neg_masks += [acc_neg_mask.clip(0,1)]
+ # txt += ["pred_{}".format(str(iou[0].item())[0:5])]
+
+ # VL.add_image(img[:,:,::-1])
+ # VL.insert(mask_img, "gt_mask")
+ # VL.overlay_obj_mask_to_image_withposneg(img[:,:,::-1], v_masks, v_pos_masks, v_neg_masks, txt, max_len=20)
+ return processed_results
+
+ def evaluate_referring_image(self, batched_inputs, extra={}):
+ assert self.task_switch['spatial']
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+ assert self.interactive_mode == 'best'
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ if 'spatial_query' in batched_inputs[0]:
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
+ mask_features = mask_features.repeat(nm,1,1,1)
+
+ query_index = self.sem_seg_head.predictor.query_index
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
+
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
+ return outputs, images.tensor.shape
+
+ def evaluate_grounding(self, batched_inputs, mode):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+
+ extra = {}
+ # mask_pred_results = []
+ # for idx, batch_per_image in enumerate(batched_inputs):
+ # grd_texts = batch_per_image['groundings']['texts']
+ # grd_masks = []
+ # for anno_text in grd_texts:
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
+ # token_emb = gtext['token_emb']
+ # tokens = gtext['tokens']
+
+ # grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
+ # extra['grounding_tokens'] = grd_emb[:,None]
+
+ # assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+ # features = self.backbone(images.tensor)
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ # pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
+ # v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
+ # t_emb = grd_emb[-1:]
+
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ # matched_id = out_prob.max(0)[1]
+ # grd_masks += [pred_gmasks[matched_id,:,:]]
+ # mask_pred_results += [torch.cat(grd_masks)]
+
+ # comment for multi object inference.
+ mask_pred_results = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ grd_texts = batch_per_image['groundings']['texts']
+ grd_texts = [x[0] for x in grd_texts]
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+ query_emb = token_emb[tokens['attention_mask'].bool()]
+ non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
+
+ extra['grounding_tokens'] = query_emb[:,None]
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ pred_gmasks = outputs['pred_gmasks'][idx]
+ v_emb = outputs['pred_gtexts'][idx]
+ t_emb = gtext['class_emb']
+
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ matched_id = out_prob.max(0)[1]
+ mask_pred_results += [pred_gmasks[matched_id,:,:]]
+
+ for i in range(len(mask_pred_results)):
+ # upsample masks
+ mask_pred_results[i] = F.interpolate(
+ mask_pred_results[i][None,],
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bilinear",
+ align_corners=False,
+ )[0]
+
+ processed_results = []
+ for mask_pred_result, input_per_image, image_size in zip(
+ mask_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ processed_results[-1]['grounding_mask'] = mask_pred_result
+
+ # compute bbox
+ # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
+ # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
+ # processed_results[-1]['grounding_box'] = bbox
+
+ return processed_results
+
+ def evaluate_grounding_sptial(self, batched_inputs, mode):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+
+ extra = {}
+ dilation = 3
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
+ pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+
+ mask_pred_results = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ grd_texts = batch_per_image['groundings']['texts']
+ grd_masks = []
+ for idx2, anno_text in enumerate(grd_texts):
+ extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+
+ grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
+ non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
+ extra['grounding_tokens'] = grd_emb[:,None]
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ pred_gmasks = outputs['pred_gmasks'][idx]
+ v_emb = outputs['pred_gtexts'][idx]
+ t_emb = gtext['class_emb']
+
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ matched_id = out_prob.max(0)[1]
+ grd_masks += [pred_gmasks[matched_id,:,:]]
+ # grd_masks += [outputs['prev_mask'][0]]
+
+ mask_pred_results += [torch.cat(grd_masks)]
+
+ # comment for multi object inference.
+ # mask_pred_results = []
+ # for idx, batch_per_image in enumerate(batched_inputs):
+ # grd_texts = batch_per_image['groundings']['texts']
+ # grd_texts = [x[0] for x in grd_texts]
+
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ # token_emb = gtext['token_emb']
+ # tokens = gtext['tokens']
+ # query_emb = token_emb[tokens['attention_mask'].bool()]
+ # non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
+
+ # extra['grounding_tokens'] = query_emb[:,None]
+ # extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ # features = self.backbone(images.tensor)
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ # pred_gmasks = outputs['pred_gmasks'][idx]
+ # v_emb = outputs['pred_gtexts'][idx]
+ # t_emb = gtext['class_emb']
+
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ # matched_id = out_prob.max(0)[1]
+ # mask_pred_results += [pred_gmasks[matched_id,:,:]]
+
+ for i in range(len(mask_pred_results)):
+ # upsample masks
+ mask_pred_results[i] = F.interpolate(
+ mask_pred_results[i][None,],
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bilinear",
+ align_corners=False,
+ )[0]
+
+ processed_results = []
+ for mask_pred_result, input_per_image, image_size in zip(
+ mask_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ processed_results[-1]['grounding_mask'] = mask_pred_result
+
+ return processed_results
+
+ def prepare_targets(self, batched_inputs, images):
+ h_pad, w_pad = images.tensor.shape[-2:]
+ new_targets = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ targets_per_image = batch_per_image['instances'].to(self.device)
+ # pad gt
+ gt_masks = targets_per_image.gt_masks.tensor
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+
+ gt_boxes = targets_per_image.gt_boxes.tensor
+ ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
+ gt_boxes = gt_boxes / ratio
+ xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
+ gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
+
+ target_dict = {
+ "labels": targets_per_image.gt_classes,
+ "is_things": targets_per_image.is_things,
+ "masks": padded_masks,
+ "boxes": gt_boxes,
+ }
+
+ if self.task_switch['spatial']:
+ # prepare targets for spatial query
+ target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
+
+ if self.task_switch['grounding']:
+ grd_masks = batch_per_image['groundings']['masks']
+ grd_texts = batch_per_image['groundings']['texts']
+ grd_hash = batch_per_image['groundings']['hash']
+ grd_task = batch_per_image['groundings']['mode']
+
+ if len(grd_masks) == 0:
+ padded_masks = None
+ else:
+ padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
+ padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+
+ unique_hash_id = np.unique(grd_hash, return_index=True)[1]
+ selected_mask = np.zeros(len(grd_hash)).astype(np.bool)
+ selected_mask[unique_hash_id] = True
+
+ selected_token_emb = token_emb[selected_mask]
+ selected_attn_mask = tokens['attention_mask'][selected_mask]
+ query_emb = selected_token_emb[selected_attn_mask.bool()]
+
+ class_idx = tokens['attention_mask'].sum(dim=-1) - 1
+ class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
+ class_emb = token_emb[class_idx]
+
+ target_dict['grounding_masks'] = padded_masks
+ target_dict['grounding_query_embs'] = query_emb
+ target_dict['grounding_class_embs'] = class_emb
+ target_dict['grounding_hash'] = grd_hash
+ target_dict['grounding_task'] = grd_task
+
+ new_targets.append(target_dict)
+ return new_targets
+
+ def prepare_next_spaital_mask(self, outputs, batched_inputs, mode='best'):
+ gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
+ if self.training:
+ gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
+ else:
+ gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor.transpose(0,1)
+
+ pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
+ prev_masks = torch.stack(outputs['spatial_query_pos_mask']) | torch.stack(outputs['spatial_query_neg_mask'])
+
+ fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
+ fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
+
+ # compute iou between gt and pred
+ iou = (gt_masks & pred_masks).sum(list(range(1,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(1,len(fn.shape)))) + 1e-8)
+ fn_sum = fn.sum(dim=list(range(1,len(fn.shape))))
+ fp_sum = fp.sum(dim=list(range(1,len(fp.shape))))
+
+ is_postive = fn_sum > fp_sum
+ # is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
+ select_mask = torch.stack([fn[i] if is_postive[i] else fp[i] for i in range(len(fn))])
+
+ # conv implementation
+ n,_,h,w = select_mask.shape
+ mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
+ if mode == 'best':
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
+ elif mode == 'best_random':
+ max_xy_idx = torch.stack([torch.arange(n), torch.cat([(mask_dt[i] > 0).nonzero()[torch.randint(0, len((mask_dt[i] > 0).nonzero()), (1,))][0] for i in range(len(mask_dt))]).cpu()]).tolist()
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
+ next_mask = next_mask.view(n,-1)
+ next_mask[max_xy_idx] = True
+ next_mask = next_mask.reshape((n,1,h,w)).float()
+ dilation = 3
+ next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2) > 0
+
+ # determine whether next mask is zero
+ keep = (iou < 0.925)
+ next_mask = next_mask & keep.view(-1,1,1,1)
+
+ pos_mask = []
+ neg_mask = []
+ for idx, ip in enumerate(is_postive):
+ if ip:
+ pos_mask += [outputs['spatial_query_pos_mask'][idx] | next_mask[idx]]
+ neg_mask += [outputs['spatial_query_neg_mask'][idx]]
+ else:
+ pos_mask += [outputs['spatial_query_pos_mask'][idx]]
+ neg_mask += [outputs['spatial_query_neg_mask'][idx] | next_mask[idx]]
+
+ if 'false_positive_mask' in outputs:
+ fp = outputs['false_positive_mask'] | fp
+ return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
+ return semseg
+
+ def panoptic_inference(self, mask_cls, mask_pred):
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
+ mask_pred = mask_pred.sigmoid()
+
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
+ cur_scores = scores[keep]
+ cur_classes = labels[keep]
+ cur_masks = mask_pred[keep]
+ cur_mask_cls = mask_cls[keep]
+ cur_mask_cls = cur_mask_cls[:, :-1]
+
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
+
+ h, w = cur_masks.shape[-2:]
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
+ segments_info = []
+
+ current_segment_id = 0
+
+ if cur_masks.shape[0] == 0:
+ # We didn't detect any mask :(
+ return panoptic_seg, segments_info
+ else:
+ # take argmax
+ cur_mask_ids = cur_prob_masks.argmax(0)
+ stuff_memory_list = {}
+ for k in range(cur_classes.shape[0]):
+ pred_class = cur_classes[k].item()
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
+ mask_area = (cur_mask_ids == k).sum().item()
+ original_area = (cur_masks[k] >= 0.5).sum().item()
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
+
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
+ if mask_area / original_area < self.overlap_threshold:
+ continue
+
+ # merge stuff regions
+ if not isthing:
+ if int(pred_class) in stuff_memory_list.keys():
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
+ continue
+ else:
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
+
+ current_segment_id += 1
+ panoptic_seg[mask] = current_segment_id
+
+ segments_info.append(
+ {
+ "id": current_segment_id,
+ "isthing": bool(isthing),
+ "category_id": int(pred_class),
+ }
+ )
+
+ return panoptic_seg, segments_info
+
+ def instance_inference(self, mask_cls, mask_pred, box_pred):
+ # mask_pred is already processed to have the same shape as original input
+ image_size = mask_pred.shape[-2:]
+
+ # [Q, K]
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
+
+ labels_per_image = labels[topk_indices]
+ topk_indices = (topk_indices // self.sem_seg_head.num_classes)
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
+ mask_pred = mask_pred[topk_indices]
+ if box_pred is not None:
+ box_pred = box_pred[topk_indices]
+
+ # if this is panoptic segmentation, we only keep the "thing" classes
+ if self.panoptic_on:
+ keep = torch.zeros_like(scores_per_image).bool()
+ for i, lab in enumerate(labels_per_image):
+ keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
+
+ scores_per_image = scores_per_image[keep]
+ labels_per_image = labels_per_image[keep]
+ mask_pred = mask_pred[keep]
+
+ if box_pred is not None:
+ box_pred = box_pred[keep]
+
+ result = Instances(image_size)
+ # mask (before sigmoid)
+ result.pred_masks = (mask_pred > 0).float()
+ # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
+ # Uncomment the following to get boxes from masks (this is slow)
+
+ if box_pred is not None:
+ result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
+ else:
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
+
+ # calculate average mask prob
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
+ result.scores = scores_per_image * mask_scores_per_image
+ result.pred_classes = labels_per_image
+
+ return result
+
+ def prepare_targets4query(self, targets, images, topk=5):
+ h_pad, w_pad = images.tensor.shape[-2:]
+ new_targets = []
+ new_queries = []
+ for targets_per_image in targets:
+ # we randomly sample maximally topk concepts
+ unique_target_classes = [k for k in set(targets_per_image.gt_classes.tolist())]
+ selected_target_classes = random.sample(unique_target_classes, min(topk, len(unique_target_classes)))
+ new_targets_per_image = []
+ new_queries_per_image = []
+ for clss in selected_target_classes:
+ indices = (targets_per_image.gt_classes == clss).nonzero().view(-1)
+ # pad gt
+ gt_masks = targets_per_image.gt_masks[indices]
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+
+ # convert class into concept name and then token seq
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings([COCO_PANOPTIC_CLASSES[clss]], name='grounding')
+ query = getattr(self.sem_seg_head.predictor.lang_encoder, 'grounding_text_embeddings')
+
+ new_targets.append(
+ {
+ "labels": targets_per_image.gt_classes[indices],
+ "masks": padded_masks,
+ }
+ )
+ new_queries_per_image.append(query)
+ new_queries.append(new_queries_per_image)
+
+ return new_targets, new_queries
+
+
+
+@register_model
+def get_seem_model(cfg, **kwargs):
+ return GeneralizedSEEM(cfg)
\ No newline at end of file
diff --git a/modeling/architectures/seem_model_v1.py b/modeling/architectures/seem_model_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d762c3a3bcbef94d6d3ab3462eaa11664e10153
--- /dev/null
+++ b/modeling/architectures/seem_model_v1.py
@@ -0,0 +1,1179 @@
+# --------------------------------------------------------
+# SEEM -- Segment Everything Everywhere All at Once
+# Licensed under The Apache License 2.0 [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import random
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from kornia.contrib import distance_transform
+
+from detectron2.structures import Boxes, ImageList, Instances, BitMasks
+from detectron2.utils.memory import retry_if_cuda_oom
+from detectron2.data import MetadataCatalog
+
+from .build import register_model
+
+from ..utils import configurable, get_class_names, get_iou, Spatial_ImageList
+from ..vision.backbone import build_backbone, Backbone
+from ..body import build_xdecoder_head
+from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
+from ..language import build_language_encoder
+from ..language.loss import vl_similarity
+from utilities.prompt_engineering import prompt_engineering
+from utilities.constants import COCO_PANOPTIC_CLASSES, BIOMED_CLASSES
+
+
+class GeneralizedSEEM(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ *,
+ backbone: Backbone,
+ sem_seg_head: nn.Module,
+ criterion: nn.Module,
+ losses: dict,
+ num_queries: int,
+ object_mask_threshold: float,
+ overlap_threshold: float,
+ metadata,
+ task_switch: dict,
+ phrase_prob: float,
+ size_divisibility: int,
+ sem_seg_postprocess_before_inference: bool,
+ pixel_mean: Tuple[float],
+ pixel_std: Tuple[float],
+ # inference
+ semantic_on: bool,
+ panoptic_on: bool,
+ instance_on: bool,
+ test_topk_per_image: int,
+ train_dataset_name: str,
+ interactive_mode: str,
+ interactive_iter: str,
+ dilation_kernel: torch.Tensor,
+ train_max_iter: int,
+ binary_classes: bool,
+ standard_text_for_eval: bool,
+ ):
+ """
+ Args:
+ backbone: a backbone module, must follow detectron2's backbone interface
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
+ criterion: a module that defines the loss
+ num_queries: int, number of queries
+ object_mask_threshold: float, threshold to filter query based on classification score
+ for panoptic segmentation inference
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
+ segmentation inference
+ size_divisibility: Some backbones require the input height and width to be divisible by a
+ specific integer. We can use this to override such requirement.
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
+ to original input size before semantic segmentation inference or after.
+ For high-resolution dataset like Mapillary, resizing predictions before
+ inference will cause OOM error.
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
+ the per-channel mean and std to be used to normalize the input image
+ semantic_on: bool, whether to output semantic segmentation prediction
+ instance_on: bool, whether to output instance segmentation prediction
+ panoptic_on: bool, whether to output panoptic segmentation prediction
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
+ """
+ super().__init__()
+ self.backbone = backbone
+ self.sem_seg_head = sem_seg_head
+ self.criterion = criterion
+ self.losses = losses
+ self.num_queries = num_queries
+ self.overlap_threshold = overlap_threshold
+ self.object_mask_threshold = object_mask_threshold
+ self.metadata = metadata
+ if size_divisibility < 0:
+ # use backbone size_divisibility if not set
+ size_divisibility = self.backbone.size_divisibility
+ self.size_divisibility = size_divisibility
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ # additional args
+ self.semantic_on = semantic_on
+ self.instance_on = instance_on
+ self.panoptic_on = panoptic_on
+
+ # caption argument
+ self.task_switch = task_switch
+ self.phrase_prob = phrase_prob
+ self.train_max_iter = train_max_iter
+
+ self.test_topk_per_image = test_topk_per_image
+ self.train_class_names = get_class_names(train_dataset_name)
+ if binary_classes:
+ self.train_class_names = ['target', 'background']
+ self.interactive_mode = interactive_mode
+ self.interactive_iter = interactive_iter
+
+ if not self.semantic_on:
+ assert self.sem_seg_postprocess_before_inference
+
+ self.register_buffer("dilation_kernel", dilation_kernel)
+
+ self.standard_text_for_eval = standard_text_for_eval
+
+ @classmethod
+ def from_config(cls, cfg):
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ # Loss parameters:
+ deep_supervision = dec_cfg['DEEP_SUPERVISION']
+ no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
+
+ # loss weights
+ loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
+ 'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
+ 'spatial': {'ce': dec_cfg['SCLASS_WEIGHT'], 'dice': dec_cfg['SDICE_WEIGHT'], 'bce': dec_cfg['SMASK_WEIGHT']},
+ 'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']},
+ 'openimage': {'ce': dec_cfg['OCLASS_WEIGHT'], 'dice': dec_cfg['ODICE_WEIGHT'], 'bce': dec_cfg['OMASK_WEIGHT']}}
+
+ openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
+ 'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
+
+ task_switch = {'bbox': dec_cfg.get('DETECTION', False),
+ 'mask': dec_cfg['MASK'].get('ENABLED', True),
+ 'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
+ 'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
+ 'openimage': openimage_switch}
+
+ top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
+ 'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),
+ 'openimage': dec_cfg.get('TOP_OPENIMAGE_LAYERS', 10),
+ 'spatial': dec_cfg.get('TOP_SPATIAL_LAYERS', 10)}
+
+ spatial_cost = {"class_weight": dec_cfg['COST_SPATIAL']['CLASS_WEIGHT'],
+ "mask_weight": dec_cfg['COST_SPATIAL']['MASK_WEIGHT'],
+ "dice_weight": dec_cfg['COST_SPATIAL']['DICE_WEIGHT']}
+
+ extra = {'task_switch': task_switch}
+ backbone = build_backbone(cfg)
+ lang_encoder = build_language_encoder(cfg)
+ sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
+
+ # building criterion
+ matcher = HungarianMatcher(
+ cost_class=loss_weights['mask']['ce'],
+ cost_mask=loss_weights['mask']['bce'],
+ cost_dice=loss_weights['mask']['dice'],
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
+ spatial_cost=spatial_cost,
+ )
+
+ # init weight dict and criterion loss functions.
+ losses = {'seg': [], 'openimage': []}
+ if task_switch['mask']:
+ losses['seg'] += ["labels", "masks"]
+ if task_switch['spatial']:
+ losses['seg'] += ["spatials"]
+ if task_switch['grounding']:
+ losses['seg'] += ["groundings"]
+ if task_switch['openimage']:
+ losses['openimage'] += ["labels_openimage", "masks"]
+ if task_switch['openimage']['grounding']:
+ losses['openimage'] += ["groundings"]
+
+ weight_dict = {}
+ for key, turn_on in task_switch.items():
+ if turn_on:
+ if isinstance(loss_weights[key], dict):
+ # HACK it should support bbox in the future
+ for key_, weight in loss_weights[key].items():
+ weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
+ else:
+ weight_dict["loss_{}_0".format(key)] = loss_weights[key]
+
+ # generate full weight dict and remove not computed layers.
+ if deep_supervision:
+ dec_layers = dec_cfg['DEC_LAYERS']
+ aux_weight_dict = {}
+ for i in range(dec_layers - 1):
+ for k, v in weight_dict.items():
+ if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
+ continue
+ aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
+ weight_dict.update(aux_weight_dict)
+
+ grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
+ # generate critenrion for loss function.
+ criterion = SetCriterion(
+ sem_seg_head.num_classes,
+ matcher=matcher,
+ weight_dict=weight_dict,
+ top_x_layers=top_x_layers,
+ eos_coef=no_object_weight,
+ losses=[],
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
+ oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
+ importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
+ grounding_weight=grd_weight,
+ )
+
+ # extra logistic
+ train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
+ train_max_iter = dec_cfg['SPATIAL'].get('MAX_ITER', 3)
+ phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
+ interactive_mode = cfg['STROKE_SAMPLER']['EVAL']['MODE']
+ interactive_iter = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
+
+ dilation = 3
+ dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
+
+ return {
+ "backbone": backbone,
+ "sem_seg_head": sem_seg_head,
+ "criterion": criterion,
+ "losses": losses,
+ "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
+ "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
+ "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
+ "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
+ "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
+ "sem_seg_postprocess_before_inference": (
+ dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
+ or dec_cfg['TEST']['PANOPTIC_ON']
+ or dec_cfg['TEST']['INSTANCE_ON']
+ ),
+ "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
+ "pixel_std": cfg['INPUT']['PIXEL_STD'],
+ "task_switch": task_switch,
+ "phrase_prob": phrase_prob,
+ # inference
+ "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
+ "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
+ "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
+ "test_topk_per_image": cfg['TEST']['DETECTIONS_PER_IMAGE'],
+ "train_dataset_name": train_dataset_name,
+ "interactive_mode": interactive_mode,
+ "interactive_iter": interactive_iter,
+ "dilation_kernel": dilation_kernel,
+ "train_max_iter": train_max_iter,
+ "binary_classes": enc_cfg['BINARY_CLASSES'],
+ "standard_text_for_eval": cfg['STANDARD_TEXT_FOR_EVAL'],
+ }
+
+ @property
+ def device(self):
+ return self.pixel_mean.device
+
+ def forward(self, batched_inputs, mode='default'):
+ """
+ Args:
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
+ Each item in the list contains the inputs for one image.
+ For now, each item in the list is a dict that contains:
+ * "image": Tensor, image in (C, H, W) format.
+ * "instances": per-region ground truth
+ * Other information that's included in the original dicts, such as:
+ "height", "width" (int): the output resolution of the model (may be different
+ from input resolution), used in inference.
+ Returns:
+ list[dict]:
+ each dict has the results for one image. The dict contains the following keys:
+
+ * "sem_seg":
+ A Tensor that represents the
+ per-pixel segmentation prediced by the head.
+ The prediction has shape KxHxW that represents the logits of
+ each class for each pixel.
+ * "panoptic_seg":
+ A tuple that represent panoptic output
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
+ Each dict contains keys "id", "category_id", "isthing".
+ """
+ if self.training:
+ losses = {}
+ if self.task_switch['mask'] or self.task_switch['grounding'] or self.task_switch['spatial']:
+ losses_seg = self.forward_seg(batched_inputs)
+ losses.update(losses_seg)
+ if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
+ losses_openimage = self.forward_openimage(batched_inputs['openimage'])
+ losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
+ losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
+ losses.update(losses_openimage)
+ for k in list(losses.keys()):
+ if k in self.criterion.weight_dict:
+ losses[k] *= self.criterion.weight_dict[k]
+ else: # remove this loss if not specified in `weight_dict`
+ losses.pop(k)
+ return losses
+ else:
+ if mode == 'interactive':
+ return self.evaluate_interactive(batched_inputs)
+ elif mode == 'interactive_grounding':
+ return self.evaluate_interactive_grounding(batched_inputs)
+ elif mode == 'grounding_spatial':
+ return self.evaluate_grounding_sptial(batched_inputs, mode)
+ elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
+ return self.evaluate_grounding(batched_inputs, mode)
+ else:
+ return self.evaluate(batched_inputs)
+
+
+ def forward_seg(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
+
+ extra = {}
+ # mask classification target
+ if "instances" in batched_inputs[0]:
+ # input bounding box is checked to be correct.
+ targets = self.prepare_targets(batched_inputs, images)
+
+ if self.task_switch['grounding']:
+ grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
+ grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
+ non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
+ grounding_tokens[non_zero_query_mask] = 0
+
+ extra['grounding_tokens'] = grounding_tokens
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ if self.task_switch['spatial']:
+ pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
+ neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
+ fp_masks = nn.utils.rnn.pad_sequence([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs], padding_value=False, batch_first=True)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
+
+ features = self.backbone(images.tensor)
+ mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ # forward spatial only without gradient
+ if self.task_switch['spatial']:
+ with torch.no_grad():
+ # generate random integeter between [0,3]
+ rand_iter_num = random.randint(0, self.train_max_iter)
+ for i in range(rand_iter_num):
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
+ extra.update(outputs)
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
+
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
+
+ extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
+ 'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
+ 'false_positive_mask': extra['false_positive_mask']}
+ # bipartite matching-based loss
+ self.criterion.losses = self.losses['seg'] # seg criterion losses
+
+ if self.task_switch['mask']:
+ losses = self.criterion(outputs, targets, extra)
+ else:
+ losses = self.criterion.forward_vlp(outputs, targets, extra)
+
+ del outputs
+ return losses
+
+ def evaluate(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
+
+ mask_cls_results = outputs["pred_logits"]
+ mask_pred_results = outputs["pred_masks"]
+ box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
+
+ # upsample masks
+ mask_pred_results = F.interpolate(
+ mask_pred_results,
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ input_size = mask_pred_results.shape[-2:]
+ del outputs
+
+ processed_results = []
+ for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
+ mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ if self.sem_seg_postprocess_before_inference:
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
+
+ # semantic segmentation inference
+ if self.semantic_on:
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
+ if not self.sem_seg_postprocess_before_inference:
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
+ processed_results[-1]["sem_seg"] = r
+
+ # panoptic segmentation inference
+ if self.panoptic_on:
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
+ processed_results[-1]["panoptic_seg"] = panoptic_r
+
+ # instance segmentation inference
+ if self.instance_on:
+ if self.task_switch['bbox']:
+ box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
+ processed_results[-1]["instances"] = instance_r
+
+ return processed_results
+
+ def evaluate_interactive(self, batched_inputs):
+ assert self.task_switch['spatial']
+ assert 'spatial_query' in batched_inputs[0]
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ extra = {}
+
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+
+ all_batch_shape_iou = []
+ pred_smask_pointer = None
+ prev_smask_pointer = None
+ pred_smask_all = None
+
+ # visualization code
+ # v_pred_mask = []
+ # v_pos_mask = []
+ # v_neg_mask = []
+ # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
+ query_index = self.sem_seg_head.predictor.query_index
+ if self.interactive_mode in ['best', 'best_random']:
+ pos_masks = [x['spatial_query']['rand_shape'].to(self.device)[:,0] for x in batched_inputs]
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
+
+ neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False)[:,0] for x in batched_inputs]
+
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
+ elif self.interactive_mode == 'random':
+ assert False, "interactive mode not correctly implemented"
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
+ extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
+ else:
+ assert False, "invalid interactive mode"
+
+ for i in range(self.interactive_iter):
+ # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
+ # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
+ extra.update(outputs)
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
+ # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
+
+ s = image_sizes[0]
+ b = batched_inputs[0]
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[0].sigmoid() > 0.5
+ gt_smask = b['gt_masks_orisize']
+ ious = get_iou(gt_smask, pred_smask_all)
+ all_batch_shape_iou += [ious]
+ if (ious > 0.9).sum() == len(ious):
+ all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
+ break
+ if self.interactive_mode in ['best', 'best_random']:
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
+ elif self.interactive_mode == 'random':
+ extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
+ else:
+ assert False, "invalid interactive mode"
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
+
+ return processed_results
+
+ def evaluate_interactive_single(self, batched_inputs, extra={}):
+ assert self.task_switch['spatial']
+ assert 'spatial_query' in batched_inputs[0]
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
+ mask_features = mask_features.repeat(nm,1,1,1)
+
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
+
+ s = image_sizes[0]
+ b = batched_inputs[0]
+ pred_smask_ori = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
+ pred_smask_batch = pred_smask[:,:,:s[0],:s[1]].sigmoid() > 0.5
+ ious = []
+ if 'gt_masks_orisize' in b:
+ gt_smask = b['gt_masks_orisize'].to(pred_smask_ori.device)
+ ious = get_iou(gt_smask, pred_smask_ori)
+ processed_results = [{"mask_iou": ious, 'pred_mask_ori': pred_smask_ori, 'pred_mask_batch': pred_smask_batch}]
+ return processed_results
+
+ def evaluate_interactive_grounding(self, batched_inputs):
+ assert self.task_switch['spatial']
+ assert 'spatial_query' in batched_inputs[0]
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ extra = {}
+
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
+ mask_features = mask_features.repeat(nm,1,1,1)
+
+ all_batch_shape_iou = []
+ pred_smask_pointer = None
+ prev_smask_pointer = None
+ pred_smask_all = None
+
+ # visualization code
+ # v_pred_mask = []
+ # v_pos_mask = []
+ # v_neg_mask = []
+ # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
+ query_index = self.sem_seg_head.predictor.query_index
+ if self.interactive_mode in ['best', 'best_random']:
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
+ elif self.interactive_mode == 'random':
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
+ extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
+ else:
+ assert False, "invalid interactive mode"
+
+ grd_texts = batched_inputs[0]['classes']
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+ query_emb = nn.utils.rnn.pad_sequence([_token_emb[_tokens.bool()] for _token_emb, _tokens in zip(token_emb, tokens['attention_mask'])], padding_value=-1)
+ non_zero_query_mask = (query_emb.sum(dim=-1) < 0)
+
+ extra['grounding_tokens'] = query_emb
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ for i in range(self.interactive_iter):
+ # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
+ # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
+ extra.update(outputs)
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
+ # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
+
+ s = image_sizes[0]
+ b = batched_inputs[0]
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
+ gt_smask = b['gt_masks_orisize']
+ ious = get_iou(gt_smask, pred_smask_all)
+ all_batch_shape_iou += [ious]
+ if (ious > 0.9).sum() == len(ious):
+ all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
+ break
+ if self.interactive_mode in ['best', 'best_random']:
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
+ elif self.interactive_mode == 'random':
+ extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
+ else:
+ assert False, "invalid interactive mode"
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
+
+ # visualization
+ # VL.step()
+ # import cv2
+ # v_masks = []
+ # v_pos_masks = []
+ # v_neg_masks = []
+ # txt = []
+
+ # img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
+ # mask_img = VL.overlay_single_mask_to_image(img[:,:,::-1], v_gt_mask.cpu().float().numpy())
+ # acc_pos_mask = np.zeros(v_pos_mask[0].shape)
+ # acc_neg_mask = np.zeros(v_neg_mask[0].shape)
+ # for x,y,z,iou in zip(v_pos_mask, v_neg_mask, v_pred_mask, all_batch_shape_iou):
+ # # dilate x,y
+ # x = cv2.dilate(x, np.ones((5,5), np.uint8), iterations=3)
+ # y = cv2.dilate(y, np.ones((5,5), np.uint8), iterations=3)
+ # acc_pos_mask += x
+ # acc_neg_mask += y
+
+ # v_masks += [z]
+ # v_pos_masks += [acc_pos_mask.clip(0,1)]
+ # v_neg_masks += [acc_neg_mask.clip(0,1)]
+ # txt += ["pred_{}".format(str(iou[0].item())[0:5])]
+
+ # VL.add_image(img[:,:,::-1])
+ # VL.insert(mask_img, "gt_mask")
+ # VL.overlay_obj_mask_to_image_withposneg(img[:,:,::-1], v_masks, v_pos_masks, v_neg_masks, txt, max_len=20)
+ return processed_results
+
+ def evaluate_referring_image(self, batched_inputs, extra={}):
+ assert self.task_switch['spatial']
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
+ assert self.interactive_mode == 'best'
+
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
+
+ if 'spatial_query' in batched_inputs[0]:
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
+ mask_features = mask_features.repeat(nm,1,1,1)
+
+ query_index = self.sem_seg_head.predictor.query_index
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
+
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
+ return outputs, images.tensor.shape
+
+ def evaluate_grounding(self, batched_inputs, mode):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+
+ extra = {}
+ # mask_pred_results = []
+ # for idx, batch_per_image in enumerate(batched_inputs):
+ # grd_texts = batch_per_image['groundings']['texts']
+ # grd_masks = []
+ # for anno_text in grd_texts:
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
+ # token_emb = gtext['token_emb']
+ # tokens = gtext['tokens']
+
+ # grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
+ # extra['grounding_tokens'] = grd_emb[:,None]
+
+ # assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+ # features = self.backbone(images.tensor)
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ # pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
+ # v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
+ # t_emb = grd_emb[-1:]
+
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ # matched_id = out_prob.max(0)[1]
+ # grd_masks += [pred_gmasks[matched_id,:,:]]
+ # mask_pred_results += [torch.cat(grd_masks)]
+
+ # comment for multi object inference.
+ mask_pred_results = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ grd_texts = batch_per_image['groundings']['texts']
+ if self.standard_text_for_eval:
+ standard_texts = []
+ for grd in batch_per_image['grounding_info']:
+ mask_file = grd['mask_file'].split('.')[0].split('/')[-1]
+ target = mask_file.split('_')[-1].replace('+', ' ')
+ site = mask_file.split('_')[-2].replace('+', ' ')
+ modality = mask_file.split('_')[-3].replace('+', ' ')
+ standard_texts.append(f'{target} in {site} {modality}')
+ grd_texts = standard_texts
+ batch_per_image['groundings']['texts'] = standard_texts
+
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+ query_emb = token_emb[tokens['attention_mask'].bool()]
+ non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
+
+ extra['grounding_tokens'] = query_emb[:,None]
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ pred_gmasks = outputs['pred_gmasks'][idx]
+ v_emb = outputs['pred_gtexts'][idx]
+ t_emb = gtext['class_emb']
+
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ matched_id = out_prob.max(0)[1]
+ mask_pred_results += [pred_gmasks[matched_id,:,:]]
+
+ for i in range(len(mask_pred_results)):
+ # upsample masks
+ mask_pred_results[i] = F.interpolate(
+ mask_pred_results[i][None,],
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bilinear",
+ align_corners=False,
+ )[0]
+
+ processed_results = []
+ for mask_pred_result, input_per_image, image_size in zip(
+ mask_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ processed_results[-1]['grounding_mask'] = mask_pred_result
+
+ # compute bbox
+ # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
+ # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
+ # processed_results[-1]['grounding_box'] = bbox
+
+ return processed_results
+
+ def evaluate_grounding_sptial(self, batched_inputs, mode):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+
+ extra = {}
+ dilation = 3
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
+ pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
+
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
+
+ mask_pred_results = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ grd_texts = batch_per_image['groundings']['texts']
+ grd_masks = []
+ for idx2, anno_text in enumerate(grd_texts):
+ extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+
+ grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
+ non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
+ extra['grounding_tokens'] = grd_emb[:,None]
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ pred_gmasks = outputs['pred_gmasks'][idx]
+ v_emb = outputs['pred_gtexts'][idx]
+ t_emb = gtext['class_emb']
+
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ matched_id = out_prob.max(0)[1]
+ grd_masks += [pred_gmasks[matched_id,:,:]]
+ # grd_masks += [outputs['prev_mask'][0]]
+
+ mask_pred_results += [torch.cat(grd_masks)]
+
+ # comment for multi object inference.
+ # mask_pred_results = []
+ # for idx, batch_per_image in enumerate(batched_inputs):
+ # grd_texts = batch_per_image['groundings']['texts']
+ # grd_texts = [x[0] for x in grd_texts]
+
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ # token_emb = gtext['token_emb']
+ # tokens = gtext['tokens']
+ # query_emb = token_emb[tokens['attention_mask'].bool()]
+ # non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
+
+ # extra['grounding_tokens'] = query_emb[:,None]
+ # extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
+
+ # features = self.backbone(images.tensor)
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ # pred_gmasks = outputs['pred_gmasks'][idx]
+ # v_emb = outputs['pred_gtexts'][idx]
+ # t_emb = gtext['class_emb']
+
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ # matched_id = out_prob.max(0)[1]
+ # mask_pred_results += [pred_gmasks[matched_id,:,:]]
+
+ for i in range(len(mask_pred_results)):
+ # upsample masks
+ mask_pred_results[i] = F.interpolate(
+ mask_pred_results[i][None,],
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bilinear",
+ align_corners=False,
+ )[0]
+
+ processed_results = []
+ for mask_pred_result, input_per_image, image_size in zip(
+ mask_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ processed_results[-1]['grounding_mask'] = mask_pred_result
+
+ return processed_results
+
+ def prepare_targets(self, batched_inputs, images):
+ h_pad, w_pad = images.tensor.shape[-2:]
+ new_targets = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ target_dict = {}
+ if self.task_switch['mask']:
+ targets_per_image = batch_per_image['instances'].to(self.device)
+ # pad gt
+ gt_masks = targets_per_image.gt_masks.tensor
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+
+ gt_boxes = targets_per_image.gt_boxes.tensor
+ ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
+ gt_boxes = gt_boxes / ratio
+ xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
+ gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
+
+ target_dict.update({
+ "labels": targets_per_image.gt_classes,
+ "is_things": targets_per_image.is_things,
+ "masks": padded_masks,
+ "boxes": gt_boxes,
+ })
+
+ if self.task_switch['spatial']:
+ # prepare targets for spatial query
+ target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
+
+ if self.task_switch['grounding']:
+ grd_masks = batch_per_image['groundings']['masks']
+ grd_texts = batch_per_image['groundings']['texts']
+ grd_hash = batch_per_image['groundings']['hash']
+ grd_task = batch_per_image['groundings']['mode']
+
+ if len(grd_masks) == 0:
+ padded_masks = None
+ else:
+ padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
+ padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+
+ unique_hash_id = np.unique(grd_hash, return_index=True)[1]
+ selected_mask = np.zeros(len(grd_hash)).astype(bool)
+ selected_mask[unique_hash_id] = True
+
+ selected_token_emb = token_emb[selected_mask]
+ selected_attn_mask = tokens['attention_mask'][selected_mask]
+ query_emb = selected_token_emb[selected_attn_mask.bool()]
+
+ class_idx = tokens['attention_mask'].sum(dim=-1) - 1
+ class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
+ class_emb = token_emb[class_idx]
+
+ target_dict['grounding_masks'] = padded_masks
+ target_dict['grounding_query_embs'] = query_emb
+ target_dict['grounding_class_embs'] = class_emb
+ target_dict['grounding_hash'] = grd_hash
+ target_dict['grounding_task'] = grd_task
+
+ new_targets.append(target_dict)
+ return new_targets
+
+ def prepare_next_spaital_mask(self, outputs, batched_inputs, mode='best'):
+ gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
+ gt_masks = Spatial_ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
+
+ pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
+ prev_masks = nn.utils.rnn.pad_sequence(outputs['spatial_query_pos_mask'], padding_value=False, batch_first=True) | \
+ nn.utils.rnn.pad_sequence(outputs['spatial_query_neg_mask'], padding_value=False, batch_first=True)
+
+ fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
+ fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
+
+ # compute iou between gt and pred
+ iou = (gt_masks & pred_masks).sum(list(range(2,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(2,len(fn.shape)))) + 1e-8)
+ fn_sum = fn.sum(dim=list(range(2,len(fn.shape))))
+ fp_sum = fp.sum(dim=list(range(2,len(fp.shape))))
+
+ is_postive = fn_sum > fp_sum
+ select_mask = torch.zeros_like(fn)
+ select_mask[is_postive] = fn[is_postive]
+ select_mask[~is_postive] = fp[~is_postive]
+ # is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
+
+ # conv implementation
+ bs,ns,h,w = select_mask.shape
+ mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(bs*ns,-1)
+ if mode == 'best':
+ max_xy_idx = torch.stack([torch.arange(bs*ns), mask_dt.max(dim=-1)[1].cpu()]).tolist()
+ elif mode == 'best_random':
+ max_xy_idx = torch.stack([torch.arange(bs*ns), torch.cat([(mask_dt[i] > 0).nonzero()[torch.randint(0, len((mask_dt[i] > 0).nonzero()), (1,))][0] for i in range(len(mask_dt))]).cpu()]).tolist()
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
+ next_mask = next_mask.view(bs*ns,-1)
+ next_mask[max_xy_idx] = True
+ next_mask = next_mask.reshape((bs*ns,1,h,w)).float()
+ dilation = 3
+ next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2).reshape(bs,ns,h,w) > 0
+
+ # determine whether next mask is zero
+ keep = (iou < 0.925)
+ next_mask = next_mask & keep.view(bs,ns,1,1)
+
+ pos_mask = []
+ neg_mask = []
+ for idx, ip in enumerate(is_postive):
+ mask_len = len(outputs['spatial_query_pos_mask'][idx])
+ pos_mask += [outputs['spatial_query_pos_mask'][idx] | (next_mask[idx][:mask_len] & ip[:mask_len,None,None])]
+ neg_mask += [outputs['spatial_query_neg_mask'][idx] | (next_mask[idx][:mask_len] & (~ip[:mask_len,None,None]))]
+
+ if 'false_positive_mask' in outputs:
+ fp = outputs['false_positive_mask'] | fp
+ return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
+ return semseg
+
+ def panoptic_inference(self, mask_cls, mask_pred):
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
+ mask_pred = mask_pred.sigmoid()
+
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
+ cur_scores = scores[keep]
+ cur_classes = labels[keep]
+ cur_masks = mask_pred[keep]
+ cur_mask_cls = mask_cls[keep]
+ cur_mask_cls = cur_mask_cls[:, :-1]
+
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
+
+ h, w = cur_masks.shape[-2:]
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
+ segments_info = []
+
+ current_segment_id = 0
+
+ if cur_masks.shape[0] == 0:
+ # We didn't detect any mask :(
+ return panoptic_seg, segments_info
+ else:
+ # take argmax
+ cur_mask_ids = cur_prob_masks.argmax(0)
+ stuff_memory_list = {}
+ for k in range(cur_classes.shape[0]):
+ pred_class = cur_classes[k].item()
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
+ mask_area = (cur_mask_ids == k).sum().item()
+ original_area = (cur_masks[k] >= 0.5).sum().item()
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
+
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
+ if mask_area / original_area < self.overlap_threshold:
+ continue
+
+ # merge stuff regions
+ if not isthing:
+ if int(pred_class) in stuff_memory_list.keys():
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
+ continue
+ else:
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
+
+ current_segment_id += 1
+ panoptic_seg[mask] = current_segment_id
+
+ segments_info.append(
+ {
+ "id": current_segment_id,
+ "isthing": bool(isthing),
+ "category_id": int(pred_class),
+ }
+ )
+
+ return panoptic_seg, segments_info
+
+ def instance_inference(self, mask_cls, mask_pred, box_pred):
+ # mask_pred is already processed to have the same shape as original input
+ image_size = mask_pred.shape[-2:]
+
+ # [Q, K]
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
+
+ labels_per_image = labels[topk_indices]
+ topk_indices = (topk_indices // self.sem_seg_head.num_classes)
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
+ mask_pred = mask_pred[topk_indices]
+ if box_pred is not None:
+ box_pred = box_pred[topk_indices]
+
+ # if this is panoptic segmentation, we only keep the "thing" classes
+ if self.panoptic_on:
+ keep = torch.zeros_like(scores_per_image).bool()
+ for i, lab in enumerate(labels_per_image):
+ keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
+
+ scores_per_image = scores_per_image[keep]
+ labels_per_image = labels_per_image[keep]
+ mask_pred = mask_pred[keep]
+
+ if box_pred is not None:
+ box_pred = box_pred[keep]
+
+ result = Instances(image_size)
+ # mask (before sigmoid)
+ result.pred_masks = (mask_pred > 0).float()
+ # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
+ # Uncomment the following to get boxes from masks (this is slow)
+
+ if box_pred is not None:
+ result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
+ else:
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
+
+ # calculate average mask prob
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
+ result.scores = scores_per_image * mask_scores_per_image
+ result.pred_classes = labels_per_image
+
+ return result
+
+ def prepare_targets4query(self, targets, images, topk=5):
+ h_pad, w_pad = images.tensor.shape[-2:]
+ new_targets = []
+ new_queries = []
+ for targets_per_image in targets:
+ # we randomly sample maximally topk concepts
+ unique_target_classes = [k for k in set(targets_per_image.gt_classes.tolist())]
+ selected_target_classes = random.sample(unique_target_classes, min(topk, len(unique_target_classes)))
+ new_targets_per_image = []
+ new_queries_per_image = []
+ for clss in selected_target_classes:
+ indices = (targets_per_image.gt_classes == clss).nonzero().view(-1)
+ # pad gt
+ gt_masks = targets_per_image.gt_masks[indices]
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+
+ # convert class into concept name and then token seq
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings([BIOMED_CLASSES[clss]], name='grounding')
+ query = getattr(self.sem_seg_head.predictor.lang_encoder, 'grounding_text_embeddings')
+
+ new_targets.append(
+ {
+ "labels": targets_per_image.gt_classes[indices],
+ "masks": padded_masks,
+ }
+ )
+ new_queries_per_image.append(query)
+ new_queries.append(new_queries_per_image)
+
+ return new_targets, new_queries
+
+
+
+@register_model
+def get_seem_model(cfg, **kwargs):
+ return GeneralizedSEEM(cfg)
\ No newline at end of file
diff --git a/modeling/architectures/xdecoder_model.py b/modeling/architectures/xdecoder_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..32cd76adfce667e8c3a14b4216888aa33a0ad8c8
--- /dev/null
+++ b/modeling/architectures/xdecoder_model.py
@@ -0,0 +1,937 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu), Ziyi Dou, Jianwei Yang
+# --------------------------------------------------------
+
+from typing import Tuple
+import random
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+import numpy as np
+
+from timm.models.layers import trunc_normal_
+from nltk.stem.lancaster import LancasterStemmer
+from detectron2.structures import Boxes, ImageList, Instances, BitMasks, BoxMode
+from detectron2.utils.memory import retry_if_cuda_oom
+from detectron2.data import MetadataCatalog
+
+from .build import register_model
+from ..utils import configurable, get_class_names
+from ..vision.backbone import build_backbone, Backbone
+from ..body import build_xdecoder_head
+from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
+from ..language import build_language_encoder
+from ..language.loss import vl_similarity, image_text_contrastive_loss_queue
+from utilities.prompt_engineering import prompt_engineering
+from utilities.constants import COCO_PANOPTIC_CLASSES
+
+st = LancasterStemmer()
+
+
+class GeneralizedXdecoder(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ *,
+ backbone: Backbone,
+ sem_seg_head: nn.Module,
+ criterion: nn.Module,
+ losses: dict,
+ num_queries: int,
+ object_mask_threshold: float,
+ overlap_threshold: float,
+ metadata,
+ task_switch: dict,
+ phrase_prob: float,
+ size_divisibility: int,
+ sem_seg_postprocess_before_inference: bool,
+ pixel_mean: Tuple[float],
+ pixel_std: Tuple[float],
+ # inference
+ semantic_on: bool,
+ panoptic_on: bool,
+ instance_on: bool,
+ test_topk_per_image: int,
+ train_dataset_name: str,
+ retrieval_emsemble: bool,
+ backbone_dim: int,
+ dim_proj: int,
+ ):
+ """
+ Args:
+ backbone: a backbone module, must follow detectron2's backbone interface
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
+ criterion: a module that defines the loss
+ num_queries: int, number of queries
+ object_mask_threshold: float, threshold to filter query based on classification score
+ for panoptic segmentation inference
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
+ segmentation inference
+ size_divisibility: Some backbones require the input height and width to be divisible by a
+ specific integer. We can use this to override such requirement.
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
+ to original input size before semantic segmentation inference or after.
+ For high-resolution dataset like Mapillary, resizing predictions before
+ inference will cause OOM error.
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
+ the per-channel mean and std to be used to normalize the input image
+ semantic_on: bool, whether to output semantic segmentation prediction
+ instance_on: bool, whether to output instance segmentation prediction
+ panoptic_on: bool, whether to output panoptic segmentation prediction
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
+ """
+ super().__init__()
+ self.backbone = backbone
+ self.sem_seg_head = sem_seg_head
+ self.criterion = criterion
+ self.losses = losses
+ self.num_queries = num_queries
+ self.overlap_threshold = overlap_threshold
+ self.object_mask_threshold = object_mask_threshold
+ self.metadata = metadata
+ if size_divisibility < 0:
+ # use backbone size_divisibility if not set
+ size_divisibility = self.backbone.size_divisibility
+ self.size_divisibility = size_divisibility
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ # additional args
+ self.semantic_on = semantic_on
+ self.instance_on = instance_on
+ self.panoptic_on = panoptic_on
+
+ # caption argument
+ self.task_switch = task_switch
+ self.phrase_prob = phrase_prob
+
+ self.test_topk_per_image = test_topk_per_image
+ self.train_class_names = get_class_names(train_dataset_name)
+
+ self.retrieval_emsemble = retrieval_emsemble
+ # backbone itc loss
+ if task_switch['retrieval'] and retrieval_emsemble:
+ self.backbone_proj = nn.Parameter(torch.empty(backbone_dim, dim_proj))
+ trunc_normal_(self.backbone_proj, std=.02)
+
+ if not self.semantic_on:
+ assert self.sem_seg_postprocess_before_inference
+
+ @classmethod
+ def from_config(cls, cfg):
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ # Loss parameters:
+ deep_supervision = dec_cfg['DEEP_SUPERVISION']
+ no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
+
+ # loss weights, switcher for task, and top layers to compute loss
+ loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
+ 'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
+ 'caption': dec_cfg['CAPTION_WEIGHT'],
+ 'captioning': dec_cfg['CAPTIONING_WEIGHT'],
+ 'retrieval': {'decoder': dec_cfg['RETRIEVAL_WEIGHT'], 'backbone': dec_cfg['BACKBONER_WEIGHT']},
+ 'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']}}
+
+ task_switch = {'bbox': dec_cfg.get('DETECTION', False),
+ 'mask': dec_cfg.get('MASK', True),
+ 'caption': dec_cfg['CAPTION'].get('ENABLED', False),
+ 'captioning': dec_cfg['CAPTIONING'].get('ENABLED', False),
+ 'retrieval': dec_cfg['RETRIEVAL'].get('ENABLED', False),
+ 'grounding': dec_cfg['GROUNDING'].get('ENABLED', False)}
+
+ top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
+ 'caption': dec_cfg.get('TOP_CAPTION_LAYERS', 10),
+ 'captioning': dec_cfg.get('TOP_CAPTIONING_LAYERS', 10),
+ 'retrieval': dec_cfg.get('TOP_RETRIEVAL_LAYERS', 10),
+ 'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),}
+
+ # build model
+ extra = {'task_switch': task_switch}
+ backbone = build_backbone(cfg)
+ lang_encoder = build_language_encoder(cfg)
+ sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra)
+
+ # building criterion
+ matcher = HungarianMatcher(
+ cost_class=loss_weights['mask']['ce'],
+ cost_mask=loss_weights['mask']['bce'],
+ cost_dice=loss_weights['mask']['dice'],
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
+ )
+
+ # init weight dict and criterion loss functions.
+ losses = {'seg': [], 'vlp': []}
+ if task_switch['mask']:
+ losses['seg'] += ["labels", "masks"]
+ if task_switch['caption']:
+ losses['seg'] += ["captions"]
+ if task_switch['grounding']:
+ losses['seg'] += ["groundings"]
+ if task_switch['captioning']:
+ losses['vlp'] += ["captionings"]
+ if task_switch['retrieval']:
+ losses['vlp'] += ["retrievals"]
+
+ weight_dict = {}
+ for key, turn_on in task_switch.items():
+ if turn_on:
+ if isinstance(loss_weights[key], dict):
+ # HACK it should support bbox in the future
+ for key_, weight in loss_weights[key].items():
+ weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
+ else:
+ weight_dict["loss_{}_0".format(key)] = loss_weights[key]
+
+ # generate full weight dict and remove not computed layers.
+ if deep_supervision:
+ dec_layers = dec_cfg['DEC_LAYERS']
+ aux_weight_dict = {}
+ for i in range(dec_layers - 1):
+ for k, v in weight_dict.items():
+ if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
+ continue
+ aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
+ weight_dict.update(aux_weight_dict)
+
+ grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
+ # generate critenrion for loss function.
+ criterion = SetCriterion(
+ sem_seg_head.num_classes,
+ matcher=matcher,
+ weight_dict=weight_dict,
+ top_x_layers=top_x_layers,
+ eos_coef=no_object_weight,
+ losses=[],
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
+ oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
+ importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
+ grounding_weight=grd_weight,
+ )
+
+ # extra logistic
+ train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
+ phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
+
+ return {
+ "backbone": backbone,
+ "sem_seg_head": sem_seg_head,
+ "criterion": criterion,
+ "losses": losses,
+ "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
+ "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
+ "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
+ "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
+ "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
+ "sem_seg_postprocess_before_inference": (
+ dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
+ or dec_cfg['TEST']['PANOPTIC_ON']
+ or dec_cfg['TEST']['INSTANCE_ON']
+ ),
+ "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
+ "pixel_std": cfg['INPUT']['PIXEL_STD'],
+ "task_switch": task_switch,
+ "phrase_prob": phrase_prob,
+ # inference
+ "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
+ "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
+ "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
+ "test_topk_per_image": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'],
+ "train_dataset_name": train_dataset_name,
+ "retrieval_emsemble": dec_cfg['RETRIEVAL']['ENSEMBLE'],
+ "backbone_dim": cfg['MODEL']['BACKBONE_DIM'],
+ "dim_proj": cfg['MODEL']['DIM_PROJ'],
+ }
+
+ @property
+ def device(self):
+ return self.pixel_mean.device
+
+ def forward(self, batched_inputs, mode=None):
+ """
+ Args:
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
+ Each item in the list contains the inputs for one image.
+ For now, each item in the list is a dict that contains:
+ * "image": Tensor, image in (C, H, W) format.
+ * "instances": per-region ground truth
+ * Other information that's included in the original dicts, such as:
+ "height", "width" (int): the output resolution of the model (may be different
+ from input resolution), used in inference.
+ Returns:
+ list[dict]:
+ each dict has the results for one image. The dict contains the following keys:
+
+ * "sem_seg":
+ A Tensor that represents the
+ per-pixel segmentation prediced by the head.
+ The prediction has shape KxHxW that represents the logits of
+ each class for each pixel.
+ * "panoptic_seg":
+ A tuple that represent panoptic output
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
+ Each dict contains keys "id", "category_id", "isthing".
+ """
+ if self.training:
+ losses = {}
+ if self.task_switch['mask']:
+ losses_seg = self.forward_seg(batched_inputs['coco'])
+ losses.update(losses_seg)
+ if self.task_switch['retrieval'] or self.task_switch['captioning']:
+ losses_vlp = self.forward_vlp(batched_inputs['vlp'])
+ losses.update(losses_vlp)
+ for k in list(losses.keys()):
+ if k in self.criterion.weight_dict:
+ losses[k] *= self.criterion.weight_dict[k]
+ else: # remove this loss if not specified in `weight_dict`
+ losses.pop(k)
+ return losses
+ else:
+ if mode == 'retrieval':
+ return self.evaluate_retrieval(batched_inputs)
+ elif mode == 'captioning':
+ return self.evaluate_captioning(batched_inputs)
+ elif mode == 'classification':
+ return self.evaluate_classification(batched_inputs)
+ elif mode == 'grounding_refcoco':
+ return self.evaluate_grounding(batched_inputs, mode)
+ else:
+ return self.evaluate(batched_inputs)
+
+
+ def forward_seg(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
+
+ extra = {}
+ # mask classification target
+ if "instances" in batched_inputs[0]:
+ # input bounding box is checked to be correct.
+ targets = self.prepare_targets(batched_inputs, images)
+
+ if self.task_switch['grounding']:
+ grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
+ grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens)
+ extra['grounding_tokens'] = grounding_tokens
+
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, extra=extra)
+
+ _outputs = {}
+ for key, value in outputs.items():
+ if key == 'pred_logits':
+ _outputs[key] = value[:,:self.num_queries-1]
+ elif key == 'pred_masks':
+ _outputs[key] = value[:,:self.num_queries-1]
+ if self.task_switch['grounding']:
+ _outputs['pred_gmasks'] = value[:,self.num_queries:2*self.num_queries-1]
+ elif key == 'pred_captions':
+ _outputs[key] = value[:,:self.num_queries-1]
+ if self.task_switch['grounding']:
+ _outputs['pred_gtexts'] = value[:,self.num_queries:2*self.num_queries-1]
+ elif key == 'aux_outputs':
+ _outputs[key] = []
+ for i in range(len(value)):
+ _outputs[key] += [{}]
+ for _key, _value in value[i].items():
+ if _key == 'pred_logits':
+ _outputs[key][i][_key] = _value[:,:self.num_queries-1]
+ elif _key == 'pred_masks':
+ _outputs[key][i][_key] = _value[:,:self.num_queries-1]
+ if self.task_switch['grounding']:
+ _outputs[key][i]['pred_gmasks'] = _value[:,self.num_queries:2*self.num_queries-1]
+ elif _key == 'pred_captions':
+ _outputs[key][i][_key] = _value[:,:self.num_queries-1]
+ if self.task_switch['grounding']:
+ _outputs[key][i]['pred_gtexts'] = _value[:,self.num_queries:2*self.num_queries-1]
+ outputs = _outputs
+
+ extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
+ 'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default'))}
+
+ # bipartite matching-based loss
+ self.criterion.losses = self.losses['seg'] # seg criterion losses
+ losses = self.criterion(outputs, targets, extra)
+
+ del outputs
+ del _outputs
+ return losses
+
+ def forward_vlp(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ targets_vlp = self.prepare_vlp_targets(batched_inputs, images.tensor.device)
+
+ extra = {"token_embedding": self.sem_seg_head.predictor.lang_encoder.lang_encoder.token_embedding,
+ "lang_encoder": self.sem_seg_head.predictor.lang_encoder,
+ "training": self.training}
+
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, target_queries=None, target_vlp=targets_vlp, task='vlp', extra=extra)
+
+ for key, value in outputs.items():
+ if key == 'pred_captionings':
+ outputs[key] = value
+ elif key == 'pred_captions':
+ # outputs[key] = value[:,-1:]
+ outputs[key] = value
+ elif key == 'aux_outputs':
+ outputs[key] = []
+ for i in range(len(value)):
+ outputs[key] += [{}]
+ for _key, _value in value[i].items():
+ if _key == 'pred_captions':
+ # outputs[key][i][_key] = _value[:,-1:]
+ outputs[key][i][_key] = _value
+ elif _key == 'pred_captionings':
+ outputs[key][i][_key] = _value
+
+ self.criterion.losses = self.losses['vlp'] # seg criterion losses
+ losses = self.criterion.forward_vlp(outputs, targets_vlp, extra)
+ del outputs
+
+ if self.task_switch['retrieval'] and self.retrieval_emsemble:
+ # compute backbone vlp.
+ v_emb = features['res5']
+ bs,nc,_,_ = v_emb.shape
+ v_emb = v_emb.reshape(bs,nc,-1)
+ v_emb = F.adaptive_avg_pool1d(v_emb, 1).reshape(bs,nc) @ self.backbone_proj
+ t_emb = torch.cat([x['caption_proj'] for x in targets_vlp], dim=0)
+ loss_contrast = image_text_contrastive_loss_queue(v_emb, t_emb, self.sem_seg_head.predictor.lang_encoder, None)
+ losses['loss_retrieval_backbone_0'] = loss_contrast
+ return losses
+
+ def evaluate(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
+
+ mask_cls_results = outputs["pred_logits"]
+ mask_pred_results = outputs["pred_masks"]
+ box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
+ caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
+
+ # upsample masks
+ mask_pred_results = F.interpolate(
+ mask_pred_results,
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bicubic",
+ align_corners=False,
+ antialias=True
+ )
+
+ input_size = mask_pred_results.shape[-2:]
+ keep_sem_bgd = self.metadata.keep_sem_bgd if hasattr(self.metadata, 'keep_sem_bgd') else False
+ del outputs
+
+ processed_results = []
+ for mask_cls_result, mask_pred_result, box_pred_result, caption_pred_result, input_per_image, image_size in zip(
+ mask_cls_results, mask_pred_results, box_pred_results, caption_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ if self.sem_seg_postprocess_before_inference:
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
+
+ # semantic segmentation inference
+ if self.semantic_on:
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result, keep_sem_bgd)
+ if not self.sem_seg_postprocess_before_inference:
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
+ processed_results[-1]["sem_seg"] = r
+
+ # panoptic segmentation inference
+ if self.panoptic_on:
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
+ processed_results[-1]["panoptic_seg"] = panoptic_r
+
+ # instance segmentation inference
+ if self.instance_on:
+ if self.task_switch['bbox']:
+ box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
+ processed_results[-1]["instances"] = instance_r
+ if self.task_switch['caption']:
+ processed_results[-1]["captions"] = caption_pred_result
+ processed_results[-1]["masks"] = mask_pred_result
+
+ return processed_results
+
+ def evaluate_retrieval(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
+ v_emb_it = outputs['pred_captions'][:,-1]
+
+ # compute backbone score
+ if self.task_switch['retrieval'] and self.retrieval_emsemble:
+ _v_emb_it = features['res5']
+ bs,nc,_,_ = _v_emb_it.shape
+ _v_emb_it = _v_emb_it.reshape(bs,nc,-1)
+ _v_emb_it = F.adaptive_avg_pool1d(_v_emb_it, 1).reshape(bs,nc) @ self.backbone_proj
+
+ processed_results = []
+ for idx, batch_data in enumerate(batched_inputs):
+ caption_ids = []
+ t_emb_its = []
+ processed_results.append({})
+ for caption in batch_data['captions']:
+ lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(caption)
+ t_emb_it = lang_results['class_emb']
+ caption_ids.append(batch_data['image_id'])
+ t_emb_its.append(t_emb_it)
+
+ t_emb_it = torch.cat(t_emb_its, dim=0)
+
+ image_embeds = [v_emb_it[idx].unsqueeze(0)]
+ if self.task_switch['retrieval'] and self.retrieval_emsemble:
+ image_embeds += [_v_emb_it[idx].unsqueeze(0)]
+ caption_results = {
+ 'image_embeds': image_embeds,
+ 'text_embeds': t_emb_it,
+ 'caption_ids': caption_ids,
+ 'image_ids': batch_data['image_id'],
+ }
+ processed_results[-1]["caption"] = caption_results
+
+ del features
+ return processed_results
+
+ def evaluate_captioning(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ if not hasattr(self, 'start_token'):
+ self.start_token = torch.tensor([[49406]*77], device=self.device)
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+
+ captioning_mask = None
+ if 'captioning_mask' in batched_inputs[-1]:
+ captioning_mask = torch.cat([x['captioning_mask'] for x in batched_inputs])
+
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding, task='captioning_infer', extra={'start_token': self.start_token, 'captioning_mask': captioning_mask})
+
+ processed_results = []
+ for idx, batch_data in enumerate(batched_inputs):
+ processed_results.append({})
+ processed_results[-1]["captioning_token"] = outputs['pred_captionings'][idx]
+ processed_results[-1]["captioning_text"] = outputs['pred_texts'][idx].split('.')[0]
+ processed_results[-1]["image_id"] = batched_inputs[idx]['image_id']
+
+ return processed_results
+
+ def evaluate_classification(self, batched_inputs):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
+
+ processed_results = []
+ for idx, batch_data in enumerate(batched_inputs):
+ processed_results.append({})
+ processed_results[-1]["pred_class"] = outputs['pred_logits'][idx,-1]
+ return processed_results
+
+ def evaluate_grounding_baseline(self, batched_inputs, mode):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+ img_bs = images.tensor.shape[0]
+
+ targets = targets_grounding = queries_grounding = None
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
+
+ mask_pred_results = outputs["pred_masks"]
+ caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
+
+ # upsample masks
+ mask_pred_results = F.interpolate(
+ mask_pred_results,
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bicubic",
+ align_corners=False,
+ antialias=True
+ )
+
+ processed_results = []
+ for mask_pred_result, caption_pred_result, input_per_image, image_size in zip(
+ mask_pred_results, caption_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )[:-1]
+
+ texts_all = input_per_image['groundings']['texts']
+ grd_masks = []
+ for texts in texts_all:
+ if mode == 'grounding_refcoco':
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=False, is_eval=True)
+ elif mode == 'grounding_phrasecut':
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=True, is_eval=False)
+ t_emb = getattr(self.sem_seg_head.predictor.lang_encoder, "{}_text_embeddings".format('grounding')).t()
+ v_emb = caption_pred_result[:-1]
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ vt_sim = v_emb @ t_emb
+ max_id = vt_sim.max(0)[1][0]
+ grd_masks += [mask_pred_result[max_id]]
+ processed_results[-1]['grounding_mask'] = torch.stack(grd_masks)
+
+ return processed_results
+
+ def evaluate_grounding(self, batched_inputs, mode):
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+
+ extra = {}
+ # mask_pred_results = []
+ # for idx, batch_per_image in enumerate(batched_inputs):
+ # grd_texts = batch_per_image['groundings']['texts']
+ # grd_masks = []
+ # for anno_text in grd_texts:
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
+ # token_emb = gtext['token_emb']
+ # tokens = gtext['tokens']
+
+ # grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
+ # extra['grounding_tokens'] = grd_emb[:,None]
+
+ # assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
+ # features = self.backbone(images.tensor)
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ # pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
+ # v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
+ # t_emb = grd_emb[-1:]
+
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ # matched_id = out_prob.max(0)[1]
+ # grd_masks += [pred_gmasks[matched_id,:,:]]
+ # mask_pred_results += [torch.cat(grd_masks)]
+
+ # comment for multi object inference.
+ mask_pred_results = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ grd_texts = batch_per_image['groundings']['texts']
+ grd_texts = [x[0] for x in grd_texts]
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+ query_emb = token_emb[tokens['attention_mask'].bool()]
+ extra['grounding_tokens'] = query_emb[:,None]
+
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
+
+ pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
+ v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
+ t_emb = gtext['class_emb']
+
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
+
+ matched_id = out_prob.max(0)[1]
+ mask_pred_results += [pred_gmasks[matched_id,:,:]]
+
+ for i in range(len(mask_pred_results)):
+ # upsample masks
+ mask_pred_results[i] = F.interpolate(
+ mask_pred_results[i][None,],
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bicubic",
+ align_corners=False,
+ antialias=True
+ )[0]
+
+ processed_results = []
+ for mask_pred_result, input_per_image, image_size in zip(
+ mask_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+ processed_results.append({})
+
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+ mask_pred_result, image_size, height, width
+ )
+ processed_results[-1]['grounding_mask'] = mask_pred_result
+
+ # compute bbox
+ # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
+ # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
+ # processed_results[-1]['grounding_box'] = bbox
+
+ return processed_results
+
+ def prepare_vlp_targets(self, batched_inputs, device):
+ input_ids = []
+ attention_mask = []
+ for cnt, x in enumerate(batched_inputs):
+ captions = x['captions']
+ randid = random.randint(0, len(captions)-1)
+ input_ids += x['tokens']['input_ids'][randid:randid+1]
+ attention_mask += x['tokens']['attention_mask'][randid:randid+1]
+
+ input_ids = torch.stack(input_ids)
+ attention_mask = torch.stack(attention_mask)
+ tokens = {"input_ids": input_ids, "attention_mask": attention_mask}
+ lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(tokens, token=True)
+
+ target_vlp = []
+ for cnt, x in enumerate(batched_inputs):
+ target_dict = {}
+ target_dict["caption_tokens"] = lang_results['token_emb'][cnt:cnt+1]
+ target_dict["caption_proj"] = lang_results['class_emb'][cnt:cnt+1]
+ target_dict["caption_tokenids"] = lang_results['tokens']['input_ids'][cnt:cnt+1]
+ target_dict["caption_mask"] = lang_results['tokens']['attention_mask'][cnt:cnt+1]
+ target_vlp.append(target_dict)
+ return target_vlp
+
+ def prepare_targets(self, batched_inputs, images):
+ h_pad, w_pad = images.tensor.shape[-2:]
+ new_targets = []
+ for idx, batch_per_image in enumerate(batched_inputs):
+ targets_per_image = batch_per_image["instances"].to(self.device)
+
+ # pad gt
+ gt_masks = targets_per_image.gt_masks
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+
+ gt_boxes = targets_per_image.gt_boxes.tensor
+ ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
+ gt_boxes = gt_boxes / ratio
+ xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
+ gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
+
+ target_dict = {
+ "labels": targets_per_image.gt_classes,
+ "is_things": targets_per_image.is_things,
+ "masks": padded_masks,
+ "boxes": gt_boxes
+ }
+
+ if self.task_switch['caption']:
+ caption = batch_per_image["captions"]
+ caption_noun = batch_per_image["captions_noun"]
+ rand_index = random.randint(0, len(caption)-1)
+
+ text = caption[rand_index]
+ nouns = caption_noun[rand_index]
+ noun_captions = [prompt_engineering(noun, topk=10000, suffix='.') for noun in nouns] + [text]
+
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(noun_captions, is_eval=False, name='caption_noun', prompt=False)
+ ctext = getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption_noun'))
+ target_dict["captions"] = ctext
+
+ target_dict["captions_hash"] = [(hash(st.stem(txt)) % 10**16) for txt in (nouns + [text])]
+ target_dict["labels_hash"] = [(hash(st.stem(COCO_PANOPTIC_CLASSES[label_id].replace('-other','').replace('-merged','').replace('-stuff',''))) % 10**16) for label_id in target_dict['labels']]
+
+ if self.task_switch['grounding']:
+ grd_masks = batch_per_image['groundings']['masks']
+ grd_texts = batch_per_image['groundings']['texts']
+ grd_hash = batch_per_image['groundings']['hash']
+ grd_task = batch_per_image['groundings']['mode']
+
+ if len(grd_masks) == 0:
+ padded_masks = None
+ else:
+ padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
+ padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
+
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
+ token_emb = gtext['token_emb']
+ tokens = gtext['tokens']
+
+ unique_hash_id = np.unique(grd_hash, return_index=True)[1]
+ selected_mask = np.zeros(len(grd_hash)).astype(np.bool)
+ selected_mask[unique_hash_id] = True
+
+ selected_token_emb = token_emb[selected_mask]
+ selected_attn_mask = tokens['attention_mask'][selected_mask]
+ query_emb = selected_token_emb[selected_attn_mask.bool()]
+
+ class_idx = tokens['attention_mask'].sum(dim=-1) - 1
+ class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
+ class_emb = token_emb[class_idx]
+
+ target_dict['grounding_masks'] = padded_masks
+ target_dict['grounding_query_embs'] = query_emb
+ target_dict['grounding_class_embs'] = class_emb
+ target_dict['grounding_hash'] = grd_hash
+ target_dict['grounding_task'] = grd_task
+
+ new_targets.append(target_dict)
+ return new_targets
+
+ def semantic_inference(self, mask_cls, mask_pred, keep_sem_bgd=False):
+ if keep_sem_bgd:
+ mask_cls = F.softmax(mask_cls, dim=-1)
+ else:
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
+ return semseg
+
+ def panoptic_inference(self, mask_cls, mask_pred):
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
+ mask_pred = mask_pred.sigmoid()
+
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
+ cur_scores = scores[keep]
+ cur_classes = labels[keep]
+ cur_masks = mask_pred[keep]
+ cur_mask_cls = mask_cls[keep]
+ cur_mask_cls = cur_mask_cls[:, :-1]
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
+
+ h, w = cur_masks.shape[-2:]
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
+ segments_info = []
+
+ current_segment_id = 0
+
+ if cur_masks.shape[0] == 0:
+ # We didn't detect any mask :(
+ return panoptic_seg, segments_info
+ else:
+ # take argmax
+ cur_mask_ids = cur_prob_masks.argmax(0)
+ stuff_memory_list = {}
+ thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
+ for k in range(cur_classes.shape[0]):
+ pred_class = cur_classes[k].item()
+ isthing = pred_class in thing_dataset_id_to_contiguous_id.values()
+ mask_area = (cur_mask_ids == k).sum().item()
+ original_area = (cur_masks[k] >= 0.5).sum().item()
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
+
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
+ if mask_area / original_area < self.overlap_threshold:
+ continue
+
+ # merge stuff regions
+ if not isthing:
+ if int(pred_class) in stuff_memory_list.keys():
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
+ continue
+ else:
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
+
+ current_segment_id += 1
+ panoptic_seg[mask] = current_segment_id
+
+ segments_info.append(
+ {
+ "id": current_segment_id,
+ "isthing": bool(isthing),
+ "category_id": int(pred_class),
+ }
+ )
+ return panoptic_seg, segments_info
+
+ def instance_inference(self, mask_cls, mask_pred, box_pred):
+ # mask_pred is already processed to have the same shape as original input
+ image_size = mask_pred.shape[-2:]
+
+ # [Q, K]
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
+
+ labels_per_image = labels[topk_indices]
+ topk_indices = (topk_indices // self.sem_seg_head.num_classes)
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
+ mask_pred = mask_pred[topk_indices]
+ if box_pred is not None:
+ box_pred = box_pred[topk_indices]
+
+ # if this is panoptic segmentation, we only keep the "thing" classes
+ if self.panoptic_on:
+ thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
+ keep = torch.zeros_like(scores_per_image).bool()
+ for i, lab in enumerate(labels_per_image):
+ keep[i] = lab in thing_dataset_id_to_contiguous_id.values()
+
+ scores_per_image = scores_per_image[keep]
+ labels_per_image = labels_per_image[keep]
+ mask_pred = mask_pred[keep]
+
+ if box_pred is not None:
+ box_pred = box_pred[keep]
+
+ result = Instances(image_size)
+ # mask (before sigmoid)
+ result.pred_masks = (mask_pred > 0).float()
+ # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
+ # Uncomment the following to get boxes from masks (this is slow)
+
+ if box_pred is not None:
+ result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
+ else:
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
+
+ # calculate average mask prob
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
+ result.scores = scores_per_image * mask_scores_per_image
+ result.pred_classes = labels_per_image
+
+ return result
+
+
+
+@register_model
+def get_xdecoder_model(cfg, **kwargs):
+ return GeneralizedXdecoder(cfg)
\ No newline at end of file
diff --git a/modeling/body/__init__.py b/modeling/body/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..044a59613be514852cf550c1550842bbe9335846
--- /dev/null
+++ b/modeling/body/__init__.py
@@ -0,0 +1,10 @@
+from .xdecoder_head import *
+from .build import *
+
+def build_xdecoder_head(config, *args, **kwargs):
+ model_name = config['MODEL']['HEAD']
+ if not is_model(model_name):
+ raise ValueError(f'Unkown model: {model_name}')
+
+ body = model_entrypoints(model_name)(config, *args, **kwargs)
+ return body
\ No newline at end of file
diff --git a/modeling/body/build.py b/modeling/body/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e77853edc69bd7636cfbc262483b29de07d6adb
--- /dev/null
+++ b/modeling/body/build.py
@@ -0,0 +1,13 @@
+_model_entrypoints = {}
+
+def register_body(fn):
+ module_name_split = fn.__module__.split('.')
+ model_name = module_name_split[-1]
+ _model_entrypoints[model_name] = fn
+ return fn
+
+def model_entrypoints(model_name):
+ return _model_entrypoints[model_name]
+
+def is_model(model_name):
+ return model_name in _model_entrypoints
\ No newline at end of file
diff --git a/modeling/body/xdecoder_head.py b/modeling/body/xdecoder_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc782b21d14e76e554c4d8d84452e521c52375ce
--- /dev/null
+++ b/modeling/body/xdecoder_head.py
@@ -0,0 +1,126 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+# Copyright (c) Facebook, Inc. and its affiliates.
+from typing import Dict
+
+from torch import nn
+
+from detectron2.layers import ShapeSpec
+
+from .build import register_body
+from ..vision.encoder import build_encoder
+from ..interface import build_decoder
+from ..utils import configurable
+
+
+class XdecoderHead(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ input_shape: Dict[str, ShapeSpec],
+ *,
+ num_classes: int,
+ pixel_decoder: nn.Module,
+ loss_weight: float = 1.0,
+ ignore_value: int = -1,
+ # extra parameters
+ transformer_predictor: nn.Module,
+ transformer_in_feature: str,
+ binary_classes: bool,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ input_shape: shapes (channels and stride) of the input features
+ num_classes: number of classes to predict
+ pixel_decoder: the pixel decoder module
+ loss_weight: loss weight
+ ignore_value: category id to be ignored during training.
+ transformer_predictor: the transformer decoder that makes prediction
+ transformer_in_feature: input feature name to the transformer_predictor
+ """
+ super().__init__()
+
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+ self.in_features = [k for k, v in input_shape]
+ feature_strides = [v.stride for k, v in input_shape]
+ feature_channels = [v.channels for k, v in input_shape]
+
+ self.ignore_value = ignore_value
+ self.common_stride = 4
+ self.loss_weight = loss_weight
+
+ self.pixel_decoder = pixel_decoder
+ self.predictor = transformer_predictor
+ self.transformer_in_feature = transformer_in_feature
+
+ self.num_classes = num_classes
+
+ if binary_classes:
+ self.num_classes = 1
+
+ @classmethod
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict):
+
+ in_features_type = cfg['MODEL']['DECODER']['TRANSFORMER_IN_FEATURE']
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ # figure out in_channels to transformer predictor
+ if in_features_type == "transformer_encoder":
+ transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
+ elif in_features_type == "pixel_embedding":
+ transformer_predictor_in_channels = enc_cfg['MASK_DIM']
+ elif in_features_type == "multi_scale_pixel_decoder":
+ transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
+ else:
+ transformer_predictor_in_channels = input_shape[dec_cfg['TRANSFORMER_IN_FEATURE']].channels
+
+ return {
+ "input_shape": {
+ k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
+ },
+ "ignore_value": enc_cfg['IGNORE_VALUE'],
+ "num_classes": enc_cfg.get('NUM_CLASSES', None),
+ "pixel_decoder": build_encoder(cfg, input_shape),
+ "loss_weight": enc_cfg['LOSS_WEIGHT'],
+ "transformer_in_feature": dec_cfg['TRANSFORMER_IN_FEATURE'],
+ "transformer_predictor": build_decoder(
+ cfg,
+ transformer_predictor_in_channels,
+ lang_encoder,
+ mask_classification=True,
+ extra=extra,
+ ),
+ "binary_classes": enc_cfg['BINARY_CLASSES']
+ }
+
+ def forward(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
+ return self.layers(features, mask, target_queries, target_vlp, task, extra)
+
+ def layers(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
+ mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
+
+ if self.transformer_in_feature == "multi_scale_pixel_decoder":
+ predictions = self.predictor(multi_scale_features, mask_features, mask, target_queries, target_vlp, task, extra)
+ else:
+ if self.transformer_in_feature == "transformer_encoder":
+ assert (
+ transformer_encoder_features is not None
+ ), "Please use the TransformerEncoderPixelDecoder."
+ predictions = self.predictor(transformer_encoder_features, mask_features, mask)
+ elif self.transformer_in_feature == "pixel_embedding":
+ predictions = self.predictor(mask_features, mask_features, mask)
+ else:
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
+ return predictions
+
+
+@register_body
+def get_xdecoder_head(cfg, input_shape, lang_encoder, extra):
+ return XdecoderHead(cfg, input_shape, lang_encoder, extra)
\ No newline at end of file
diff --git a/modeling/interface/__init__.py b/modeling/interface/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3b995e69c354638a3c118a1f4fbcc0ca1818bfc
--- /dev/null
+++ b/modeling/interface/__init__.py
@@ -0,0 +1,13 @@
+from .xdecoder import *
+from .seem_v0 import *
+from .seem_v1 import *
+from .seem_demo import *
+from .build import *
+
+def build_decoder(config, *args, **kwargs):
+ model_name = config['MODEL']['DECODER']['NAME']
+
+ if not is_model(model_name):
+ raise ValueError(f'Unkown model: {model_name}')
+
+ return model_entrypoints(model_name)(config, *args, **kwargs)
\ No newline at end of file
diff --git a/modeling/interface/build.py b/modeling/interface/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..1565142701e6f9f4063bbf3475703a7b3e620946
--- /dev/null
+++ b/modeling/interface/build.py
@@ -0,0 +1,14 @@
+_model_entrypoints = {}
+
+
+def register_decoder(fn):
+ module_name_split = fn.__module__.split('.')
+ model_name = module_name_split[-1]
+ _model_entrypoints[model_name] = fn
+ return fn
+
+def model_entrypoints(model_name):
+ return _model_entrypoints[model_name]
+
+def is_model(model_name):
+ return model_name in _model_entrypoints
\ No newline at end of file
diff --git a/modeling/interface/modules.py b/modeling/interface/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbfe449ec78319a2dbf4e691a220ec8967e53763
--- /dev/null
+++ b/modeling/interface/modules.py
@@ -0,0 +1,200 @@
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from timm.models.layers import trunc_normal_
+from detectron2.layers import Conv2d
+import fvcore.nn.weight_init as weight_init
+
+from ..utils import MultiheadAttention
+
+
+class SelfAttentionLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dropout=0.0,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
+
+ self.norm = nn.LayerNorm(d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self, tgt,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ tgt = self.norm(tgt)
+ return tgt
+
+ def forward_pre(self, tgt,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ tgt2 = self.norm(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+
+ return tgt
+
+ def forward(self, tgt,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, tgt_mask,
+ tgt_key_padding_mask, query_pos)
+ return self.forward_post(tgt, tgt_mask,
+ tgt_key_padding_mask, query_pos)
+
+
+class CrossAttentionLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dropout=0.0,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+
+ self.norm = nn.LayerNorm(d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self, tgt, memory,
+ memory_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)
+ tgt = tgt + self.dropout(tgt2)
+ tgt = self.norm(tgt)
+ return tgt, avg_attn
+
+ def forward_pre(self, tgt, memory,
+ memory_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ tgt2 = self.norm(tgt)
+ tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)
+ tgt = tgt + self.dropout(tgt2)
+
+ return tgt, avg_attn
+
+ def forward(self, tgt, memory,
+ memory_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, memory, memory_mask,
+ memory_key_padding_mask, pos, query_pos)
+ return self.forward_post(tgt, memory, memory_mask,
+ memory_key_padding_mask, pos, query_pos)
+
+
+class FFNLayer(nn.Module):
+
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm = nn.LayerNorm(d_model)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self, tgt):
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout(tgt2)
+ tgt = self.norm(tgt)
+ return tgt
+
+ def forward_pre(self, tgt):
+ tgt2 = self.norm(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout(tgt2)
+ return tgt
+
+ def forward(self, tgt):
+ if self.normalize_before:
+ return self.forward_pre(tgt)
+ return self.forward_post(tgt)
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
diff --git a/modeling/interface/prototype/__init__.py b/modeling/interface/prototype/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modeling/interface/prototype/attention_data_struct_seemdemo.py b/modeling/interface/prototype/attention_data_struct_seemdemo.py
new file mode 100644
index 0000000000000000000000000000000000000000..f885cfd1f8615fb631bf2913ead49feeff3ca9b6
--- /dev/null
+++ b/modeling/interface/prototype/attention_data_struct_seemdemo.py
@@ -0,0 +1,265 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+predict_name_matcher = {"predictions_class": ["pred_logits"],
+ "predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"],
+ "predictions_caption":["pred_captions", "pred_gtexts"],
+ "predictions_maskemb":["pred_maskembs", "pred_smaskembs"],
+ "predictions_pos_spatial":["pred_pspatials"],
+ "predictions_neg_spatial":["pred_nspatials"],
+ "predictions_pos_visual":["pred_pvisuals"],
+ "predictions_neg_visual":["pred_nvisuals"]}
+
+predict_index_matcher = {"predictions_class": ["queries_object"],
+ "predictions_mask":["queries_object", "queries_grounding", "queries_spatial"],
+ "predictions_caption": ["queries_object", "queries_grounding"],
+ "predictions_maskemb":["queries_object", "queries_spatial"],
+ "predictions_pos_spatial":["all"],
+ "predictions_neg_spatial":["all"],
+ "predictions_pos_visual":["all"],
+ "predictions_neg_visual":["all"]}
+
+class Variable(object):
+ '''
+ Store dataset variable for attention
+ output: embedding that accumuates during cross/self attention
+ pos: positional embedding that is fixed during cross/self attention
+ name: name of the variable
+ type: type of the variable, e.g. queries, tokens
+ attn_mask: attention mask for corss attention
+ masking: masking for padding
+ '''
+ def __init__(self, output, name, _type, pos=None):
+ self.output = output
+ self.pos = pos
+ self.name = name
+ self.type = _type
+ self.attn_mask = None
+ self.masking = None
+
+ def copy(self,):
+ output = self.output.clone() if self.output is not None else None
+ pos = self.pos.clone() if self.pos is not None else None
+ return Variable(output, self.name, self.type, pos)
+
+class AttentionDataStruct(nn.Module):
+ '''
+ Store dataset structure for cross/self attention
+ task_switch: switch for different tasks
+
+ p_attn_variables: prototype of variables that is used in cross/self attention
+ p_self_attn: prototype of variables that is used in self attention
+ p_cross_attn: prototype of variables that is used in cross attention
+ p_iter: prototype of iteration for different queries
+ p_masking: prototype of masking for different tokens
+ p_duplication: prototype of duplication for different quries
+ '''
+ def __init__(self, attn_arch, task_switch):
+ super(AttentionDataStruct, self).__init__()
+ self.task_switch = task_switch
+
+ # p stands for prototype
+ self.p_attn_variables = attn_arch['VARIABLE']
+ self.p_self_attn = attn_arch['SELF_ATTENTION']
+ self.p_cross_attn = attn_arch['CROSS_ATTENTION']
+ self.p_masking = attn_arch['MASKING']
+ self.p_duplication = attn_arch['DUPLICATION']
+
+ self.num_layers = attn_arch['NUM_LAYERS']
+
+ def reset(self, flags, task, extra):
+ # reset variables
+ self.attn_variables = {}
+ self.cross_attn_dict = {}
+ self.self_attn_dict = {}
+ self.duplication_dict = {}
+ self.query_index = {}
+ self.output = {}
+ self.flags = {}
+ self.spatial_memory = {}
+
+ # initialize duplication
+ for key, values in self.p_duplication.items():
+ for name in values:
+ self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
+
+ # initialize flag
+ self.flags = {"object": True}
+ self.flags.update(flags)
+
+ # initialize task
+ self.task = task
+
+ # initialize output
+ if self.task_switch['mask']:
+ self.output['predictions_class'] = []
+ self.output['predictions_mask'] = []
+ self.output['predictions_maskemb'] = []
+
+ if self.task_switch['bbox']:
+ self.output['predictions_bbox'] = []
+
+ if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
+ self.output['predictions_pos_spatial'] = []
+ self.output['predictions_neg_spatial'] = []
+
+ if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
+ self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
+
+ if (self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True)) \
+ or (self.task_switch['audio'] and ('audio' in self.flags and self.flags['audio']==True)):
+ self.output['predictions_caption'] = []
+
+ if self.task_switch['visual']:
+ self.output['predictions_pos_visual'] = []
+ self.output['predictions_neg_visual'] = []
+
+ # initialize cross_attn, whether the variable is used in cross attention
+ for key, values in self.p_cross_attn.items():
+ for name in values:
+ self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
+
+ # initialize self_attn, whether the variable is used in self attention, and the interactions between queries
+ for key, values in self.p_self_attn.items():
+ for name in values:
+ self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
+
+ # initialize masking
+ self.masking = self.p_masking
+
+ # initialize query_index
+ self.query_index = {"all":[0, None]}
+
+
+ def set(self, name, _type, output=None, pos=None, var=None):
+ if var is not None:
+ self.attn_variables[name] = var
+ elif name in self.duplication_dict:
+ assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
+ self.attn_variables[name] = self.attn_variables[self.duplication_dict[name]].copy()
+ else:
+ var = Variable(output, name, _type, pos)
+ self.attn_variables[name] = var
+
+ def set_results(self, results):
+ for name in self.cross_attn_name:
+ self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
+ for key in self.output:
+ self.output[key].append(results[key])
+
+ def set_maskings(self, name, masking):
+ self.attn_variables[name].masking = masking
+
+ def cross_attn_variables(self, ):
+ cross_attn_name = [key for key, value in self.cross_attn_dict.items()
+ if (value==True) and (key in self.attn_variables)
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
+ self.cross_attn_name = cross_attn_name
+
+ output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
+
+ index = 0
+ for name in cross_attn_name:
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
+ index += self.attn_variables[name].output.shape[0]
+ return output, pos_emb
+
+ def cross_attn_mask(self, size, num_heads):
+ attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
+
+ # hard code memories_spatial to previous selected mask
+ if 'memories_spatial' in self.cross_attn_name:
+ memory_attn_mask = self.spatial_memory['prev_batch_mask']
+ bs,c,_,_ = memory_attn_mask.shape
+ memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
+ memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
+ attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask
+
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+ return attn_mask
+
+ def self_attn(self, bs, num_heads):
+ self_attn_name = [key for key, value in self.self_attn_dict.items()
+ if len(value)>0 and key in self.attn_variables
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
+ self.self_attn_name = self_attn_name
+
+ output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
+
+ index = 0
+ for name in self_attn_name:
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
+ index += self.attn_variables[name].output.shape[0]
+
+ self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
+ self_attn_pair = []
+ # build self_attention mask by query interaction
+ for key1, value in self.self_attn_dict.items():
+ for key2 in value:
+ if key1 not in self_attn_name or key2 not in self_attn_name:
+ # exclude the variables that are not used in the current layer
+ continue
+ if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
+ self_attn_pair += [[key1, key2]]
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
+
+ # build self_attention mask by masking, for birectional
+ for key in self.masking:
+ if key in self_attn_name:
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
+
+ # build self_attention mask by masking, for uni-directional
+ for key1, key2 in self_attn_pair:
+ if key1 not in self_attn_name or key2 not in self_attn_name:
+ # exclude the variables that are not used in the current layer
+ continue
+ if key1 in self.masking:
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
+ if key2 in self.masking:
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
+
+ self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
+ return output, pos_emb, self_attn_mask
+
+ def update_variables(self, output, mode):
+ name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
+ for key in name_set:
+ self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
+
+ def update_spatial_results(self, results):
+ v_emb = results['pred_smaskembs']
+ pred_smasks = results['pred_smasks']
+
+ s_emb = results['pred_pspatials']
+ pred_logits = v_emb @ s_emb.transpose(1,2)
+ logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
+ pred_masks_pos = pred_smasks[logits_idx][:,None,]
+
+ extra = {"prev_mask": pred_masks_pos}
+ return extra
+
+ def organize_output(self, ):
+ outputs = {}
+ outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
+
+ for key, values in self.output.items():
+ for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
+ if idx_name not in self.query_index:
+ continue
+ outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
+ for idx, aux_values in enumerate(self.output[key][:-1]):
+ outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
+ return outputs
\ No newline at end of file
diff --git a/modeling/interface/prototype/attention_data_struct_seemv0.py b/modeling/interface/prototype/attention_data_struct_seemv0.py
new file mode 100644
index 0000000000000000000000000000000000000000..f568e9c5ea01a6f7a8036496f33ecde8f82ed18b
--- /dev/null
+++ b/modeling/interface/prototype/attention_data_struct_seemv0.py
@@ -0,0 +1,264 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+predict_name_matcher = {"predictions_class": ["pred_logits"],
+ "predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"],
+ "predictions_caption":["pred_captions", "pred_gtexts"],
+ "predictions_maskemb":["pred_smaskembs"],
+ "predictions_pos_spatial":["pred_pspatials"],
+ "predictions_neg_spatial":["pred_nspatials"],}
+
+predict_index_matcher = {"predictions_class": ["queries_object"],
+ "predictions_mask":["queries_object", "queries_grounding", "queries_spatial"],
+ "predictions_caption": ["queries_object", "queries_grounding"],
+ "predictions_maskemb":["queries_spatial"],
+ "predictions_pos_spatial":["all"],
+ "predictions_neg_spatial":["all"],}
+
+class Variable(object):
+ '''
+ Store dataset variable for attention
+ output: embedding that accumuates during cross/self attention
+ pos: positional embedding that is fixed during cross/self attention
+ name: name of the variable
+ type: type of the variable, e.g. queries, tokens
+ attn_mask: attention mask for corss attention
+ masking: masking for padding
+ '''
+ def __init__(self, output, name, _type, pos=None):
+ self.output = output
+ self.pos = pos
+ self.name = name
+ self.type = _type
+ self.attn_mask = None
+ self.masking = None
+
+ def copy(self,):
+ output = self.output.clone() if self.output is not None else None
+ pos = self.pos.clone() if self.pos is not None else None
+ return Variable(output, self.name, self.type, pos)
+
+class AttentionDataStruct(nn.Module):
+ '''
+ Store dataset structure for cross/self attention
+ task_switch: switch for different tasks
+
+ p_attn_variables: prototype of variables that is used in cross/self attention
+ p_self_attn: prototype of variables that is used in self attention
+ p_cross_attn: prototype of variables that is used in cross attention
+ p_iter: prototype of iteration for different queries
+ p_masking: prototype of masking for different tokens
+ p_duplication: prototype of duplication for different quries
+ '''
+ def __init__(self, attn_arch, task_switch):
+ super(AttentionDataStruct, self).__init__()
+ self.task_switch = task_switch
+
+ # p stands for prototype
+ self.p_attn_variables = attn_arch['VARIABLE']
+ self.p_self_attn = attn_arch['SELF_ATTENTION']
+ self.p_cross_attn = attn_arch['CROSS_ATTENTION']
+ self.p_masking = attn_arch['MASKING']
+ self.p_duplication = attn_arch['DUPLICATION']
+
+ self.num_layers = attn_arch['NUM_LAYERS']
+
+ def reset(self, flags, task, extra):
+ # reset variables
+ self.attn_variables = {}
+ self.cross_attn_dict = {}
+ self.self_attn_dict = {}
+ self.duplication_dict = {}
+ self.query_index = {}
+ self.output = {}
+ self.flags = {}
+ self.spatial_memory = {}
+
+ # initialize duplication
+ for key, values in self.p_duplication.items():
+ for name in values:
+ self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
+
+ # initialize flag
+ self.flags = {"object": True}
+ self.flags.update(flags)
+
+ # initialize task
+ self.task = task
+
+ # initialize output
+ if self.task_switch['mask']:
+ self.output['predictions_class'] = []
+ self.output['predictions_mask'] = []
+
+ if self.task_switch['bbox']:
+ self.output['predictions_bbox'] = []
+
+ if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
+ self.output['predictions_maskemb'] = []
+ self.output['predictions_pos_spatial'] = []
+ self.output['predictions_neg_spatial'] = []
+ # self.spatial_memory['spatial_query_mode'] = extra['spatial_query_mode']
+
+ if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
+ self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
+
+ if self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True):
+ self.output['predictions_caption'] = []
+
+ # initialize cross_attn, whether the variable is used in cross attention
+ for key, values in self.p_cross_attn.items():
+ for name in values:
+ self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
+
+ # initialize self_attn, whether the variable is used in self attention, and the interactions between queries
+ for key, values in self.p_self_attn.items():
+ for name in values:
+ self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
+
+ # initialize masking
+ self.masking = self.p_masking
+
+ # initialize query_index
+ self.query_index = {"all":[0, None]}
+
+
+ def set(self, name, _type, output=None, pos=None, var=None):
+ if var is not None:
+ self.attn_variables[name] = var
+ elif name in self.duplication_dict:
+ assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
+ self.attn_variables[name] = self.attn_variables[self.duplication_dict[name]].copy()
+ else:
+ var = Variable(output, name, _type, pos)
+ self.attn_variables[name] = var
+
+ def set_results(self, results):
+ for name in self.cross_attn_name:
+ self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
+ for key in self.output:
+ self.output[key].append(results[key])
+
+ def set_maskings(self, name, masking):
+ self.attn_variables[name].masking = masking
+
+ def cross_attn_variables(self, ):
+ cross_attn_name = [key for key, value in self.cross_attn_dict.items()
+ if (value==True) and (key in self.attn_variables)
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
+ self.cross_attn_name = cross_attn_name
+
+ output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
+
+ index = 0
+ for name in cross_attn_name:
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
+ index += self.attn_variables[name].output.shape[0]
+ return output, pos_emb
+
+ def cross_attn_mask(self, size, num_heads):
+ attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
+
+ # hard code memories_spatial to previous selected mask
+ if 'memories_spatial' in self.cross_attn_name:
+ memory_attn_mask = self.spatial_memory['prev_batch_mask']
+ bs,c,_,_ = memory_attn_mask.shape
+ memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
+ memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
+ attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask
+
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+ return attn_mask
+
+ def self_attn(self, bs, num_heads):
+ self_attn_name = [key for key, value in self.self_attn_dict.items()
+ if len(value)>0 and key in self.attn_variables
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
+ self.self_attn_name = self_attn_name
+
+ output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
+
+ index = 0
+ for name in self_attn_name:
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
+ index += self.attn_variables[name].output.shape[0]
+
+ self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
+ self_attn_pair = []
+ # build self_attention mask by query interaction
+ for key1, value in self.self_attn_dict.items():
+ for key2 in value:
+ if key1 not in self_attn_name or key2 not in self_attn_name:
+ # exclude the variables that are not used in the current layer
+ continue
+ if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
+ self_attn_pair += [[key1, key2]]
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
+
+ # build self_attention mask by masking, for birectional
+ for key in self.masking:
+ if key in self_attn_name:
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
+
+ # build self_attention mask by masking, for uni-directional
+ for key1, key2 in self_attn_pair:
+ if key1 not in self_attn_name or key2 not in self_attn_name:
+ # exclude the variables that are not used in the current layer
+ continue
+ if key1 in self.masking:
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
+ if key2 in self.masking:
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
+
+ self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
+ return output, pos_emb, self_attn_mask
+
+ def update_variables(self, output, mode):
+ name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
+ for key in name_set:
+ self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
+
+ def update_spatial_results(self, results):
+ v_emb = results['pred_smaskembs']
+ pred_smasks = results['pred_smasks']
+
+ s_emb = results['pred_pspatials']
+ pred_logits = v_emb @ s_emb.transpose(1,2)
+ logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
+ pred_masks_pos = pred_smasks[logits_idx][:,None,]
+
+ # s_emb = results['pred_nspatials']
+ # pred_logits = v_emb @ s_emb.transpose(1,2)
+ # logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
+ # logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
+ # logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
+ # pred_masks_neg = pred_smasks[logits_idx][:,None,]
+ # # clip the negative mask to 0, and then multiply by -1
+ # pred_masks_neg = (pred_masks_neg.clip(0) * -1)
+ # keep_neg = (s_emb.sum(dim=list(range(1, s_emb.dim()))) != 0).float()
+ # pred_masks_neg = pred_masks_neg * keep_neg[:,None,None,None]
+ # extra = {"prev_mask": pred_masks_pos + pred_masks_neg}
+
+ extra = {"prev_mask": pred_masks_pos}
+ return extra
+
+ def organize_output(self, ):
+ outputs = {}
+ outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
+ for key, values in self.output.items():
+ for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
+ if idx_name not in self.query_index:
+ continue
+ outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
+ for idx, aux_values in enumerate(self.output[key][:-1]):
+ outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
+ if self.task == 'spatial' or self.task == 'refimg':
+ outputs = self.update_spatial_results(outputs)
+ # outputs = self.update_spatial_results(outputs)
+ return outputs
\ No newline at end of file
diff --git a/modeling/interface/prototype/attention_data_struct_seemv1.py b/modeling/interface/prototype/attention_data_struct_seemv1.py
new file mode 100644
index 0000000000000000000000000000000000000000..29761c1f16ac32e92ae7cbe64c10549e84d6481f
--- /dev/null
+++ b/modeling/interface/prototype/attention_data_struct_seemv1.py
@@ -0,0 +1,302 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+predict_name_matcher = {"predictions_class": ["pred_logits"],
+ "predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"],
+ "predictions_caption":["pred_captions", "pred_gtexts", "pred_stexts"],
+ "predictions_maskemb":["pred_smaskembs"],
+ "predictions_pos_spatial":["pred_pspatials"],
+ "predictions_neg_spatial":["pred_nspatials"],}
+
+predict_index_matcher = {"predictions_class": ["queries_object"],
+ "predictions_mask":["queries_object", "queries_grounding", "queries_spatial"],
+ "predictions_caption": ["queries_object", "queries_grounding", "queries_spatial"],
+ "predictions_maskemb":["queries_spatial"],
+ "predictions_pos_spatial":["all"],
+ "predictions_neg_spatial":["all"],}
+
+class Variable(object):
+ '''
+ Store dataset variable for attention
+ output: embedding that accumuates during cross/self attention
+ pos: positional embedding that is fixed during cross/self attention
+ name: name of the variable
+ type: type of the variable, e.g. queries, tokens
+ attn_mask: attention mask for corss attention
+ masking: masking for padding
+ '''
+ def __init__(self, output, name, _type, pos=None):
+ self.output = output
+ self.pos = pos
+ self.name = name
+ self.type = _type
+ self.attn_mask = None
+ self.masking = None
+
+ def copy(self,):
+ output = self.output.clone() if self.output is not None else None
+ pos = self.pos.clone() if self.pos is not None else None
+ return Variable(output, self.name, self.type, pos)
+
+ def rand_sample(self, max_len):
+ rand_idx = torch.randint(0, len(self.pos), (max_len,))
+ self.output = self.output[rand_idx]
+ self.pos = self.pos[rand_idx]
+ return self
+
+class AttentionDataStruct(nn.Module):
+ '''
+ Store dataset structure for cross/self attention
+ task_switch: switch for different tasks
+
+ p_attn_variables: prototype of variables that is used in cross/self attention
+ p_self_attn: prototype of variables that is used in self attention
+ p_cross_attn: prototype of variables that is used in cross attention
+ p_iter: prototype of iteration for different queries
+ p_masking: prototype of masking for different tokens
+ p_duplication: prototype of duplication for different quries
+ '''
+ def __init__(self, attn_arch, task_switch):
+ super(AttentionDataStruct, self).__init__()
+ self.task_switch = task_switch
+
+ # p stands for prototype
+ self.p_attn_variables = attn_arch['VARIABLE']
+ self.p_self_attn = attn_arch['SELF_ATTENTION']
+ self.p_cross_attn = attn_arch['CROSS_ATTENTION']
+ self.p_masking = attn_arch['MASKING']
+ self.p_duplication = attn_arch['DUPLICATION']
+
+ self.num_layers = attn_arch['NUM_LAYERS']
+
+ def reset(self, flags, task, extra):
+ # reset variables
+ self.attn_variables = {}
+ self.cross_attn_dict = {}
+ self.self_attn_dict = {}
+ self.duplication_dict = {}
+ self.query_index = {}
+ self.output = {}
+ self.flags = {}
+ self.spatial_memory = {}
+ self.extra = {}
+
+ # initialize duplication
+ for key, values in self.p_duplication.items():
+ for name in values:
+ self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
+
+ # initialize flag
+ self.flags = {"object": True}
+ self.flags.update(flags)
+
+ # initialize task
+ self.task = task
+
+ # initialize output
+ if self.task_switch['mask']:
+ self.output['predictions_class'] = []
+ self.output['predictions_mask'] = []
+
+ if self.task_switch['bbox']:
+ self.output['predictions_bbox'] = []
+
+ if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
+ self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
+
+ if self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True):
+ self.output['predictions_caption'] = []
+
+ if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
+ self.output['predictions_maskemb'] = []
+ self.output['predictions_pos_spatial'] = []
+ self.output['predictions_neg_spatial'] = []
+ self.output['predictions_mask'] = [] if 'predictions_mask' not in self.output else self.output['predictions_mask']
+ self.output['predictions_class'] = [] if 'predictions_class' not in self.output else self.output['predictions_class']
+ self.output['predictions_caption'] = [] if 'predictions_caption' not in self.output else self.output['predictions_caption']
+
+ # initialize cross_attn, whether the variable is used in cross attention
+ for key, values in self.p_cross_attn.items():
+ for name in values:
+ self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
+
+ # initialize self_attn, whether the variable is used in self attention, and the interactions between queries
+ for key, values in self.p_self_attn.items():
+ for name in values:
+ self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
+
+ # initialize masking
+ self.masking = self.p_masking
+
+ # initialize query_index
+ self.query_index = {"all":[0, None]}
+
+
+ def set(self, name, _type, output=None, pos=None, var=None, sample_size=None):
+ if var is not None:
+ self.attn_variables[name] = var
+ elif name in self.duplication_dict:
+ assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
+ var = self.attn_variables[self.duplication_dict[name]].copy()
+ if sample_size is not None:
+ var = var.rand_sample(sample_size)
+ self.attn_variables[name] = var
+ else:
+ var = Variable(output, name, _type, pos)
+ self.attn_variables[name] = var
+
+ def set_results(self, results):
+ for name in self.cross_attn_name:
+ self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
+ for key in self.output:
+ self.output[key].append(results[key])
+
+ def set_maskings(self, name, masking):
+ self.attn_variables[name].masking = masking
+
+ def set_extra(self, extra):
+ self.extra.update(extra)
+
+ def cross_attn_variables(self, ):
+ cross_attn_name = [key for key, value in self.cross_attn_dict.items()
+ if (value==True) and (key in self.attn_variables)
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
+ self.cross_attn_name = cross_attn_name
+
+ output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
+
+ index = 0
+ for name in cross_attn_name:
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
+ index += self.attn_variables[name].output.shape[0]
+ return output, pos_emb
+
+ def cross_attn_mask(self, size, num_heads):
+ attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
+
+ # hard code memories_spatial to previous selected mask
+ if 'memories_spatial' in self.cross_attn_name:
+ memory_attn_mask = self.spatial_memory['prev_batch_mask']
+ bs,c,_,_ = memory_attn_mask.shape
+ memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
+ memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
+ repeat = (self.query_index['memories_spatial'][1] - self.query_index['memories_spatial'][0]) // c
+ mem_len = self.query_index['memories_spatial'][1] - self.query_index['memories_spatial'][0]
+ probs = torch.tensor([1./repeat for i in range(c)])
+ indices = torch.multinomial(probs, num_samples=mem_len, replacement=True).sort()[0]
+ attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask[:,indices]
+ self.extra['memory_indices'] = indices
+
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+ return attn_mask
+
+ def self_attn(self, bs, num_heads):
+ self_attn_name = [key for key, value in self.self_attn_dict.items()
+ if len(value)>0 and key in self.attn_variables
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
+ self.self_attn_name = self_attn_name
+
+ output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
+
+ index = 0
+ for name in self_attn_name:
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
+ index += self.attn_variables[name].output.shape[0]
+
+ self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
+ self_attn_pair = []
+ # build self_attention mask by query interaction
+ for key1, value in self.self_attn_dict.items():
+ for key2 in value:
+ if key1 not in self_attn_name or key2 not in self_attn_name:
+ # exclude the variables that are not used in the current layer
+ continue
+ if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
+ self_attn_pair += [[key1, key2]]
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
+
+ # build self_attention mask by masking, for birectional
+ for key in self.masking:
+ if key in self_attn_name:
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
+
+ # build self_attention mask by masking, for uni-directional
+ for key1, key2 in self_attn_pair:
+ if key1 not in self_attn_name or key2 not in self_attn_name:
+ # exclude the variables that are not used in the current layer
+ continue
+ if key1 in self.masking:
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
+ if key2 in self.masking:
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
+
+ # build self_attention mask masking for spatial query
+ # spatial query attend with itself
+ if 'queries_spatial' in self_attn_name and 'tokens_spatial' in self_attn_name:
+ diag_mask = ~(torch.eye(self.extra['spatial_query_number']).repeat_interleave(self.extra['sample_size'],dim=0).repeat_interleave(self.extra['sample_size'],dim=1)).bool()
+ self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1]] = diag_mask[None,]
+ # spatial query attend with spatial token
+ indices = self.extra['spatial_indices'].permute(0,2,1)
+ diag_index = torch.arange(self.extra['spatial_query_number'], device=indices.device).repeat_interleave(self.extra['sample_size'],dim=0)[None,:,None]
+ diag_mask = ~(indices == diag_index)
+ self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1]] = diag_mask
+ # spatial token attend with itself
+ diag_mask = ~(indices == indices.transpose(1,2))
+ self_attn_mask[:,self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1],self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1]] = diag_mask
+
+ if 'memory_indices' in self.extra:
+ # spatial query attend with memory
+ memory_indices = self.extra['memory_indices'][None,None,:]
+ diag_index = torch.arange(self.extra['spatial_query_number'], device=memory_indices.device).repeat_interleave(self.extra['sample_size'],dim=0)[None,:,None]
+ diag_mask = ~(diag_index == memory_indices)
+ self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = diag_mask
+ # memory attend with itself
+ diag_mask = ~(memory_indices == memory_indices.transpose(1,2))
+ self_attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1],self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = diag_mask
+
+ self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
+ return output, pos_emb, self_attn_mask
+
+ def update_variables(self, output, mode):
+ name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
+ for key in name_set:
+ self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
+
+ def update_spatial_results(self, results):
+ v_emb = results['pred_smaskembs']
+ pred_smasks = results['pred_smasks']
+
+ s_emb = results['pred_pspatials']
+ diag_mask = ~(torch.eye(self.extra['spatial_query_number'], device=s_emb.device).repeat_interleave(self.extra['sample_size'],dim=0)).bool()
+ offset = torch.zeros_like(diag_mask, device=s_emb.device).float()
+ offset.masked_fill_(diag_mask, float("-inf"))
+
+ pred_logits = v_emb @ s_emb.transpose(1,2) + offset[None,]
+ bs,_,ns=pred_logits.shape
+ _,_,h,w=pred_smasks.shape
+
+ logits_idx_y = pred_logits.max(dim=1)[1]
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)[:,None].repeat(1, logits_idx_y.shape[1])
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).view(2,-1).tolist()
+ pred_masks_pos = pred_smasks[logits_idx].reshape(bs,ns,h,w)
+ extra = {"prev_mask": pred_masks_pos}
+ return extra
+
+ def organize_output(self, ):
+ outputs = {}
+ outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
+ for key, values in self.output.items():
+ for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
+ if idx_name not in self.query_index:
+ continue
+ outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
+ for idx, aux_values in enumerate(self.output[key][:-1]):
+ outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
+ if self.task == 'spatial' or self.task == 'refimg':
+ outputs = self.update_spatial_results(outputs)
+ # outputs = self.update_spatial_results(outputs)
+ return outputs
\ No newline at end of file
diff --git a/modeling/interface/seem_demo.py b/modeling/interface/seem_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..bde92033d9e7a96cc498f5196c346c324fa22ee9
--- /dev/null
+++ b/modeling/interface/seem_demo.py
@@ -0,0 +1,397 @@
+# --------------------------------------------------------
+# SEEM -- Segment Everything Everywhere All At Once
+# Licensed under The Apache License 2.0 [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu), Jianwei Yang (jianwyan@microsoft.com)
+# --------------------------------------------------------
+
+import logging
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from timm.models.layers import trunc_normal_
+from detectron2.layers import Conv2d
+import fvcore.nn.weight_init as weight_init
+
+from .build import register_decoder
+from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
+from .prototype.attention_data_struct_seemdemo import AttentionDataStruct
+from ..utils import rand_sample_plain as rand_sample
+from ..utils import prepare_features, configurable
+from ..modules import PositionEmbeddingSine
+from ..modules.point_features import point_sample
+
+
+class SEEMDecoder(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ lang_encoder: nn.Module,
+ in_channels,
+ mask_classification=True,
+ *,
+ hidden_dim: int,
+ dim_proj: int,
+ num_queries: int,
+ contxt_len: int,
+ nheads: int,
+ dim_feedforward: int,
+ dec_layers: int,
+ pre_norm: bool,
+ mask_dim: int,
+ task_switch: dict,
+ enforce_input_project: bool,
+ max_spatial_len: int,
+ attn_arch: dict,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ in_channels: channels of the input features
+ mask_classification: whether to add mask classifier or not
+ num_classes: number of classes
+ hidden_dim: Transformer feature dimension
+ num_queries: number of queries
+ nheads: number of heads
+ dim_feedforward: feature dimension in feedforward network
+ enc_layers: number of Transformer encoder layers
+ dec_layers: number of Transformer decoder layers
+ pre_norm: whether to use pre-LayerNorm or not
+ mask_dim: mask feature dimension
+ enforce_input_project: add input project 1x1 conv even if input
+ channels and hidden dim is identical
+ """
+ super().__init__()
+ assert mask_classification, "Only support mask classification model"
+ self.mask_classification = mask_classification
+
+ # positional encoding
+ N_steps = hidden_dim // 2
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+ # define Transformer decoder here
+ self.num_heads = nheads
+ self.num_layers = dec_layers
+ self.contxt_len = contxt_len
+ self.transformer_self_attention_layers = nn.ModuleList()
+ self.transformer_cross_attention_layers = nn.ModuleList()
+ self.transformer_ffn_layers = nn.ModuleList()
+
+ for _ in range(self.num_layers):
+ self.transformer_self_attention_layers.append(
+ SelfAttentionLayer(
+ d_model=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.transformer_cross_attention_layers.append(
+ CrossAttentionLayer(
+ d_model=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.transformer_ffn_layers.append(
+ FFNLayer(
+ d_model=hidden_dim,
+ dim_feedforward=dim_feedforward,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+
+ self.num_queries = num_queries
+ # learnable query features
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
+ # learnable query p.e.
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+ # learnable positive negative indicator
+ self.pn_indicator = nn.Embedding(2, hidden_dim)
+
+ # level embedding (we always use 3 scales)
+ self.num_feature_levels = 3
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
+ self.input_proj = nn.ModuleList()
+
+ for _ in range(self.num_feature_levels):
+ if in_channels != hidden_dim or enforce_input_project:
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
+ weight_init.c2_xavier_fill(self.input_proj[-1])
+ else:
+ self.input_proj.append(nn.Sequential())
+
+ self.task_switch = task_switch
+ self.query_index = {}
+
+ # output FFNs
+ self.lang_encoder = lang_encoder
+ if self.task_switch['mask']:
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
+ trunc_normal_(self.class_embed, std=.02)
+
+ if task_switch['bbox']:
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+
+ if task_switch['spatial']:
+ # spatial query
+ self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
+ trunc_normal_(self.mask_sptial_embed[0], std=.02)
+ trunc_normal_(self.mask_sptial_embed[1], std=.02)
+ trunc_normal_(self.mask_sptial_embed[2], std=.02)
+
+ self.max_spatial_len = max_spatial_len
+ # spatial memory
+ num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
+ self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
+ self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
+
+ # build AttentionDataStruct
+ attn_arch['NUM_LAYERS'] = self.num_layers
+ self.attention_data = AttentionDataStruct(attn_arch, task_switch)
+
+ @classmethod
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
+ ret = {}
+
+ ret["lang_encoder"] = lang_encoder
+ ret["in_channels"] = in_channels
+ ret["mask_classification"] = mask_classification
+
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
+
+ # Transformer parameters:
+ ret["nheads"] = dec_cfg['NHEADS']
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
+
+ # NOTE: because we add learnable query features which requires supervision,
+ # we add minus 1 to decoder layers to be consistent with our loss
+ # implementation: that is, number of auxiliary losses is always
+ # equal to number of decoder layers. With learnable query features, the number of
+ # auxiliary losses equals number of decoders plus 1.
+ assert dec_cfg['DEC_LAYERS'] >= 1
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
+ ret["task_switch"] = extra['task_switch']
+ ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
+
+ # attn data struct
+ ret["attn_arch"] = cfg['ATTENTION_ARCH']
+
+ return ret
+
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
+ # x is a list of multi-scale feature
+ assert len(x) == self.num_feature_levels; del mask
+ spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg'
+ grounding_extra_flag = 'grounding_tokens' in extra.keys()
+ visual_extra_flag = 'visual_query_pos' in extra.keys()
+ audio_extra_flag = 'audio_tokens' in extra.keys()
+ spatial_memory_flag = 'prev_mask' in extra.keys()
+ flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag, "visual": visual_extra_flag, "audio": audio_extra_flag}
+ self.attention_data.reset(flags, task, extra)
+
+ src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
+ _, bs, _ = src[0].shape
+
+ # QxNxC
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
+ self.attention_data.set('queries_object', 'queries', output, query_embed)
+
+ if self.task_switch['spatial'] and spatial_extra_flag:
+ # get divisor
+ _,h,w = extra['spatial_query_pos_mask'][0].shape
+ divisor = torch.tensor([h,w], device=output.device)[None,]
+
+ # Get mean pos spatial query
+ non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
+ non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
+ non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
+ spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
+ spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num()
+
+ # Get mean neg spatial query
+ non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
+ non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
+ non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
+ spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
+ spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
+
+ # merge positive and negative sample points for self attention
+
+ # Get layerwise spatial query
+ src_spatial_queries = []
+ src_spatial_maskings = []
+ for i in range(len(src)):
+ hw,_,dc = src[i].shape
+ src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
+ src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
+
+ non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
+ non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
+ non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
+
+ pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
+ pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
+
+ non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
+ non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
+ non_zero_query_point[non_zero_query_mask] = 0
+
+ spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
+ spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
+ spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
+
+ src_spatial_queries += [spatial_tokens]
+ src_spatial_maskings += [non_zero_query_mask]
+
+ if 'refimg' in task:
+ output_refimg = {}
+ output_refimg['visual_query_pos'] = spatial_query_pos
+ output_refimg['visual_query_neg'] = spatial_query_neg
+ output_refimg['src_visual_queries'] = src_spatial_queries
+ output_refimg['src_visual_maskings'] = src_spatial_maskings
+ return output_refimg
+
+ if task != 'demo':
+ # Get object query for spatial index
+ self.attention_data.set('queries_spatial', 'queries')
+
+ if self.task_switch['visual'] and visual_extra_flag:
+ visual_query_pos = extra['visual_query_pos']
+ visual_query_neg = extra['visual_query_neg']
+ src_visual_queries = extra['src_visual_queries']
+ src_visual_maskings = extra['src_visual_maskings']
+
+ if self.task_switch['grounding'] and grounding_extra_flag:
+ # Get grounding tokens
+ grounding_tokens = extra['grounding_tokens']
+ _grounding_tokens = grounding_tokens.detach().clone()
+
+ self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
+ self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
+
+ if self.task_switch['audio'] and audio_extra_flag:
+ # Get grounding tokens
+ grounding_tokens = extra['audio_tokens']
+ _grounding_tokens = grounding_tokens.detach().clone()
+
+ self.attention_data.set('tokens_audio', 'tokens', grounding_tokens, _grounding_tokens)
+ self.attention_data.set_maskings('tokens_audio', extra['audio_nonzero_mask'])
+
+ output, query_embed = self.attention_data.cross_attn_variables()
+ # prediction heads on learnable query features
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
+ results["predictions_pos_visual"] = visual_query_pos.transpose(0,1) if visual_extra_flag else None
+ results["predictions_neg_visual"] = visual_query_neg.transpose(0,1) if visual_extra_flag else None
+ self.attention_data.set_results(results)
+
+ for i in range(self.num_layers):
+ level_index = i % self.num_feature_levels
+ # CROSS ATTENTION
+ output, avg_attn = self.transformer_cross_attention_layers[i](
+ output, src[level_index],
+ memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
+ pos=pos[level_index], query_pos=query_embed
+ )
+ self.attention_data.update_variables(output, 'cross_attn')
+
+ # SELF ATTENTION
+ self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
+ if self.task_switch['spatial'] and spatial_extra_flag:
+ # get spatial tokens
+ spatial_tokens = src_spatial_queries[level_index]
+ _spatial_tokens = spatial_tokens.detach().clone()
+
+ self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
+ self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
+
+ if self.task_switch['visual'] and visual_extra_flag:
+ # get spatial tokens
+ visual_tokens = src_visual_queries[level_index]
+ _visual_tokens = visual_tokens.detach().clone()
+
+ self.attention_data.set('tokens_visual', 'tokens', visual_tokens, _visual_tokens)
+ self.attention_data.set_maskings('tokens_visual', src_visual_maskings[level_index])
+
+ output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
+ output = self.transformer_self_attention_layers[i](
+ output, tgt_mask=self_attn_mask,
+ tgt_key_padding_mask=None,
+ query_pos=query_embed)
+
+ # FFN
+ output = self.transformer_ffn_layers[i](
+ output
+ )
+
+ self.attention_data.update_variables(output, 'self_attn')
+ output, query_embed = self.attention_data.cross_attn_variables()
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
+ results["predictions_pos_visual"] = visual_query_pos.transpose(0,1) if visual_extra_flag else None
+ results["predictions_neg_visual"] = visual_query_neg.transpose(0,1) if visual_extra_flag else None
+ self.attention_data.set_results(results)
+
+ return self.attention_data.organize_output()
+
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
+ decoder_output = self.decoder_norm(output)
+ decoder_output = decoder_output.transpose(0, 1)
+ class_embed = decoder_output @ self.class_embed
+ outputs_class = self.lang_encoder.compute_similarity(class_embed)
+ mask_embed = self.mask_embed(decoder_output)
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
+
+ outputs_bbox = [None for i in range(len(outputs_mask))]
+ if self.task_switch['bbox']:
+ outputs_bbox = self.bbox_embed(decoder_output)
+
+ # NOTE: prediction is of higher-resolution
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
+
+ # must use bool type
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
+ attn_mask = attn_mask.detach()
+
+ outputs_caption = class_embed
+
+ results = {
+ "attn_mask": attn_mask,
+ "predictions_class": outputs_class,
+ "predictions_mask": outputs_mask,
+ "predictions_bbox": outputs_bbox,
+ "predictions_caption": outputs_caption,
+ "predictions_maskemb": mask_embed,
+ }
+ return results
+
+@register_decoder
+def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
+ return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
\ No newline at end of file
diff --git a/modeling/interface/seem_v0.py b/modeling/interface/seem_v0.py
new file mode 100644
index 0000000000000000000000000000000000000000..285563eb0569403876429e08f15c42a5f23ce9f6
--- /dev/null
+++ b/modeling/interface/seem_v0.py
@@ -0,0 +1,392 @@
+# --------------------------------------------------------
+# SEEM -- Segment Everything Everywhere All at Once
+# Licensed under The Apache License 2.0 [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import logging
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from timm.models.layers import trunc_normal_
+from detectron2.layers import Conv2d
+import fvcore.nn.weight_init as weight_init
+
+from .build import register_decoder
+from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
+from .prototype.attention_data_struct_seemv0 import AttentionDataStruct
+from ..utils import rand_sample_plain as rand_sample
+from ..utils import prepare_features, configurable
+from ..modules import PositionEmbeddingSine
+from ..modules.point_features import point_sample
+
+
+class SEEMDecoder(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ lang_encoder: nn.Module,
+ in_channels,
+ mask_classification=True,
+ *,
+ hidden_dim: int,
+ dim_proj: int,
+ num_queries: int,
+ contxt_len: int,
+ nheads: int,
+ dim_feedforward: int,
+ dec_layers: int,
+ pre_norm: bool,
+ mask_dim: int,
+ task_switch: dict,
+ enforce_input_project: bool,
+ max_spatial_len: int,
+ attn_arch: dict,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ in_channels: channels of the input features
+ mask_classification: whether to add mask classifier or not
+ num_classes: number of classes
+ hidden_dim: Transformer feature dimension
+ num_queries: number of queries
+ nheads: number of heads
+ dim_feedforward: feature dimension in feedforward network
+ enc_layers: number of Transformer encoder layers
+ dec_layers: number of Transformer decoder layers
+ pre_norm: whether to use pre-LayerNorm or not
+ mask_dim: mask feature dimension
+ enforce_input_project: add input project 1x1 conv even if input
+ channels and hidden dim is identical
+ """
+ super().__init__()
+ assert mask_classification, "Only support mask classification model"
+ self.mask_classification = mask_classification
+
+ # positional encoding
+ N_steps = hidden_dim // 2
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+ # define Transformer decoder here
+ self.num_heads = nheads
+ self.num_layers = dec_layers
+ self.contxt_len = contxt_len
+ self.transformer_self_attention_layers = nn.ModuleList()
+ self.transformer_cross_attention_layers = nn.ModuleList()
+ self.transformer_ffn_layers = nn.ModuleList()
+
+ for _ in range(self.num_layers):
+ self.transformer_self_attention_layers.append(
+ SelfAttentionLayer(
+ d_model=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.transformer_cross_attention_layers.append(
+ CrossAttentionLayer(
+ d_model=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.transformer_ffn_layers.append(
+ FFNLayer(
+ d_model=hidden_dim,
+ dim_feedforward=dim_feedforward,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+
+ self.num_queries = num_queries
+ # learnable query features
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
+ # learnable query p.e.
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+ # level embedding (we always use 3 scales)
+ self.num_feature_levels = 3
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
+ self.input_proj = nn.ModuleList()
+
+ for _ in range(self.num_feature_levels):
+ if in_channels != hidden_dim or enforce_input_project:
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
+ weight_init.c2_xavier_fill(self.input_proj[-1])
+ else:
+ self.input_proj.append(nn.Sequential())
+
+ self.task_switch = task_switch
+ self.query_index = {}
+
+ # output FFNs
+ self.lang_encoder = lang_encoder
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
+ trunc_normal_(self.class_embed, std=.02)
+
+ if task_switch['bbox']:
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+
+ if task_switch['spatial']:
+ # spatial query
+ self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
+ trunc_normal_(self.mask_sptial_embed[0], std=.02)
+ trunc_normal_(self.mask_sptial_embed[1], std=.02)
+ trunc_normal_(self.mask_sptial_embed[2], std=.02)
+
+ self.max_spatial_len = max_spatial_len
+ # spatial memory
+ num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
+ self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
+ self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
+
+ # learnable positive negative indicator
+ self.pn_indicator = nn.Embedding(2, hidden_dim)
+
+ # build AttentionDataStruct
+ attn_arch['NUM_LAYERS'] = self.num_layers
+ self.attention_data = AttentionDataStruct(attn_arch, task_switch)
+
+ @classmethod
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
+ ret = {}
+
+ ret["lang_encoder"] = lang_encoder
+ ret["in_channels"] = in_channels
+ ret["mask_classification"] = mask_classification
+
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
+
+ # Transformer parameters:
+ ret["nheads"] = dec_cfg['NHEADS']
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
+
+ # NOTE: because we add learnable query features which requires supervision,
+ # we add minus 1 to decoder layers to be consistent with our loss
+ # implementation: that is, number of auxiliary losses is always
+ # equal to number of decoder layers. With learnable query features, the number of
+ # auxiliary losses equals number of decoders plus 1.
+ assert dec_cfg['DEC_LAYERS'] >= 1
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
+ ret["task_switch"] = extra['task_switch']
+ ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
+
+ # attn data struct
+ ret["attn_arch"] = cfg['ATTENTION_ARCH']
+
+ return ret
+
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
+ # x is a list of multi-scale feature
+ assert len(x) == self.num_feature_levels; del mask
+ spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg' or 'refimg_tokens' in extra
+ grounding_extra_flag = 'grounding_tokens' in extra.keys()
+ spatial_memory_flag = 'prev_mask' in extra.keys()
+ flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag}
+ self.attention_data.reset(flags, task, extra)
+
+ src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
+ _, bs, _ = src[0].shape
+
+ # QxNxC
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
+ self.attention_data.set('queries_object', 'queries', output, query_embed)
+
+ if self.task_switch['spatial'] and spatial_extra_flag:
+ if 'refimg_tokens' not in extra:
+ # get divisor
+ _,h,w = extra['spatial_query_pos_mask'][0].shape
+ divisor = torch.tensor([h,w], device=output.device)[None,]
+
+ # Get mean pos spatial query
+ non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
+ non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
+ non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
+ spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
+ spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num()
+
+ # Get mean neg spatial query
+ non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
+ non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
+ non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
+ spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
+ spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
+
+ # merge positive and negative sample points for self attention
+ # pos_neg_points = [x|y for x,y in zip(extra['spatial_query_pos_mask'], extra['spatial_query_neg_mask'])]
+
+ # Get layerwise spatial query
+ src_spatial_queries = []
+ src_spatial_maskings = []
+ for i in range(len(src)):
+ hw,_,dc = src[i].shape
+ src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
+ src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
+
+ non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
+ non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
+ non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
+
+ pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
+ pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
+
+ non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
+ non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
+ non_zero_query_point[non_zero_query_mask] = 0
+
+ spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
+ spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
+ spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
+
+ src_spatial_queries += [spatial_tokens]
+ src_spatial_maskings += [non_zero_query_mask]
+
+ if 'refimg' in task:
+ output_refimg = {}
+ output_refimg['spatial_query_pos'] = spatial_query_pos
+ output_refimg['spatial_query_neg'] = spatial_query_neg
+ output_refimg['src_spatial_queries'] = src_spatial_queries
+ output_refimg['src_spatial_maskings'] = src_spatial_maskings
+ return output_refimg
+ else:
+ spatial_query_pos = extra['refimg_tokens']['spatial_query_pos']
+ spatial_query_neg = extra['refimg_tokens']['spatial_query_neg']
+ src_spatial_queries = extra['refimg_tokens']['src_spatial_queries']
+ src_spatial_maskings = extra['refimg_tokens']['src_spatial_maskings']
+
+ # Get object query for spatial index
+ self.attention_data.set('queries_spatial', 'queries')
+
+ # set spatial memory
+ spatial_output = self.spatial_featured.weight.unsqueeze(1).repeat(1, bs, 1)
+ spatial_embed = self.spatial_embed.weight.unsqueeze(1).repeat(1, bs, 1)
+ self.attention_data.set('memories_spatial', 'memories', spatial_output, spatial_embed)
+
+ # if 'queries_spatial' in extra:
+ # self.attention_data.set('queries_spatial', 'queries', var=extra['queries_spatial'])
+
+ # if spatial_memory_flag:
+ # prev_mask = (extra['prev_mask'].sigmoid() > 0.5).detach()
+ # non_zero_query_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in prev_mask]
+ # non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
+ # non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
+ # spatial_memory = point_sample(mask_features, non_zero_query_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
+ # spatial_memory = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_memory.transpose(1,2), ~non_zero_query_mask)]).transpose(0,1).nan_to_num()
+
+ if self.task_switch['grounding'] and grounding_extra_flag:
+ # Get grounding tokens
+ grounding_tokens = extra['grounding_tokens']
+ _grounding_tokens = grounding_tokens.detach().clone()
+
+ self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
+ self.attention_data.set('queries_grounding', 'queries')
+ self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
+
+ output, query_embed = self.attention_data.cross_attn_variables()
+ # prediction heads on learnable query features
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
+ self.attention_data.set_results(results)
+
+ for i in range(self.num_layers):
+ level_index = i % self.num_feature_levels
+ # CROSS ATTENTION
+ output, avg_attn = self.transformer_cross_attention_layers[i](
+ output, src[level_index],
+ memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
+ pos=pos[level_index], query_pos=query_embed
+ )
+ self.attention_data.update_variables(output, 'cross_attn')
+
+ # SELF ATTENTION
+ self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
+ if self.task_switch['spatial'] and spatial_extra_flag:
+ # get spatial tokens
+ spatial_tokens = src_spatial_queries[level_index]
+ _spatial_tokens = spatial_tokens.detach().clone()
+
+ self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
+ self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
+
+ output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
+
+ output = self.transformer_self_attention_layers[i](
+ output, tgt_mask=self_attn_mask,
+ tgt_key_padding_mask=None,
+ query_pos=query_embed)
+
+ # FFN
+ output = self.transformer_ffn_layers[i](
+ output
+ )
+
+ self.attention_data.update_variables(output, 'self_attn')
+ output, query_embed = self.attention_data.cross_attn_variables()
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
+ self.attention_data.set_results(results)
+
+ return self.attention_data.organize_output()
+
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
+ decoder_output = self.decoder_norm(output)
+ decoder_output = decoder_output.transpose(0, 1)
+ class_embed = decoder_output @ self.class_embed
+ outputs_class = self.lang_encoder.compute_similarity(class_embed)
+ mask_embed = self.mask_embed(decoder_output)
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
+
+ outputs_bbox = [None for i in range(len(outputs_mask))]
+ if self.task_switch['bbox']:
+ outputs_bbox = self.bbox_embed(decoder_output)
+
+ # NOTE: prediction is of higher-resolution
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
+
+ # must use bool type
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
+ attn_mask = attn_mask.detach()
+
+ outputs_caption = class_embed
+
+ results = {
+ "attn_mask": attn_mask,
+ "predictions_class": outputs_class,
+ "predictions_mask": outputs_mask,
+ "predictions_bbox": outputs_bbox,
+ "predictions_caption": outputs_caption,
+ "predictions_maskemb": mask_embed,
+ }
+ return results
+
+@register_decoder
+def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
+ return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
diff --git a/modeling/interface/seem_v1.py b/modeling/interface/seem_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8266d7495b8b3a1f9327c110d195bacbff7d0c80
--- /dev/null
+++ b/modeling/interface/seem_v1.py
@@ -0,0 +1,389 @@
+# --------------------------------------------------------
+# SEEM -- Segment Everything Everywhere All at Once
+# Licensed under The Apache License 2.0 [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import logging
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from timm.models.layers import trunc_normal_
+from detectron2.layers import Conv2d
+import fvcore.nn.weight_init as weight_init
+
+from .build import register_decoder
+from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
+from .prototype.attention_data_struct_seemv1 import AttentionDataStruct
+from ..utils import rand_sample, prepare_features, configurable
+from ..modules import PositionEmbeddingSine
+from ..modules.point_features import point_sample
+
+
+class SEEMDecoder(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ lang_encoder: nn.Module,
+ in_channels,
+ mask_classification=True,
+ *,
+ hidden_dim: int,
+ dim_proj: int,
+ num_queries: int,
+ contxt_len: int,
+ nheads: int,
+ dim_feedforward: int,
+ dec_layers: int,
+ pre_norm: bool,
+ mask_dim: int,
+ task_switch: dict,
+ enforce_input_project: bool,
+ max_spatial_len: int,
+ attn_arch: dict,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ in_channels: channels of the input features
+ mask_classification: whether to add mask classifier or not
+ num_classes: number of classes
+ hidden_dim: Transformer feature dimension
+ num_queries: number of queries
+ nheads: number of heads
+ dim_feedforward: feature dimension in feedforward network
+ enc_layers: number of Transformer encoder layers
+ dec_layers: number of Transformer decoder layers
+ pre_norm: whether to use pre-LayerNorm or not
+ mask_dim: mask feature dimension
+ enforce_input_project: add input project 1x1 conv even if input
+ channels and hidden dim is identical
+ """
+ super().__init__()
+ assert mask_classification, "Only support mask classification model"
+ self.mask_classification = mask_classification
+
+ # positional encoding
+ N_steps = hidden_dim // 2
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+ # define Transformer decoder here
+ self.num_heads = nheads
+ self.num_layers = dec_layers
+ self.contxt_len = contxt_len
+ self.transformer_self_attention_layers = nn.ModuleList()
+ self.transformer_cross_attention_layers = nn.ModuleList()
+ self.transformer_ffn_layers = nn.ModuleList()
+
+ for _ in range(self.num_layers):
+ self.transformer_self_attention_layers.append(
+ SelfAttentionLayer(
+ d_model=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.transformer_cross_attention_layers.append(
+ CrossAttentionLayer(
+ d_model=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.transformer_ffn_layers.append(
+ FFNLayer(
+ d_model=hidden_dim,
+ dim_feedforward=dim_feedforward,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+
+ self.num_queries = num_queries
+ # learnable query features
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
+ # learnable query p.e.
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+ # level embedding (we always use 3 scales)
+ self.num_feature_levels = 3
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
+ self.input_proj = nn.ModuleList()
+
+ for _ in range(self.num_feature_levels):
+ if in_channels != hidden_dim or enforce_input_project:
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
+ weight_init.c2_xavier_fill(self.input_proj[-1])
+ else:
+ self.input_proj.append(nn.Sequential())
+
+ self.task_switch = task_switch
+ self.query_index = {}
+
+ # output FFNs
+ self.lang_encoder = lang_encoder
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
+ trunc_normal_(self.class_embed, std=.02)
+
+ if task_switch['bbox']:
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+
+ if task_switch['spatial']:
+ # spatial query
+ self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
+ trunc_normal_(self.mask_sptial_embed[0], std=.02)
+ trunc_normal_(self.mask_sptial_embed[1], std=.02)
+ trunc_normal_(self.mask_sptial_embed[2], std=.02)
+
+ self.max_spatial_len = max_spatial_len
+ # spatial memory
+ num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
+ self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
+ self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
+
+ # learnable positive negative indicator
+ self.pn_indicator = nn.Embedding(2, hidden_dim)
+
+ # build AttentionDataStruct
+ attn_arch['NUM_LAYERS'] = self.num_layers
+ self.attention_data = AttentionDataStruct(attn_arch, task_switch)
+ self.sample_size = attn_arch['QUERY_NUMBER']
+
+ @classmethod
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
+ ret = {}
+
+ ret["lang_encoder"] = lang_encoder
+ ret["in_channels"] = in_channels
+ ret["mask_classification"] = mask_classification
+
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
+
+ # Transformer parameters:
+ ret["nheads"] = dec_cfg['NHEADS']
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
+
+ # NOTE: because we add learnable query features which requires supervision,
+ # we add minus 1 to decoder layers to be consistent with our loss
+ # implementation: that is, number of auxiliary losses is always
+ # equal to number of decoder layers. With learnable query features, the number of
+ # auxiliary losses equals number of decoders plus 1.
+ assert dec_cfg['DEC_LAYERS'] >= 1
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
+ ret["task_switch"] = extra['task_switch']
+ ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
+
+ # attn data struct
+ ret["attn_arch"] = cfg['ATTENTION_ARCH']
+
+ return ret
+
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
+ # x is a list of multi-scale feature
+ assert len(x) == self.num_feature_levels; del mask
+ spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg' or 'refimg_tokens' in extra
+ grounding_extra_flag = 'grounding_tokens' in extra.keys()
+ spatial_memory_flag = 'prev_mask' in extra.keys()
+ flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag}
+ self.attention_data.reset(flags, task, extra)
+
+ src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
+ _,bs,_ = src[0].shape
+
+ # QxNxC
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
+ self.attention_data.set('queries_object', 'queries', output, query_embed)
+
+ if self.task_switch['spatial'] and spatial_extra_flag:
+ if 'refimg_tokens' not in extra:
+ # get divisor
+ c,h,w = extra['spatial_query_pos_mask'][0].shape
+ divisor = torch.tensor([1,h,w], device=output.device)[None,]
+
+ # Get mean pos spatial query
+ non_zero_pos_point = [rand_sample(m, divisor, self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
+ non_zero_pos_index = [m[:,0:1].long() for m in non_zero_pos_point]
+ non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
+ non_zero_pos_index = nn.utils.rnn.pad_sequence(non_zero_pos_index, padding_value=-1).permute(1,0,2)[:,:,0]
+ non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
+ spatial_query_pos = point_sample(mask_features, non_zero_pos_point[:,:,1:].flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
+ num_mask_per_batch = [len(m) for m in extra['spatial_query_pos_mask']]
+ spatial_query_pos = nn.utils.rnn.pad_sequence([torch.stack([x[ns==n].mean(dim=0, keepdim=False) if (ns==n).sum() > 0 else -torch.ones((x.shape[1]), device=spatial_query_pos.device) for n in range(mb)]) for x, m, ns, mb in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask, non_zero_pos_index, num_mask_per_batch)], padding_value=-1).nan_to_num()
+
+ # Get mean neg spatial query
+ non_zero_neg_point = [rand_sample(m, divisor, self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
+ non_zero_neg_index = [m[:,0:1].long() for m in non_zero_neg_point]
+ non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
+ non_zero_neg_index = nn.utils.rnn.pad_sequence(non_zero_neg_index, padding_value=-1).permute(1,0,2)[:,:,0]
+ non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
+ spatial_query_neg = point_sample(mask_features, non_zero_neg_point[:,:,1:].flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
+ num_mask_per_batch = [len(m) for m in extra['spatial_query_neg_mask']]
+ spatial_query_neg = nn.utils.rnn.pad_sequence([torch.stack([x[ns==n].mean(dim=0, keepdim=False) if (ns==n).sum() > 0 else -torch.ones((x.shape[1]), device=spatial_query_neg.device) for n in range(mb)]) for x, m, ns, mb in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask, non_zero_neg_index, num_mask_per_batch)], padding_value=-1).nan_to_num()
+ # Get layerwise spatial query
+ src_spatial_queries = []
+ src_spatial_maskings = []
+ src_spatial_indices = []
+ for i in range(len(src)):
+ hw,_,dc = src[i].shape
+ src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
+ src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
+
+ non_zero_query_point_pos = [rand_sample(m, divisor, self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
+ non_zero_query_point_neg = [rand_sample(m, divisor, self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
+ non_zero_query_point = [torch.cat([x[:,1:],y[:,1:]], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
+ non_zero_query_index = [torch.cat([x[:,0:1],y[:,0:1]], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
+
+ pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
+ pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
+
+ non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
+ non_zero_query_index = nn.utils.rnn.pad_sequence(non_zero_query_index, padding_value=-1).permute(1,0,2)
+ non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
+ non_zero_query_point[non_zero_query_mask] = 0
+
+ spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
+ spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
+ spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
+
+ src_spatial_queries += [spatial_tokens]
+ src_spatial_maskings += [non_zero_query_mask]
+ src_spatial_indices += [non_zero_query_index]
+
+ if 'refimg' in task:
+ output_refimg = {}
+ output_refimg['spatial_query_pos'] = spatial_query_pos
+ output_refimg['spatial_query_neg'] = spatial_query_neg
+ output_refimg['src_spatial_queries'] = src_spatial_queries
+ output_refimg['src_spatial_maskings'] = src_spatial_maskings
+ return output_refimg
+ else:
+ spatial_query_pos = extra['refimg_tokens']['spatial_query_pos']
+ spatial_query_neg = extra['refimg_tokens']['spatial_query_neg']
+ src_spatial_queries = extra['refimg_tokens']['src_spatial_queries']
+ src_spatial_maskings = extra['refimg_tokens']['src_spatial_maskings']
+
+ # Get object query for spatial index
+ self.attention_data.set_extra({"spatial_query_number": len(spatial_query_pos), "sample_size": self.sample_size})
+ self.attention_data.set('queries_spatial', 'queries', sample_size=self.sample_size*len(spatial_query_pos))
+
+ # set spatial memory
+ spatial_output = self.spatial_featured.weight.unsqueeze(1).repeat(1, bs, 1)
+ spatial_embed = self.spatial_embed.weight.unsqueeze(1).repeat(1, bs, 1)
+ self.attention_data.set('memories_spatial', 'memories', spatial_output, spatial_embed)
+
+ if self.task_switch['grounding'] and grounding_extra_flag:
+ # Get grounding tokens
+ grounding_tokens = extra['grounding_tokens']
+ _grounding_tokens = grounding_tokens.detach().clone()
+
+ self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
+ self.attention_data.set('queries_grounding', 'queries')
+ self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
+
+ output, query_embed = self.attention_data.cross_attn_variables()
+ # prediction heads on learnable query features
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
+ self.attention_data.set_results(results)
+
+ for i in range(self.num_layers):
+ level_index = i % self.num_feature_levels
+ # CROSS ATTENTION
+ output, avg_attn = self.transformer_cross_attention_layers[i](
+ output, src[level_index],
+ memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
+ pos=pos[level_index], query_pos=query_embed
+ )
+ self.attention_data.update_variables(output, 'cross_attn')
+
+ # SELF ATTENTION
+ self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
+ if self.task_switch['spatial'] and spatial_extra_flag:
+ # get spatial tokens
+ spatial_tokens = src_spatial_queries[level_index]
+ _spatial_tokens = spatial_tokens.detach().clone()
+
+ self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
+ self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
+ self.attention_data.set_extra({"spatial_indices": src_spatial_indices[level_index]})
+
+ output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
+
+ output = self.transformer_self_attention_layers[i](
+ output, tgt_mask=self_attn_mask,
+ tgt_key_padding_mask=None,
+ query_pos=query_embed)
+
+ # FFN
+ output = self.transformer_ffn_layers[i](
+ output
+ )
+
+ self.attention_data.update_variables(output, 'self_attn')
+ output, query_embed = self.attention_data.cross_attn_variables()
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
+ self.attention_data.set_results(results)
+
+ return self.attention_data.organize_output()
+
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
+ decoder_output = self.decoder_norm(output)
+ decoder_output = decoder_output.transpose(0, 1)
+ class_embed = decoder_output @ self.class_embed
+ outputs_class = self.lang_encoder.compute_similarity(class_embed)
+ mask_embed = self.mask_embed(decoder_output)
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
+
+ outputs_bbox = [None for i in range(len(outputs_mask))]
+ if self.task_switch['bbox']:
+ outputs_bbox = self.bbox_embed(decoder_output)
+
+ # NOTE: prediction is of higher-resolution
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
+
+ # must use bool type
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
+ attn_mask = attn_mask.detach()
+
+ outputs_caption = class_embed
+
+ results = {
+ "attn_mask": attn_mask,
+ "predictions_class": outputs_class,
+ "predictions_mask": outputs_mask,
+ "predictions_bbox": outputs_bbox,
+ "predictions_caption": outputs_caption,
+ "predictions_maskemb": mask_embed,
+ }
+ return results
+
+@register_decoder
+def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
+ return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
diff --git a/modeling/interface/xdecoder.py b/modeling/interface/xdecoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f13a28ed6971e79e7297ed2b6102ad1c19e2dc6
--- /dev/null
+++ b/modeling/interface/xdecoder.py
@@ -0,0 +1,497 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import logging
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from timm.models.layers import trunc_normal_
+from detectron2.layers import Conv2d
+import fvcore.nn.weight_init as weight_init
+
+from .build import register_decoder
+from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
+from ..utils import configurable
+from ..modules import PositionEmbeddingSine
+
+
+class XDecoder(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ lang_encoder: nn.Module,
+ in_channels,
+ mask_classification=True,
+ *,
+ hidden_dim: int,
+ dim_proj: int,
+ num_queries: int,
+ contxt_len: int,
+ nheads: int,
+ dim_feedforward: int,
+ dec_layers: int,
+ pre_norm: bool,
+ mask_dim: int,
+ task_switch: dict,
+ captioning_step: int,
+ enforce_input_project: bool,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ in_channels: channels of the input features
+ mask_classification: whether to add mask classifier or not
+ num_classes: number of classes
+ hidden_dim: Transformer feature dimension
+ num_queries: number of queries
+ nheads: number of heads
+ dim_feedforward: feature dimension in feedforward network
+ enc_layers: number of Transformer encoder layers
+ dec_layers: number of Transformer decoder layers
+ pre_norm: whether to use pre-LayerNorm or not
+ mask_dim: mask feature dimension
+ enforce_input_project: add input project 1x1 conv even if input
+ channels and hidden dim is identical
+ """
+ super().__init__()
+ assert mask_classification, "Only support mask classification model"
+ self.mask_classification = mask_classification
+
+ # positional encoding
+ N_steps = hidden_dim // 2
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+ # define Transformer decoder here
+ self.num_heads = nheads
+ self.num_layers = dec_layers
+ self.contxt_len = contxt_len
+ self.transformer_self_attention_layers = nn.ModuleList()
+ self.transformer_cross_attention_layers = nn.ModuleList()
+ self.transformer_ffn_layers = nn.ModuleList()
+
+ for _ in range(self.num_layers):
+ self.transformer_self_attention_layers.append(
+ SelfAttentionLayer(
+ d_model=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.transformer_cross_attention_layers.append(
+ CrossAttentionLayer(
+ d_model=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.transformer_ffn_layers.append(
+ FFNLayer(
+ d_model=hidden_dim,
+ dim_feedforward=dim_feedforward,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+
+ self.num_queries = num_queries
+ # learnable query features
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
+ # learnable query p.e.
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+ # level embedding (we always use 3 scales)
+ self.num_feature_levels = 3
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
+ self.input_proj = nn.ModuleList()
+
+ for _ in range(self.num_feature_levels):
+ if in_channels != hidden_dim or enforce_input_project:
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
+ weight_init.c2_xavier_fill(self.input_proj[-1])
+ else:
+ self.input_proj.append(nn.Sequential())
+
+ self.task_switch = task_switch
+
+ # output FFNs
+ self.lang_encoder = lang_encoder
+ if self.task_switch['mask']:
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
+ trunc_normal_(self.class_embed, std=.02)
+
+ if task_switch['bbox']:
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+
+ # Caption Project and query
+ if task_switch['captioning']:
+ self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
+ trunc_normal_(self.caping_embed, std=.02)
+ self.pos_embed_caping = nn.Embedding(contxt_len, hidden_dim)
+ self.captioning_step = captioning_step
+
+ # register self_attn_mask to avoid information leakage, it includes interaction between object query, class query and caping query
+ self_attn_mask = torch.zeros((1, num_queries + contxt_len, num_queries + contxt_len)).bool()
+ self_attn_mask[:, :num_queries, num_queries:] = True # object+class query does not attend with caption query.
+ self_attn_mask[:, num_queries:, num_queries:] = torch.triu(torch.ones((1, contxt_len, contxt_len)), diagonal=1).bool() # caption query only attend with previous token.
+ self_attn_mask[:, :num_queries-1, num_queries-1:num_queries] = True # object query does not attend with class query.
+ self_attn_mask[:, num_queries-1:num_queries, :num_queries-1] = True # class query does not attend with object query.
+ self.register_buffer("self_attn_mask", self_attn_mask)
+
+
+ @classmethod
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
+ ret = {}
+
+ ret["lang_encoder"] = lang_encoder
+ ret["in_channels"] = in_channels
+ ret["mask_classification"] = mask_classification
+
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
+
+ # Transformer parameters:
+ ret["nheads"] = dec_cfg['NHEADS']
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
+
+ # NOTE: because we add learnable query features which requires supervision,
+ # we add minus 1 to decoder layers to be consistent with our loss
+ # implementation: that is, number of auxiliary losses is always
+ # equal to number of decoder layers. With learnable query features, the number of
+ # auxiliary losses equals number of decoders plus 1.
+ assert dec_cfg['DEC_LAYERS'] >= 1
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
+
+ ret["task_switch"] = extra['task_switch']
+ ret["captioning_step"] = dec_cfg['CAPTIONING'].get('STEP', 50)
+
+ return ret
+
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
+ if task == 'captioning_infer':
+ return self.forward_captioning(x, mask_features, mask=mask, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra)
+ # x is a list of multi-scale feature
+ assert len(x) == self.num_feature_levels
+ src = []
+ pos = []
+ size_list = []
+
+ # disable mask, it does not affect performance
+ del mask
+ for i in range(self.num_feature_levels):
+ size_list.append(x[i].shape[-2:])
+ pos.append(self.pe_layer(x[i], None).flatten(2))
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
+
+ # flatten NxCxHxW to HWxNxC
+ pos[-1] = pos[-1].permute(2, 0, 1)
+ src[-1] = src[-1].permute(2, 0, 1)
+
+ _, bs, _ = src[0].shape
+
+ # QxNxC
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
+
+ predictions_class = []
+ predictions_mask = []
+ predictions_bbox = []
+ predictions_caption = []
+ predictions_captioning = []
+
+ self_tgt_mask = None
+ if self.training and task == 'vlp' and self.task_switch['captioning']:
+ # output = torch.cat((output, self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)), dim=0) # concat object query, class token and caption token.
+ caping_lang_embed = torch.cat([caption['caption_tokens'] for caption in target_vlp], dim=0).transpose(0, 1) # language output
+ _caping_lang_embed = caping_lang_embed.detach().clone()
+ output = torch.cat((output, _caping_lang_embed), dim=0) # concat object query, class token and caption token.
+ caping_lang_embed += self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
+ query_embed = torch.cat((query_embed, caping_lang_embed), dim=0) # may not add at the beginning.
+ self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
+ elif (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
+ self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
+ grounding_tokens = extra['grounding_tokens']
+ _grounding_tokens = grounding_tokens.detach().clone()
+ # initialize with negative attention at the beginning.
+ pad_tgt_mask = torch.ones((1, self.num_queries + (self.num_queries-1) + len(grounding_tokens), self.num_queries + (self.num_queries-1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat(output.shape[1]*self.num_heads, 1, 1)
+ pad_tgt_mask[:,:self.num_queries,:self.num_queries] = self_tgt_mask
+ pad_tgt_mask[:,self.num_queries:,self.num_queries:] = False # grounding tokens could attend with eatch other
+ self_tgt_mask = pad_tgt_mask
+ output = torch.cat((output, output[:-1]), dim=0)
+ query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) # also pad language embdding to fix embedding
+ else:
+ self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
+
+ # prediction heads on learnable query features
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
+ attn_mask = results["attn_mask"]
+ predictions_class.append(results["outputs_class"])
+ predictions_mask.append(results["outputs_mask"])
+ predictions_bbox.append(results["outputs_bbox"])
+ predictions_caption.append(results["outputs_caption"])
+ predictions_captioning.append(results["outputs_captionting"])
+
+ for i in range(self.num_layers):
+ level_index = i % self.num_feature_levels
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+
+ if self.training and task == 'vlp' and self.task_switch['captioning']:
+ attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
+ # attention: cross-attention first
+ output, avg_attn = self.transformer_cross_attention_layers[i](
+ output, src[level_index],
+ memory_mask=attn_mask,
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
+ pos=pos[level_index], query_pos=query_embed
+ )
+
+ if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
+ output = torch.cat((output, _grounding_tokens), dim=0)
+ query_embed = torch.cat((query_embed, grounding_tokens), dim=0)
+
+ output = self.transformer_self_attention_layers[i](
+ output, tgt_mask=self_tgt_mask,
+ tgt_key_padding_mask=None,
+ query_pos=query_embed
+ )
+
+ # FFN
+ output = self.transformer_ffn_layers[i](
+ output
+ )
+
+ if ((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']:
+ _grounding_tokens = output[-len(_grounding_tokens):]
+ output = output[:-len(_grounding_tokens)]
+ query_embed = query_embed[:-len(_grounding_tokens)]
+
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
+ attn_mask = results["attn_mask"]
+ predictions_class.append(results["outputs_class"])
+ predictions_mask.append(results["outputs_mask"])
+ predictions_bbox.append(results["outputs_bbox"])
+ predictions_caption.append(results["outputs_caption"])
+ predictions_captioning.append(results["outputs_captionting"])
+
+ assert len(predictions_class) == self.num_layers + 1
+ if task == 'vlp':
+ out = {'pred_captionings': predictions_captioning[-1],
+ 'pred_captions': predictions_caption[-1],
+ 'aux_outputs': [{'pred_captionings': x, 'pred_captions': y } for x, y in zip(predictions_captioning[:-1], predictions_caption[:-1])]}
+ return out
+ else:
+ out = {
+ 'pred_logits': predictions_class[-1],
+ 'pred_masks': predictions_mask[-1],
+ 'pred_boxes': predictions_bbox[-1],
+ 'pred_captions': predictions_caption[-1],
+ 'aux_outputs': self._set_aux_loss(
+ predictions_class if self.mask_classification else None, predictions_mask, predictions_bbox, predictions_caption
+ )
+ }
+ return out
+
+ def forward_captioning(self, x, mask_features, mask = None, target_queries = None, target_vlp = None, task='seg', extra={}):
+ # x is a list of multi-scale feature
+ assert len(x) == self.num_feature_levels
+ src = []
+ pos = []
+ size_list = []
+
+ # disable mask, it does not affect performance
+ del mask
+ for i in range(self.num_feature_levels):
+ size_list.append(x[i].shape[-2:])
+ pos.append(self.pe_layer(x[i], None).flatten(2))
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
+
+ # flatten NxCxHxW to HWxNxC
+ pos[-1] = pos[-1].permute(2, 0, 1)
+ src[-1] = src[-1].permute(2, 0, 1)
+
+ _, bs, _ = src[0].shape
+
+ # QxNxC
+ query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
+ query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
+ caping_lang_token = extra['start_token'].repeat(bs, 1)
+ pos_embed_caping = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
+
+ # prepare token embedding for evaluation
+ token_embs = self.lang_encoder.lang_encoder.token_embedding.weight
+ # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
+
+ for cap_idx in range(0, self.captioning_step):
+ caping_lang_embed = self.lang_encoder.forward_language_token((caping_lang_token,))[0].transpose(0, 1)
+ output = torch.cat((query_feat, caping_lang_embed), dim=0) # concat object query, class token and caption token.
+ caping_lang_embed += pos_embed_caping
+ query_embed = torch.cat((query_embed_, caping_lang_embed), dim=0) # may not add at the beginning.
+ # output = torch.cat((query_feat, query_feat_caping), dim=0) # concat object query, class token and caption token.
+
+ # prediction heads on learnable query features
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
+ attn_mask = results["attn_mask"]
+
+ for i in range(self.num_layers):
+ level_index = i % self.num_feature_levels
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+ attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
+ self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
+
+ if extra['captioning_mask'] is not None:
+ bs,nq,wh = attn_mask.shape
+ assert bs==self.num_heads, "Only support single image referring captioning."
+ cap_mask = extra['captioning_mask']
+ attn_mask = attn_mask.reshape(bs,nq,size_list[i%3][0],size_list[i%3][1])
+ cap_mask = F.interpolate(cap_mask[None,].float(), size_list[i%3], mode='nearest').bool()[0,0]
+ attn_mask[:,self.num_queries:, cap_mask] = True
+ attn_mask = attn_mask.reshape(bs,nq,wh)
+
+ # attention: cross-attention first
+ output, avg_attn = self.transformer_cross_attention_layers[i](
+ output, src[level_index],
+ memory_mask=attn_mask,
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
+ pos=pos[level_index], query_pos=query_embed
+ )
+
+ output = self.transformer_self_attention_layers[i](
+ output, tgt_mask=self_tgt_mask,
+ tgt_key_padding_mask=None,
+ query_pos=query_embed
+ )
+
+ # FFN
+ output = self.transformer_ffn_layers[i](
+ output
+ )
+
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
+ attn_mask = results["attn_mask"]
+
+ pred_captions_gen = results['outputs_captionting']
+ # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
+ pred_captions_gen = pred_captions_gen @ token_embs.t()
+ caping_lang_token[:,cap_idx+1] = pred_captions_gen[:,cap_idx].max(-1)[1]
+
+ texts = self.lang_encoder.tokenizer.batch_decode(caping_lang_token, skip_special_tokens=False)
+ texts_new = []
+
+ for x in texts:
+ x = x.split('<|endoftext|>')[0]
+ x = x.replace('<|endoftext|>','')
+ x = x.replace('<|startoftext|>','')
+ x = x.strip()
+ texts_new.append(x)
+
+ out = {'pred_captionings': caping_lang_token,
+ 'pred_texts': texts_new}
+ return out
+
+
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1, task='seg'):
+ decoder_output = self.decoder_norm(output)
+ decoder_output = decoder_output.transpose(0, 1)
+
+ # extract image captioning token from decoder output.
+ if self.task_switch['captioning'] and (task == 'vlp' or task == 'captioning_infer'):
+ outputs_captionting = decoder_output[:,self.num_queries:] @ self.caping_embed
+ else:
+ outputs_captionting = None
+
+ # recompute class token output.
+ norm_decoder_output = decoder_output / (decoder_output.norm(dim=-1, keepdim=True) + 1e-7)
+ obj_token = norm_decoder_output[:,:self.num_queries-1]
+ cls_token = norm_decoder_output[:,self.num_queries-1:self.num_queries]
+
+ sim = (cls_token @ obj_token.transpose(1,2)).softmax(-1)[:,0,:,None] # TODO include class token.
+ cls_token = (sim * decoder_output[:,:self.num_queries-1]).sum(dim=1, keepdim=True)
+
+ if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
+ decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token, decoder_output[:,self.num_queries:2*self.num_queries-1]), dim=1)
+ else:
+ decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token), dim=1)
+
+ # compute class, mask and bbox.
+ class_embed = decoder_output @ self.class_embed
+ # HACK do not compute similarity if mask is not on
+ outputs_class = self.lang_encoder.compute_similarity(class_embed, fake=(((not self.task_switch['mask']) and self.training)))
+
+ if self.task_switch['mask']:
+ mask_embed = self.mask_embed(decoder_output)
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
+
+ # NOTE: prediction is of higher-resolution
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bicubic", align_corners=False, antialias=True)
+
+ # must use bool type
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
+ attn_mask = attn_mask.detach()
+
+ # NOTE: fill False for cls token (JY)
+ attn_mask[:, self.num_queries:self.num_queries+1].fill_(False)
+ else:
+ outputs_mask = None
+ attn_mask = torch.zeros((list(decoder_output.shape[:2]) + [attn_mask_target_size[0]*attn_mask_target_size[1]]), device=decoder_output.device).repeat(self.num_heads, 1, 1).bool()
+
+ outputs_bbox = [None for i in range(len(decoder_output))]
+ if self.task_switch['bbox']:
+ outputs_bbox = self.bbox_embed(decoder_output)
+
+ outputs_caption = None
+ if self.task_switch['caption']:
+ outputs_caption = class_embed
+
+
+ results = {
+ "outputs_class": outputs_class,
+ "outputs_mask": outputs_mask,
+ "outputs_bbox": outputs_bbox,
+ "attn_mask": attn_mask,
+ "outputs_caption": outputs_caption,
+ "outputs_captionting": outputs_captionting,
+ }
+ return results
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_boxes, outputs_captions):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ if self.mask_classification:
+ return [
+ {"pred_logits": a, "pred_masks": b, "pred_boxes": c, "pred_captions": d}
+ for a, b, c, d in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_boxes[:-1], outputs_captions[:-1])
+ ]
+ else:
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
+
+
+@register_decoder
+def get_xdecoder_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
+ return XDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
\ No newline at end of file
diff --git a/modeling/language/LangEncoder/__init__.py b/modeling/language/LangEncoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..81e0291ebac7a30a9f523f6170f8b9a248aca7f1
--- /dev/null
+++ b/modeling/language/LangEncoder/__init__.py
@@ -0,0 +1,35 @@
+from transformers import CLIPTokenizer, CLIPTokenizerFast
+from transformers import AutoTokenizer
+
+from .transformer import *
+from .build import *
+
+
+def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
+ model_name = config_encoder['NAME']
+
+ if not is_lang_encoder(model_name):
+ raise ValueError(f'Unkown model: {model_name}')
+
+ return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
+
+def build_tokenizer(config_encoder):
+ tokenizer = None
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+ if config_encoder['TOKENIZER'] == 'clip':
+ pretrained_tokenizer = config_encoder.get(
+ 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
+ )
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
+ tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
+ elif config_encoder['TOKENIZER'] == 'clip-fast':
+ pretrained_tokenizer = config_encoder.get(
+ 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
+ )
+ tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)
+ elif config_encoder['TOKENIZER'] == 'biomed-clip':
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])
+
+ return tokenizer
\ No newline at end of file
diff --git a/modeling/language/LangEncoder/build.py b/modeling/language/LangEncoder/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f3d7a1d188dac2902cd35a5d9ee7ace5ce49af0
--- /dev/null
+++ b/modeling/language/LangEncoder/build.py
@@ -0,0 +1,16 @@
+_lang_encoders = {}
+
+
+def register_lang_encoder(fn):
+ module_name_split = fn.__module__.split('.')
+ model_name = module_name_split[-1]
+
+ _lang_encoders[model_name] = fn
+
+ return fn
+
+def lang_encoders(model_name):
+ return _lang_encoders[model_name]
+
+def is_lang_encoder(model_name):
+ return model_name in _lang_encoders
diff --git a/modeling/language/LangEncoder/transformer.py b/modeling/language/LangEncoder/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f921721996bf6d346dda2e6214362087e81ae82e
--- /dev/null
+++ b/modeling/language/LangEncoder/transformer.py
@@ -0,0 +1,222 @@
+from collections import OrderedDict
+from typing import Tuple, Union
+import logging
+import os
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from timm.models.layers import DropPath, trunc_normal_
+
+from .build import register_lang_encoder
+from utilities.distributed import is_main_process
+from utilities.model import register_norm_module
+
+logger = logging.getLogger(__name__)
+
+
+@register_norm_module
+class LayerNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-12):
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
+ """
+ super(LayerNorm, self).__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, x):
+ pdtype = x.dtype
+ x = x.float()
+ u = x.mean(-1, keepdim=True)
+ s = (x - u).pow(2).mean(-1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
+ return self.weight * x.to(pdtype) + self.bias
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self,
+ d_model: int,
+ n_head: int,
+ attn_mask: torch.Tensor = None,
+ drop_path: float = 0.0):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
+ if self.attn_mask is not None else None
+
+
+ return self.attn(
+ x, x, x,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ attn_mask=self.attn_mask
+ )[0]
+
+ def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
+ x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
+ x = x + self.drop_path(self.mlp(self.ln_2(x)))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self,
+ context_length: int,
+ vocab_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ drop_path: float = 0.0,
+ autogressive: bool =True):
+ super().__init__()
+
+ self.token_embedding = nn.Embedding(vocab_size, width)
+
+ self.context_length = context_length
+ self.positional_embedding = nn.Parameter(
+ torch.empty(self.context_length, width)
+ )
+
+ self.width = width
+ self.layers = layers
+ self.autogressive = autogressive
+ attn_mask = self.build_attention_mask() if autogressive else None
+ dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule
+ self.resblocks = nn.ModuleList(
+ [
+ ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
+ for i in range(layers)
+ ]
+ )
+
+ self.ln_final = LayerNorm(width)
+
+ trunc_normal_(self.positional_embedding, std=.02)
+ # nn.init.normal_(self.token_embedding, std=.02)
+ trunc_normal_(self.token_embedding.weight, std=.02)
+ self.apply(self._init_weights)
+
+ @property
+ def dim_out(self):
+ return self.width
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
+ if is_main_process():
+ logger.info('=> init weight of Linear/Conv2d from trunc norm')
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ if is_main_process():
+ logger.info('=> init bias of Linear/Conv2d to zeros')
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
+ nn.init.constant_(m.bias, 0)
+
+ def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
+ if os.path.isfile(pretrained):
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
+ logging.info(f'=> loading pretrained model {pretrained}')
+ model_dict = self.state_dict()
+ stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x
+ pretrained_dict = {
+ stripped_key(k): v for k, v in pretrained_dict.items()
+ if stripped_key(k) in model_dict.keys()
+ }
+ need_init_state_dict = {}
+ for k, v in pretrained_dict.items():
+ need_init = (
+ k.split('.')[0] in pretrained_layers
+ or pretrained_layers[0] == '*'
+ )
+ if need_init:
+ if verbose:
+ logger.info(f'=> init {k} from {pretrained}')
+
+ if 'positional_embedding' in k and v.size() != model_dict[k].size():
+ positional_embedding_pretrained = v
+ positional_embedding_current = model_dict[k]
+ L1, nH1 = positional_embedding_pretrained.size()
+ L2, nH2 = positional_embedding_current.size()
+ if nH1 != nH2:
+ logger.info(f"Error in loading {k}, passing")
+ else:
+ if L1 != L2:
+ logger.info(
+ '=> load_pretrained: resized variant: {} to {}'
+ .format((L1, nH1), (L2, nH2))
+ )
+
+ posemb = positional_embedding_pretrained.float()
+ posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1)
+ posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear')
+ posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0)
+ v = posemb_grid
+
+ need_init_state_dict[k] = v
+
+ self.load_state_dict(need_init_state_dict, strict=False)
+
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ 'positional_embedding',
+ 'token_embedding',
+ }
+
+ def forward(self, input_ids, attention_mask=None):
+ key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None
+ # key_padding_mask = (input_ids == 0) if not self.autogressive else None
+ x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
+ x = x + self.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ for block in self.resblocks:
+ x = block(x, key_padding_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_final(x)
+
+ return {'last_hidden_state': x}
+
+
+@register_lang_encoder
+def lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
+ transformer = Transformer(
+ context_length=config_encoder['CONTEXT_LENGTH'],
+ vocab_size=tokenizer.vocab_size,
+ width=config_encoder['WIDTH'],
+ layers=config_encoder['LAYERS'],
+ heads=config_encoder['HEADS'],
+ autogressive=config_encoder.get('AUTOGRESSIVE', True)
+ )
+
+ if config_encoder.get('LOAD_PRETRAINED', False):
+ transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*']))
+ return transformer
diff --git a/modeling/language/__init__.py b/modeling/language/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33112edbe72acf6fa1c8543e05d89428c41dfdf4
--- /dev/null
+++ b/modeling/language/__init__.py
@@ -0,0 +1,10 @@
+from .vlpencoder import *
+from .build import *
+
+def build_language_encoder(config, **kwargs):
+ model_name = config['MODEL']['TEXT']['ARCH']
+
+ if not is_model(model_name):
+ raise ValueError(f'Unkown model: {model_name}')
+
+ return model_entrypoints(model_name)(config, **kwargs)
\ No newline at end of file
diff --git a/modeling/language/build.py b/modeling/language/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..22c4961df7cc16007596144ac4ec9d0ef4f02e47
--- /dev/null
+++ b/modeling/language/build.py
@@ -0,0 +1,14 @@
+_model_entrypoints = {}
+
+
+def register_model(fn):
+ module_name_split = fn.__module__.split('.')
+ model_name = module_name_split[-1]
+ _model_entrypoints[model_name] = fn
+ return fn
+
+def model_entrypoints(model_name):
+ return _model_entrypoints[model_name]
+
+def is_model(model_name):
+ return model_name in _model_entrypoints
\ No newline at end of file
diff --git a/modeling/language/loss.py b/modeling/language/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..db10fa6c37a623dc3e67d8fcb4ae154a618d0950
--- /dev/null
+++ b/modeling/language/loss.py
@@ -0,0 +1,232 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import pickle
+from distutils import log
+
+import torch
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from einops import rearrange, repeat
+from timm.loss import SoftTargetCrossEntropy
+
+soft_cross_entropy = SoftTargetCrossEntropy()
+
+def is_dist_initialized():
+ return torch.distributed.is_initialized()
+
+def get_world_size():
+ if is_dist_initialized():
+ return torch.distributed.get_world_size()
+ return 1
+
+def get_rank():
+ if is_dist_initialized():
+ return dist.get_rank()
+ return 0
+
+def all_gather_grad(x):
+ if get_world_size() > 1:
+ all_x = [torch.zeros_like(x) for _ in range(get_world_size())]
+ torch.distributed.all_gather(all_x, x)
+ all_x[torch.distributed.get_rank()] = x
+ x = torch.cat(all_x, dim=0)
+ return x
+
+def vl_multilabel_contrastive_loss(image_feat, text_feat, temperature=1):
+ """
+ Args:
+ image_feat (torch.Tensor): shape [B, L1, C] # B: batch_size, L1: 1, C: 256
+ text_feat (torch.Tensor): shape [B, L2, C] # B:batch_size, L2: number of selected nouns, C: 256
+
+ Returns:
+ """
+ # [B, L1, C], L1 = 1
+ # image_feat = F.normalize(image_feat, dim=-1)
+ # [B, L2, C]
+ # text_feat = F.normalize(text_feat, dim=-1)
+ # HACK: normalize outside
+
+ # [B, L1, L2]
+ dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
+ # [B, L2, L1]
+ dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
+
+ batch = image_feat.shape[0]
+ img_len = image_feat.shape[1]
+ text_len = text_feat.shape[1]
+ # [B, L1, L2]
+ pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
+ # [B, L2, L1]
+ pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
+
+ image_x = rearrange(image_feat, 'b l c -> (b l) c')
+ text_x = rearrange(text_feat, 'b l c -> (b l) c')
+
+ logits_per_img = image_x @ all_gather_grad(text_x).t()
+ logits_per_text = text_x @ all_gather_grad(image_x).t()
+
+ # get label globally
+ # [B, L1, B, L2, W]
+ labels_per_img = F.one_hot(
+ torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * get_rank(),
+ num_classes=get_world_size()).to(image_x.dtype)
+ labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
+ torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
+ # [BxL1, WxBxL2]
+ labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
+ # [B, L2, B, L1, W]
+ labels_per_text = F.one_hot(
+ torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * get_rank(),
+ num_classes=get_world_size()).to(text_x.dtype)
+ labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
+ torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
+ # [BxL2, WxBxL1]
+ labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
+
+ logit_scale = temperature.exp().clamp(max=100)
+
+ loss_img = soft_cross_entropy(logit_scale * logits_per_img, labels_per_img)
+ loss_text = soft_cross_entropy(logit_scale * logits_per_text, labels_per_text)
+
+ loss = 0.5 * (loss_img + loss_text)
+ return loss
+
+def vl_contrastive_loss(image_feat, text_feat, temperature=1):
+ # if image_id or text_id is None, it should be None across all GPUs
+ # image_feat = F.normalize(image_feat, dim=1)
+ # text_feat = F.normalize(text_feat, dim=1)
+ # handle normalization outside
+
+ # add the following 4 lines
+ image_feat = all_gather_grad(image_feat)
+ text_feat = all_gather_grad(text_feat)
+
+ logits = torch.matmul(image_feat, text_feat.t())
+ logit_scale = temperature.exp().clamp(max=100)
+
+ gt = torch.arange(logits.shape[0], device=logits.device)
+ loss1 = F.cross_entropy(logit_scale * logits, gt)
+ loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
+ return (loss1 + loss2) / 2 # scale it up by the number of GPUs
+
+
+def all_gather_pickle(data, device):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device)
+
+ # obtain Tensor size of each rank
+ local_size = torch.LongTensor([tensor.numel()])
+ size_list = [torch.LongTensor([0]) for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.ByteTensor(size=(max_size,)) )
+ if local_size != max_size:
+ padding = torch.ByteTensor(size=(max_size - local_size,))
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+def all_gather_arbitary_tensor(tensor):
+ if get_world_size() > 1:
+ device = tensor.device
+ tensor_batch = all_gather_pickle(tensor.cpu(), device)
+ tensor_batch = [x.to(device) for x in tensor_batch]
+ tensor_batch[torch.distributed.get_rank()] = tensor
+ tensor_batch = torch.cat(tensor_batch, dim=0)
+ else:
+ tensor_batch = tensor
+ return tensor_batch
+
+def ql_contrastive_loss(image_feat, text_feat, temperature=1):
+ # add the following 4 lines
+ image_feat = all_gather_arbitary_tensor(image_feat)
+ text_feat = all_gather_arbitary_tensor(text_feat)
+
+ logits = torch.matmul(image_feat, text_feat.t())
+ logit_scale = temperature.exp().clamp(max=100)
+
+ gt = torch.arange(logits.shape[0], device=logits.device)
+ loss1 = F.cross_entropy(logit_scale * logits, gt)
+ loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
+ return (loss1 + loss2) / 2 # scale it up by the number of GPUs
+
+def vl_similarity(image_feat, text_feat, temperature=1):
+ # Only support single GPU for now.
+ logits = torch.matmul(image_feat, text_feat.t())
+ logits = temperature.exp().clamp(max=100) * logits
+ return logits
+
+def ql_multi_contrastive_loss(image_feat, text_feat, text_hash, temperature=1):
+ # add the following 4 lines
+ image_feat = all_gather_arbitary_tensor(image_feat)
+ text_feat = all_gather_arbitary_tensor(text_feat)
+
+ text_hash_batch = all_gather_pickle(text_hash, text_feat.device)
+ text_hash_all = torch.cat(text_hash_batch)
+
+ text_hash_all_unique = torch.unique(text_hash_all).tolist()
+ gt = torch.zeros((image_feat.shape[0], len(text_hash_all_unique)), device=text_feat.device)
+ text_hash_all = text_hash_all.tolist()
+ text_feat_unique = torch.stack([text_feat[text_hash_all.index(txt)] for txt in text_hash_all_unique])
+
+ for idx, txt in enumerate(text_hash_all):
+ gt[idx][text_hash_all_unique.index(txt)] = 1
+
+ logits = torch.matmul(image_feat, text_feat_unique.t())
+ logits = logits*temperature.exp().clamp(max=100)
+
+ loss_img = soft_cross_entropy(logits, gt)
+ loss_text = soft_cross_entropy(logits.t(), gt.t() / gt.t().sum(-1, keepdim=True))
+
+ loss = 0.7 * loss_img + 0.3 * loss_text
+ return loss
+
+def image_text_contrastive_loss_queue(image_feat_inp, text_feat_inp, lang_enc, training):
+ # add the following 4 lines
+ image_feat = all_gather_grad(image_feat_inp.contiguous())
+ text_feat = all_gather_grad(text_feat_inp.contiguous())
+
+ image_feat = image_feat / (image_feat.norm(dim=-1, keepdim=True) + 1e-7)
+ text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-7)
+
+ temperature = lang_enc.logit_scale
+ logits = torch.matmul(image_feat, text_feat.t())
+ logit_scale = temperature.exp().clamp(max=100)
+
+ gt = torch.arange(logits.shape[0], device=logits.device)
+ loss1 = F.cross_entropy(logit_scale * logits, gt)
+ loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
+
+ return (loss1 + loss2) / 2 # scale it up by the number of GPUs
\ No newline at end of file
diff --git a/modeling/language/misc.py b/modeling/language/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f1aa8c5cc1392f4fc4098845a43fdac3081d531
--- /dev/null
+++ b/modeling/language/misc.py
@@ -0,0 +1,66 @@
+import random
+
+import torch
+import nltk
+import numpy as np
+
+from utilities.constants import IMAGENET_DEFAULT_TEMPLATES
+
+nltk.download('punkt', quiet=True)
+nltk.download('averaged_perceptron_tagger', quiet=True)
+
+def get_tag(tokenized, tags):
+ if not isinstance(tags, (list, tuple)):
+ tags = [tags]
+ ret = []
+ for (word, pos) in nltk.pos_tag(tokenized):
+ for tag in tags:
+ if pos == tag:
+ ret.append(word)
+ return ret
+
+def get_noun_phrase(tokenized):
+ # Taken from Su Nam Kim Paper...
+ grammar = r"""
+ NBAR:
+ {*} # Nouns and Adjectives, terminated with Nouns
+
+ NP:
+ {}
+ {} # Above, connected with in/of/etc...
+ """
+ chunker = nltk.RegexpParser(grammar)
+
+ chunked = chunker.parse(nltk.pos_tag(tokenized))
+ continuous_chunk = []
+ current_chunk = []
+
+ for subtree in chunked:
+ if isinstance(subtree, nltk.Tree):
+ current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
+ elif current_chunk:
+ named_entity = ' '.join(current_chunk)
+ if named_entity not in continuous_chunk:
+ continuous_chunk.append(named_entity)
+ current_chunk = []
+ else:
+ continue
+
+ return continuous_chunk
+
+def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True):
+ tokenized = nltk.word_tokenize(text)
+
+ if random.random() >= phrase_prob:
+ nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP'])
+ else:
+ nouns = get_noun_phrase(tokenized)
+
+
+ prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns]
+
+ if append_text:
+ prompt_texts += [text]
+ nouns += [text]
+
+ return prompt_texts, nouns
\ No newline at end of file
diff --git a/modeling/language/vlpencoder.py b/modeling/language/vlpencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d116f2f900f66b8baf886762597bc1c9395dbe3f
--- /dev/null
+++ b/modeling/language/vlpencoder.py
@@ -0,0 +1,214 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from timm.models.layers import trunc_normal_
+
+from .build import register_model
+from ..utils import configurable
+from .LangEncoder import build_tokenizer, build_lang_encoder
+from utilities.prompt_engineering import prompt_engineering, get_prompt_templates
+
+from transformers import AutoTokenizer, AutoModel
+
+class LanguageEncoder(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ tokenizer,
+ tokenizer_type,
+ lang_encoder,
+ lang_projection,
+ max_token_num,
+ queue_operator,
+ ):
+ super().__init__()
+ # seg
+ self.tokenizer = tokenizer
+ self.tokenizer_type = tokenizer_type
+ self.lang_encoder = lang_encoder
+ self.lang_proj = lang_projection
+ self.max_token_num = max_token_num
+ self.logit_scale = nn.Parameter(torch.ones([]))
+
+ self.device = lang_projection.device
+ # captioning & retrieval
+ for key, value in queue_operator.items():
+ self.register_buffer(key, value)
+
+ self.biomed_encoder = AutoModel.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
+ self.biomed_encoder.to(self.device)
+ @classmethod
+ def from_config(cls, cfg):
+ # build up text encoder for seg
+ tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
+ tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']
+ lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])
+ max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
+
+ dim_lang = cfg['MODEL']['TEXT']['WIDTH']
+ dim_projection = cfg['MODEL']['DIM_PROJ']
+ lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))
+ trunc_normal_(lang_projection, std=.02)
+
+ # tested not working better
+ queue_operator = {}
+
+ return {
+ "tokenizer": tokenizer,
+ "tokenizer_type": tokenizer_type,
+ "lang_encoder": lang_encoder,
+ "lang_projection": lang_projection,
+ "max_token_num": max_token_num,
+ "queue_operator": queue_operator,
+ }
+
+ def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True, store_buffer=None):
+ if not is_eval:
+ if prompt:
+ # randomly sample one template
+ arbitary_concepts = [
+ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
+ for label in range(len(class_names))
+ ]
+ if add_bgd:
+ arbitary_concepts.append("A background in coco.")
+ else:
+ arbitary_concepts = class_names
+
+ input_ids = []
+ attention_masks = []
+ for txt in arbitary_concepts:
+ tokens = self.tokenizer(
+ txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
+ )
+ tokens['input_ids'].squeeze_()
+ tokens['attention_mask'].squeeze_()
+
+ input_ids.append(tokens['input_ids'])
+ attention_masks.append(tokens['attention_mask'])
+
+ arbitary_tokens = torch.stack(input_ids)
+ arbitary_attention_masks = torch.stack(attention_masks)
+
+ text_emb = self.forward_language((arbitary_tokens , arbitary_attention_masks ), norm=norm)
+ setattr(self, '{}_text_embeddings'.format(name), text_emb)
+ else:
+ with torch.no_grad():
+ def extract_mean_emb(txts):
+ tokens = self.tokenizer(
+ txts, padding='max_length', truncation=True,
+ max_length=self.max_token_num, return_tensors='pt'
+ )
+ # Move tokens to correct device
+ tokens = {k: v.to(self.device) for k, v in tokens.items()}
+ clss_embedding = self.forward_language(
+ (tokens['input_ids'], tokens['attention_mask']),
+ norm=norm
+ )
+ clss_embedding = clss_embedding.mean(dim=0)
+ clss_embedding /= clss_embedding.norm()
+ return clss_embedding
+
+ templates = get_prompt_templates()
+ clss_embeddings = []
+ if prompt:
+ for clss in class_names:
+ txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff',''))
+ for template in templates]
+ clss_embeddings.append(extract_mean_emb(txts))
+ else:
+ for clss in class_names:
+ clss_embeddings.append(extract_mean_emb([clss]))
+
+ if add_bgd:
+ txts = ["A background in coco."]
+ clss_embeddings.append(extract_mean_emb(txts))
+
+ text_emb = torch.stack(clss_embeddings, dim=0)
+ setattr(self, '{}_text_embeddings'.format(name), text_emb)
+
+ def reset_text_embeddings(self, name='default'):
+ pass
+
+ def get_text_token_embeddings(self, txts, name='default', token=False, norm=False):
+ if not token:
+ tokens = self.tokenizer(
+ txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
+ )
+ tokens = {key: value for key, value in tokens.items()}
+ else:
+ tokens = txts
+ token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm)
+ ret = {"tokens": tokens,
+ "token_emb": token_emb,
+ "class_emb": class_emb,}
+ setattr(self, '{}_token_embeddings'.format(name), ret)
+ return ret
+
+ def forward_language(self, texts, norm=True):
+ if self.tokenizer_type == 'biomed-clip':
+ with torch.no_grad(): # Disable gradient calculation
+ outputs = self.biomed_encoder(*texts)
+ # Extract the last hidden state
+ x = outputs['last_hidden_state']
+ x = x[:, 0] # Get the [CLS] token's embeddings for all examples
+ else:
+ x = self.lang_encoder(*texts)
+ x = x['last_hidden_state']
+
+ if self.tokenizer_type == 'clip':
+ x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]
+ else:
+ x = x[:, 0]
+
+ x = x @ self.lang_proj
+ if norm:
+ x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)
+ return x
+
+ def forward_language_token(self, texts, norm=False):
+ if self.tokenizer_type == 'biomed-clip':
+ with torch.no_grad(): # Disable gradient calculation
+ outputs = self.biomed_encoder(*texts)
+ # Extract the last hidden state
+ token_x = outputs['last_hidden_state']
+ class_x = token_x[:, 0] # Get the [CLS] token's embeddings for all examples
+ else:
+ x = self.lang_encoder(*texts)
+ token_x = x['last_hidden_state']
+
+ if self.tokenizer_type == 'clip':
+ class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)]
+ else:
+ class_x = token_x[:, 0]
+
+ class_x = class_x @ self.lang_proj
+ token_x = token_x @ self.lang_proj
+
+ if norm:
+ class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7)
+ token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7)
+
+ return token_x, class_x
+
+ def compute_similarity(self, v_emb, name='default', fake=False):
+ if fake:
+ return None
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ t_emb = getattr(self, '{}_text_embeddings'.format(name))
+ output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)
+ return output
+
+
+@register_model
+def get_language_model(cfg, **kwargs):
+ return LanguageEncoder(cfg)
\ No newline at end of file
diff --git a/modeling/modules/__init__.py b/modeling/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..83531335a4b0b1cf3c26e99ba117b355b3baff2c
--- /dev/null
+++ b/modeling/modules/__init__.py
@@ -0,0 +1,6 @@
+from .point_features import *
+from .position_encoding import *
+from .postprocessing import *
+from .attention import *
+from .criterion import *
+from .matcher import *
\ No newline at end of file
diff --git a/modeling/modules/attention.py b/modeling/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..b11c793aaf74b941b2388bd74f35cf0376f7ecff
--- /dev/null
+++ b/modeling/modules/attention.py
@@ -0,0 +1,487 @@
+import warnings
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
+from torch.nn.parameter import Parameter
+from torch.overrides import has_torch_function, handle_torch_function
+from torch.nn.functional import pad, linear, softmax, dropout
+
+
+def multi_head_attention_forward(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Tensor,
+ in_proj_bias: Tensor,
+ bias_k: Optional[Tensor],
+ bias_v: Optional[Tensor],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Tensor,
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[Tensor] = None,
+ k_proj_weight: Optional[Tensor] = None,
+ v_proj_weight: Optional[Tensor] = None,
+ static_k: Optional[Tensor] = None,
+ static_v: Optional[Tensor] = None,
+) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ See "Attention Is All You Need" for more details.
+ embed_dim_to_check: total dimension of the model.
+ num_heads: parallel attention heads.
+ in_proj_weight, in_proj_bias: input projection weight and bias.
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
+ add_zero_attn: add a new batch of zeros to the key and
+ value sequences at dim=1.
+ dropout_p: probability of an element to be zeroed.
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
+ training: apply dropout if is ``True``.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
+ and value in different forms. If false, in_proj_weight will be used, which is
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
+ static_k, static_v: static key and value used for attention operators.
+
+
+ Shape:
+ Inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
+ will be unchanged. If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
+
+ Outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
+ if has_torch_function(tens_ops):
+ return handle_torch_function(
+ multi_head_attention_forward,
+ tens_ops,
+ query,
+ key,
+ value,
+ embed_dim_to_check,
+ num_heads,
+ in_proj_weight,
+ in_proj_bias,
+ bias_k,
+ bias_v,
+ add_zero_attn,
+ dropout_p,
+ out_proj_weight,
+ out_proj_bias,
+ training=training,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ use_separate_proj_weight=use_separate_proj_weight,
+ q_proj_weight=q_proj_weight,
+ k_proj_weight=k_proj_weight,
+ v_proj_weight=v_proj_weight,
+ static_k=static_k,
+ static_v=static_v,
+ )
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == embed_dim_to_check
+ # allow MHA to have different sizes for the feature dimension
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+ head_dim = embed_dim // num_heads
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
+ scaling = float(head_dim) ** -0.5
+
+ if not use_separate_proj_weight:
+ if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):
+ # self-attention
+ q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
+
+ elif key is value or torch.equal(key, value):
+ # encoder-decoder attention
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = linear(query, _w, _b)
+
+ if key is None:
+ assert value is None
+ k = None
+ v = None
+ else:
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ k, v = linear(key, _w, _b).chunk(2, dim=-1)
+
+ else:
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = linear(query, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = embed_dim * 2
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ k = linear(key, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim * 2
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ v = linear(value, _w, _b)
+ else:
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
+ len1, len2 = q_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == query.size(-1)
+
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
+ len1, len2 = k_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == key.size(-1)
+
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
+ len1, len2 = v_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == value.size(-1)
+
+ if in_proj_bias is not None:
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
+ else:
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias)
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias)
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias)
+ q = q * scaling
+
+ if attn_mask is not None:
+ assert (
+ attn_mask.dtype == torch.float32
+ or attn_mask.dtype == torch.float64
+ or attn_mask.dtype == torch.float16
+ or attn_mask.dtype == torch.uint8
+ or attn_mask.dtype == torch.bool
+ ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype)
+ if attn_mask.dtype == torch.uint8:
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
+ attn_mask = attn_mask.to(torch.bool)
+
+ if attn_mask.dim() == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
+ elif attn_mask.dim() == 3:
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
+ else:
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
+ # attn_mask's dim is 3 now.
+
+ # convert ByteTensor key_padding_mask to bool
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+ warnings.warn(
+ "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
+ )
+ key_padding_mask = key_padding_mask.to(torch.bool)
+
+ if bias_k is not None and bias_v is not None:
+ if static_k is None and static_v is None:
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = pad(key_padding_mask, (0, 1))
+ else:
+ assert static_k is None, "bias cannot be added to static key."
+ assert static_v is None, "bias cannot be added to static value."
+ else:
+ assert bias_k is None
+ assert bias_v is None
+
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+
+ if static_k is not None:
+ assert static_k.size(0) == bsz * num_heads
+ assert static_k.size(2) == head_dim
+ k = static_k
+
+ if static_v is not None:
+ assert static_v.size(0) == bsz * num_heads
+ assert static_v.size(2) == head_dim
+ v = static_v
+
+ src_len = k.size(1)
+
+ if key_padding_mask is not None:
+ # assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
+ if attn_mask is not None:
+ attn_mask = pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = pad(key_padding_mask, (0, 1))
+
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_output_weights.masked_fill_(attn_mask, float("-inf"))
+ else:
+ attn_output_weights += attn_mask
+
+ if key_padding_mask is not None:
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ float("-inf"),
+ )
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
+
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
+
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
+
+ if need_weights:
+ # average attention weights over heads
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
+ else:
+ return attn_output, None
+
+
+# This class exists solely for Transformer; it has an annotation stating
+# that bias is never None, which appeases TorchScript
+class _LinearWithBias(nn.Linear):
+ bias: Tensor # type: ignore
+
+ def __init__(self, in_features: int, out_features: int) -> None:
+ super().__init__(in_features, out_features, bias=True) # type: ignore
+
+
+class MultiheadAttention(nn.Module):
+ r"""Allows the model to jointly attend to information
+ from different representation subspaces.
+ See `Attention Is All You Need `_
+
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
+
+ Args:
+ embed_dim: total dimension of the model.
+ num_heads: parallel attention heads.
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+ bias: add bias as module parameter. Default: True.
+ add_bias_kv: add bias to the key and value sequences at dim=0.
+ add_zero_attn: add a new batch of zeros to the key and
+ value sequences at dim=1.
+ kdim: total number of features in key. Default: None.
+ vdim: total number of features in value. Default: None.
+
+ Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
+ to :attr:`embed_dim` such that query, key, and value have the same
+ number of features.
+
+ Examples::
+
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+ """
+ bias_k: Optional[torch.Tensor]
+ bias_v: Optional[torch.Tensor]
+
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+
+ if self._qkv_same_embed_dim is False:
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+ self.register_parameter('in_proj_weight', None)
+ else:
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
+ self.register_parameter('q_proj_weight', None)
+ self.register_parameter('k_proj_weight', None)
+ self.register_parameter('v_proj_weight', None)
+
+ if bias:
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+ self.out_proj = _LinearWithBias(embed_dim, embed_dim)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ if self._qkv_same_embed_dim:
+ xavier_uniform_(self.in_proj_weight)
+ else:
+ xavier_uniform_(self.q_proj_weight)
+ xavier_uniform_(self.k_proj_weight)
+ xavier_uniform_(self.v_proj_weight)
+
+ if self.in_proj_bias is not None:
+ constant_(self.in_proj_bias, 0.)
+ constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ xavier_normal_(self.bias_v)
+
+ def __setstate__(self, state):
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
+ if '_qkv_same_embed_dim' not in state:
+ state['_qkv_same_embed_dim'] = True
+
+ super(MultiheadAttention, self).__setstate__(state)
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ See "Attention Is All You Need" for more details.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. When given a binary mask and a value is True,
+ the corresponding value on the attention layer will be ignored. When given
+ a byte mask and a value is non-zero, the corresponding value on the attention
+ layer will be ignored
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+
+ Shapes for inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
+ source sequence length.
+
+ If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
+ length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
+ the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+
+ Shapes for outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+ if not self._qkv_same_embed_dim:
+ return multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
+ attn_mask=attn_mask, use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight)
+ else:
+ return multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
+ attn_mask=attn_mask)
\ No newline at end of file
diff --git a/modeling/modules/criterion.py b/modeling/modules/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd9644c7093643447e4f3ba5eb1ee9a6c804a331
--- /dev/null
+++ b/modeling/modules/criterion.py
@@ -0,0 +1,874 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+"""
+MaskFormer criterion.
+"""
+import logging
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from detectron2.utils.comm import get_world_size
+from timm.loss import SoftTargetCrossEntropy
+from .point_features import (
+ get_uncertain_point_coords_with_randomness,
+ point_sample,
+)
+
+from ..language.loss import ql_multi_contrastive_loss, image_text_contrastive_loss_queue, vl_similarity, all_gather_grad
+from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list, _max_by_axis
+from ..utils import box_ops
+
+# from image2html.visualizer import VL
+
+
+def dice_loss(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ num_masks: float,
+ ):
+ """
+ Compute the DICE loss, similar to generalized IOU for masks
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ """
+ inputs = inputs.sigmoid()
+ inputs = inputs.flatten(1)
+ numerator = 2 * (inputs * targets).sum(-1)
+ denominator = inputs.sum(-1) + targets.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss.sum() / num_masks
+
+
+dice_loss_jit = torch.jit.script(
+ dice_loss
+) # type: torch.jit.ScriptModule
+
+
+def sigmoid_ce_loss(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ num_masks: float,
+ ):
+ """
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ Returns:
+ Loss tensor
+ """
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+
+ return loss.mean(1).sum() / num_masks
+
+
+sigmoid_ce_loss_jit = torch.jit.script(
+ sigmoid_ce_loss
+) # type: torch.jit.ScriptModule
+
+
+def calculate_uncertainty(logits):
+ """
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
+ foreground class in `classes`.
+ Args:
+ logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
+ class-agnostic, where R is the total number of predicted masks in all images and C is
+ the number of foreground classes. The values are logits.
+ Returns:
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
+ the most uncertain locations having the highest uncertainty score.
+ """
+ assert logits.shape[1] == 1
+ gt_class_logits = logits.clone()
+ return -(torch.abs(gt_class_logits))
+
+
+class SetCriterion(nn.Module):
+ """This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, top_x_layers, losses,
+ num_points, oversample_ratio, importance_sample_ratio, grounding_weight):
+ """Create the criterion.
+ Parameters:
+ num_classes: number of object categories, omitting the special no-object category
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.eos_coef = eos_coef
+ self.top_x_layers = top_x_layers
+ self.losses = losses
+ empty_weight = torch.ones(self.num_classes + 1)
+ empty_weight[-1] = self.eos_coef
+ self.register_buffer("empty_weight", empty_weight)
+
+ # pointwise mask loss parameters
+ self.num_points = num_points
+ self.oversample_ratio = oversample_ratio
+ self.importance_sample_ratio = importance_sample_ratio
+
+ # grounding
+ self.grounding_weight = grounding_weight
+
+ def loss_labels(self, outputs, targets, indices, num_masks, layer_id, extra):
+ """Classification loss (NLL)
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+ """
+ if layer_id > self.top_x_layers['mask']:
+ return {"loss_mask_ce_0": 0}
+
+ if indices is None or len(targets) == 0:
+ loss_ce = outputs['pred_logits'].sum() * 0.0
+ losses = {"loss_mask_ce_0": loss_ce}
+ return losses
+
+ assert "pred_logits" in outputs
+ src_logits = outputs["pred_logits"].type(self.empty_weight.dtype)
+
+ idx = self._get_src_permutation_idx(indices)
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+ target_classes = torch.full(
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
+ )
+ target_classes[idx] = target_classes_o
+
+ if src_logits.shape[2] == self.num_classes+1:
+ empty_weight = torch.ones(self.num_classes + 1).to(src_logits.device).type(self.empty_weight.dtype)
+ empty_weight[-1] = self.eos_coef
+ else:
+ empty_weight = torch.ones(self.num_classes + 1000 + 1).to(src_logits.device).type(self.empty_weight.dtype)
+ empty_weight[self.num_classes] = self.eos_coef
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes)
+ losses = {"loss_mask_ce_0": loss_ce}
+ return losses
+
+ def loss_labels_openimage(self, outputs, targets, indices, num_masks, layer_id, extra):
+ """Classification loss (NLL)
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+ """
+ if layer_id > self.top_x_layers['mask']:
+ return {"loss_openimage_ce_0": 0}
+
+ assert "pred_captions" in outputs
+
+ if indices is None or len(targets) == 0 or (len(targets) > 0 and len(targets[0]['labels']) == 0):
+ loss_ce = outputs['pred_captions'].sum() * 0.0
+ losses = {"loss_openimage_ce_0": loss_ce}
+ return losses
+
+ # compute i2t loss
+ loss_openimage_ce = 0
+ losses = {}
+ for b in range(len(indices)):
+ pred_logit = outputs["pred_logits"][b][indices[b][0]]
+ gt_logit = torch.zeros_like(pred_logit)
+ select_idx = torch.stack((torch.arange(len(indices[b][1])), indices[b][1])).tolist()
+ gt_logit[select_idx] = 1
+ loss_openimage_ce += torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1).mean()
+ loss_openimage_ce = loss_openimage_ce / len(indices)
+ losses.update({"loss_openimage_ce_0": loss_openimage_ce})
+ return losses
+
+ def loss_itc(self, outputs, targets, indices, num_masks, layer_id, extra):
+ if layer_id >= self.top_x_layers['retrieval']:
+ return {"loss_retrieval_decoder_0": 0}
+ t_emb = torch.cat([x['caption_proj'] for x in targets], dim=0)
+ v_emb = outputs['pred_captions'][:,-1]
+ loss_contrast = image_text_contrastive_loss_queue(v_emb, t_emb, extra['lang_encoder'], extra['training'])
+
+ # compute query-token contrastive loss
+ ttk_emb = torch.cat([x['caption_tokens'] for x in targets], dim=0)
+ ttk_mask = torch.cat([x['caption_mask'] for x in targets], dim=0).float()
+ ttk_mask = ttk_mask * torch.cumsum(ttk_mask, dim=1)
+ vtk_emb = outputs['pred_captions'][:,:-1]
+ keep = torch.cat([x['caption_mask'] for x in targets], dim=0).bool()
+
+ ttk_emb = ttk_emb / (ttk_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ vtk_emb = vtk_emb / (vtk_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ logit_scale = extra['lang_encoder'].logit_scale.exp().clamp(max=100)
+
+ # prepare gt
+ gt = (torch.eye(vtk_emb.shape[0]).type_as(ttk_mask).unsqueeze(-1) * ttk_mask.unsqueeze(0).repeat(vtk_emb.shape[0], 1, 1))[:,keep].flatten(1)
+ gt = gt / (gt.sum(1, keepdim=True) + 1e-7)
+ # compute i2t loss
+ logits = logit_scale * (vtk_emb @ ttk_emb[keep].transpose(0, 1)).mean(1)
+ loss_contrast_fine_vt = SoftTargetCrossEntropy()(logits, gt)
+ # loss_contrast_fine = loss_contrast_fine_vt # i2t only
+
+ # compute t2i loss
+ bs, nq, _ = vtk_emb.shape
+ logits = logit_scale * (ttk_emb @ vtk_emb.flatten(0,1).transpose(0, 1)).reshape(bs,-1,bs,nq).mean(dim=-1)[keep,:]
+ loss_contrast_fine_tv = SoftTargetCrossEntropy()(logits, gt.t())
+ # compute loss
+ loss_contrast_fine = (loss_contrast_fine_vt * 0.7 + loss_contrast_fine_tv * 0.3)
+
+ losses = {"loss_retrieval_decoder_0": loss_contrast + loss_contrast_fine * 0.5}
+ return losses
+
+ def loss_captionings(self, outputs, targets, indices, num_masks, layer_id, extra):
+ if layer_id >= self.top_x_layers['captioning']:
+ return {"loss_captioning_0": 0}
+
+ pred_captions_gen = outputs['pred_captionings'][:, :-1]
+ token_embs = extra['token_embedding'].weight
+ # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
+ # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
+ pred_captions_gen = pred_captions_gen @ token_embs.t()
+
+ # temperature = extra['lang_encoder'].logit_scale
+ # logit_scale = temperature.exp().clamp(max=100)
+
+ target_captions_gen = torch.cat([target['caption_tokenids'] for target in targets], 0)[:, 1:]
+ target_captions_gen_mask = torch.cat([target['caption_mask'] for target in targets], 0)[:, 1:]
+
+ # loss_caption = F.cross_entropy(pred_captions_gen.transpose(1,2) * logit_scale, target_captions_gen, reduction='none')
+ loss_caption = F.cross_entropy(pred_captions_gen.transpose(1,2), target_captions_gen, reduction='none')
+ loss_caption = (loss_caption * target_captions_gen_mask).sum() / (target_captions_gen_mask.sum() + 1)
+ losses = {"loss_captioning_0": loss_caption}
+ return losses
+
+ def loss_captions(self, outputs, targets, indices, num_masks, layer_id, extra):
+ if layer_id >= self.top_x_layers['caption']:
+ return {"loss_caption_0": 0}
+ matched_tokens = [m[0] for m in indices]
+ t_emb_class = torch.cat([extra['class_embeddings'][targets[bs]['labels'][m[1]]] for bs, m in enumerate(indices)])
+ t_hash_class = torch.cat([torch.tensor(targets[bs]['labels_hash'])[m[1]] for bs, m in enumerate(indices)])
+
+ # pred_captions denotes all unmatched object queries.
+ unmatched_pred_captions = []
+ matched_pred_captions = []
+ for idx, m in enumerate(matched_tokens):
+ unmatched_masks = torch.ones(outputs['pred_captions'].shape[1:-1]).bool()
+ matched_masks = torch.zeros(outputs['pred_captions'].shape[1:-1]).bool()
+
+ unmatched_masks[m] = False
+ matched_masks[m] = True
+
+ unmatched_pred_captions.append(outputs['pred_captions'][idx][unmatched_masks])
+ matched_pred_captions.append(outputs['pred_captions'][idx][matched_masks])
+
+ outputs['unmatched_pred_captions'] = unmatched_pred_captions
+ v_emb_class = torch.cat(matched_pred_captions)
+ v_emb_class = v_emb_class / (v_emb_class.norm(dim=-1, keepdim=True) + 1e-7)
+
+ indices = self.matcher(outputs, targets, mode="caption_womask", extra={'temperature':extra['lang_logit']})
+ src_idx = self._get_src_permutation_idx(indices)
+
+ t_emb = torch.cat([t['captions'][indices[bs][1]] for bs,t in enumerate(targets)])
+ t_hash = torch.cat([torch.tensor(t['captions_hash'])[indices[bs][1]] for bs,t in enumerate(targets)])
+
+ unmatched_pred_captions, _ = nested_tensor_from_tensor_list(unmatched_pred_captions).decompose()
+ v_emb = unmatched_pred_captions[src_idx]
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ loss_contrast = ql_multi_contrastive_loss(torch.cat((v_emb, v_emb_class)), torch.cat((t_emb, t_emb_class)), torch.cat((t_hash, t_hash_class)), temperature=extra['lang_logit'])
+ losses = {"loss_caption_0": loss_contrast}
+
+ return losses
+
+ def loss_masks(self, outputs, targets, indices, num_masks, layer_id, extra):
+ """Compute the losses related to the masks: the focal loss and the dice loss.
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+ """
+ if layer_id >= self.top_x_layers['mask']:
+ return {"loss_mask_bce_0": 0, "loss_mask_dice_0": 0}
+
+ assert "pred_masks" in outputs
+ if indices is None or len(targets) == 0:
+ loss = outputs['pred_masks'].sum() * 0.0
+ losses = {"loss_mask_bce_0": loss, "loss_mask_dice_0": loss}
+ return losses
+
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+ src_masks = outputs["pred_masks"]
+ src_masks = src_masks[src_idx]
+ masks = [t["masks"] for t in targets]
+ # TODO use valid to mask invalid areas due to padding in loss
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
+ target_masks = target_masks.to(src_masks)
+ target_masks = target_masks[tgt_idx]
+ # No need to upsample predictions as we are using normalized coordinates :)
+ # N x 1 x H x W
+ src_masks = src_masks[:, None]
+ target_masks = target_masks[:, None]
+
+ with torch.no_grad():
+ # sample point_coords
+ point_coords = get_uncertain_point_coords_with_randomness(
+ src_masks,
+ lambda logits: calculate_uncertainty(logits),
+ self.num_points,
+ self.oversample_ratio,
+ self.importance_sample_ratio,
+ ).type(src_masks.dtype)
+ # get gt labels
+ point_labels = point_sample(
+ target_masks,
+ point_coords,
+ align_corners=False,
+ ).squeeze(1)
+
+ point_logits = point_sample(
+ src_masks,
+ point_coords,
+ align_corners=False,
+ ).squeeze(1)
+
+ losses = {
+ "loss_mask_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
+ "loss_mask_dice_0": dice_loss_jit(point_logits, point_labels, num_masks),
+ }
+
+ del src_masks
+ del target_masks
+ return losses
+
+ def loss_groundings(self, outputs, targets, indices, num_masks, layer_id, extra):
+ """Compute the losses related to the masks: the focal loss and the dice loss.
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+ """
+ assert "pred_gmasks" in outputs
+ assert "pred_gtexts" in outputs
+
+ if layer_id >= self.top_x_layers['grounding']:
+ return {"loss_grounding_bce_0": 0, "loss_grounding_dice_0": 0, "loss_grounding_ce_0": 0}
+
+ masks = [t["grounding_masks"] for t in targets]
+ if indices is None or None in masks:
+ loss = outputs['pred_gmasks'].sum() * 0.0
+ return {"loss_grounding_bce_0": loss, "loss_grounding_dice_0": loss, "loss_grounding_ce_0": loss}
+
+ pred_logits = []
+ for b in range(len(indices)):
+ t_emb = targets[b]['grounding_class_embs']
+ v_emb = outputs["pred_gtexts"][b]
+
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ out_prob = vl_similarity(v_emb, t_emb, temperature=extra['lang_logit'])
+ pred_logits += [out_prob]
+ outputs['pred_logits'] = pred_logits
+
+ indices = self.matcher(outputs, targets, mode='grounding', extra={'temperature':extra['lang_logit']})
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+
+ src_masks = outputs["pred_gmasks"]
+ src_masks = src_masks[src_idx]
+ # TODO use valid to mask invalid areas due to padding in loss
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
+ target_masks = target_masks.to(src_masks)
+ target_masks = target_masks[tgt_idx]
+ # No need to upsample predictions as we are using normalized coordinates :)
+ # N x 1 x H x W
+ src_masks = src_masks[:, None]
+ target_masks = target_masks[:, None]
+
+ with torch.no_grad():
+ # sample point_coords
+ point_coords = get_uncertain_point_coords_with_randomness(
+ src_masks,
+ lambda logits: calculate_uncertainty(logits),
+ self.num_points,
+ self.oversample_ratio,
+ self.importance_sample_ratio,
+ ).type(src_masks.dtype)
+ # get gt labels
+ point_labels = point_sample(
+ target_masks,
+ point_coords,
+ align_corners=False,
+ ).squeeze(1)
+
+ point_logits = point_sample(
+ src_masks,
+ point_coords,
+ align_corners=False,
+ ).squeeze(1)
+
+ losses = {
+ "loss_grounding_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, len(src_masks)),
+ "loss_grounding_dice_0": dice_loss_jit(point_logits, point_labels, len(src_masks)),
+ }
+
+ # compute query-token contrastive loss
+ # ttk_emb = torch.cat([x['caption_tokens'] for x in targets], dim=0)
+ # ttk_mask = torch.cat([x['caption_mask'] for x in targets], dim=0).float()
+ # ttk_mask = ttk_mask * torch.cumsum(ttk_mask, dim=1)
+ # vtk_emb = outputs['pred_captions'][:,:-1]
+ # keep = torch.cat([x['caption_mask'] for x in targets], dim=0).bool()
+
+ # ttk_emb = ttk_emb / (ttk_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ # vtk_emb = vtk_emb / (vtk_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ # logit_scale = extra['lang_encoder'].logit_scale.exp().clamp(max=100)
+
+ # # prepare gt
+ # gt = (torch.eye(vtk_emb.shape[0]).type_as(ttk_mask).unsqueeze(-1) * ttk_mask.unsqueeze(0).repeat(vtk_emb.shape[0], 1, 1))[:,keep].flatten(1)
+ # gt = gt / (gt.sum(1, keepdim=True) + 1e-7)
+ # # compute i2t loss
+ # logits = logit_scale * (vtk_emb @ ttk_emb[keep].transpose(0, 1)).mean(1)
+ # loss_contrast_fine_vt = SoftTargetCrossEntropy()(logits, gt)
+ # # loss_contrast_fine = loss_contrast_fine_vt # i2t only
+
+ # # compute t2i loss
+ # bs, nq, _ = vtk_emb.shape
+ # logits = logit_scale * (ttk_emb @ vtk_emb.flatten(0,1).transpose(0, 1)).reshape(bs,-1,bs,nq).mean(dim=-1)[keep,:]
+ # loss_contrast_fine_tv = SoftTargetCrossEntropy()(logits, gt.t())
+ # # compute loss
+ # loss_contrast_fine = (loss_contrast_fine_vt * 0.7 + loss_contrast_fine_tv * 0.3)
+
+ # compute t2i loss
+ loss_grd_ce = 0
+ for b in range(len(indices)):
+ task = targets[b]['grounding_task']
+ pred_logit = outputs["pred_logits"][b]
+ gt_logit = torch.zeros_like(pred_logit)
+ select_idx = torch.stack((indices[b][0], indices[b][1])).tolist()
+ gt_logit[select_idx] = 1
+ t_hash = torch.tensor(targets[b]['grounding_hash'], device=gt_logit.device)
+ hash_table = torch.zeros((len(t_hash), len(t_hash)), device=gt_logit.device)
+ for idx in range(0, len(hash_table)):
+ hash_table[idx][t_hash==t_hash[idx]] = 1
+ hash_table = hash_table / hash_table.sum(-1, keepdim=True)
+ gt_logit = gt_logit @ hash_table
+ loss_grd_ce += self.grounding_weight[task]*torch.sum(-gt_logit.t() * F.log_softmax(pred_logit.t(), dim=-1), dim=-1).mean()
+ loss_grd_ce = loss_grd_ce / len(indices)
+ losses.update({"loss_grounding_ce_0": loss_grd_ce})
+ del src_masks
+ del target_masks
+ return losses
+
+ def loss_spatials(self, outputs, targets, indices, num_masks, layer_id, extra):
+ """Compute the losses related to the masks: the focal loss and the dice loss.
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+ """
+ assert "pred_smasks" in outputs
+ assert "pred_smaskembs" in outputs
+
+ if layer_id >= self.top_x_layers['spatial']:
+ loss = outputs['pred_smasks'].sum() * 0.0
+ loss_grd_ce = outputs["pred_smasks"].sum() * 0.0
+ return {"loss_spatial_bce_0": loss, "loss_spatial_dice_0": loss, "loss_spatial_ce_0": loss_grd_ce}
+
+ gt_masks = [x['gt_spatial_masks'] for x in targets]
+ # compute a keep index with batch size to avoid empty gt_masks
+ stack_gt_mask = torch.cat(gt_masks)
+ bs,_,_ = stack_gt_mask.shape
+ stack_gt_mask = stack_gt_mask.view(bs,-1).sum(dim=-1)
+ keep = stack_gt_mask > 0 # only keep sample contain positive mask
+
+ if keep.sum() == 0:
+ loss = outputs['pred_smasks'].sum() * 0.0
+ loss_grd_ce = outputs["pred_smasks"].sum() * 0.0
+ return {"loss_spatial_bce_0": loss, "loss_spatial_dice_0": loss, "loss_spatial_ce_0": loss_grd_ce}
+
+ # mask embedding logits
+ v_emb = outputs["pred_smaskembs"] # [bs, nq, 512]
+
+ # pos mask
+ s_emb = outputs["pred_pspatials"] # [bs, ns, 512]
+ pred_logits = v_emb @ s_emb.transpose(1,2)
+ outputs['pred_pos_logits'] = pred_logits # [bs, nq, 1]
+ indices = self.matcher(outputs, targets, mode='spatial', extra={})
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+
+ # pos class loss
+ pred_logit = torch.cat([o[:len(t['gt_spatial_masks'])] for o,t in zip(outputs["pred_pos_logits"].transpose(1,2), targets)])
+ gt_logit = torch.zeros_like(pred_logit)
+ gt_logit = gt_logit[keep]
+ _src_idx = [torch.arange(keep.sum(), device=src_idx[0].device), src_idx[1][keep.cpu()]]
+ gt_logit[_src_idx] = 1
+ pred_logit = pred_logit[keep]
+ loss_spa_ce_pos = torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1).mean()
+
+ # neg mask
+ # s_emb = outputs["pred_nspatials"] # [bs, ns, 512]
+ # neg_mask = (s_emb.sum(dim=list(range(1, len(s_emb.shape)))) != 0).float()[keep]
+ # pred_logits = v_emb @ s_emb.transpose(1,2)
+ # outputs['pred_neg_logits'] = pred_logits # [bs, nq, 1]
+ # indices = self.matcher(outputs, targets, mode='spatial_pn', extra=extra)
+ # src_idx = self._get_src_permutation_idx(indices)
+ # tgt_idx = self._get_tgt_permutation_idx(indices)
+ # src_masks_neg = outputs["pred_smasks"][src_idx][keep]
+ # src_masks_neg = src_masks_neg*(neg_mask[:,None,None])
+ # src_masks_neg = src_masks_neg.clip(0) * (-1)
+
+ # neg class loss
+ # pred_logit = outputs["pred_neg_logits"]
+ # gt_logit = torch.zeros_like(pred_logit)
+ # gt_logit[src_idx] = 1
+ # bs,_,ns = pred_logit[keep].shape
+ # pred_logit = pred_logit[keep].transpose(1,2).view(bs*ns,-1)
+ # gt_logit = gt_logit[keep].transpose(1,2).view(bs*ns,-1)
+ # loss_spa_ce_neg = (torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1)*neg_mask).sum() / (neg_mask.sum()+1e-6)
+
+ # recompute a keep index with matched tgt
+ stack_gt_mask = nn.utils.rnn.pad_sequence(gt_masks, padding_value=-1).transpose(0,1)[tgt_idx]
+ bs,_,_ = stack_gt_mask.shape
+ target_masks = stack_gt_mask
+ stack_gt_mask = stack_gt_mask.view(bs,-1).sum(dim=-1)
+ keep = stack_gt_mask > 0 # only keep sample contain positive mask
+ src_masks_pos = outputs["pred_smasks"][src_idx][keep]
+
+ # TODO use valid to mask invalid areas due to padding in loss
+ target_masks = target_masks.to(src_masks_pos)
+ target_masks = target_masks[keep]
+
+ # mul = extra['spatial_query_mode'][keep]
+ # src_masks_cur = src_masks_cur.clip(0) * mul[:,None,None]
+ # src_masks_cur = src_masks_cur
+
+ # if neg_mask[0] == 1:
+ # import cv2
+ # print(src_masks_pos.shape)
+ # print(src_masks_neg.shape)
+ # print(target_masks.shape)
+ # # import pdb; pdb.set_trace()
+ # v_pos_mask = (src_masks_pos[0].sigmoid() > 0.5).float().cpu().detach().numpy() * 255
+ # v_neg_mask = (_src_masks_neg[0].sigmoid() > 0.5).float().cpu().detach().numpy() * 255
+ # v_sum = ((src_masks_pos[0]-_src_masks_neg[0].clip(0)).sigmoid() > 0.5).float().cpu().detach().numpy() * 255
+ # v_gt = target_masks[0].float().cpu().detach().numpy() * 255
+
+ # cv2.imwrite('v_pos_mask.png', v_pos_mask)
+ # cv2.imwrite('v_neg_mask.png', v_neg_mask)
+ # cv2.imwrite('v_sum.png', v_sum)
+ # cv2.imwrite('v_gt.png', v_gt)
+ # import pdb; pdb.set_trace()
+
+ # src_masks = (src_masks_pos + src_masks_neg)[:, None]
+ src_masks = src_masks_pos[:, None]
+ target_masks = target_masks[:, None]
+
+ # debug visualization
+ # with torch.no_grad():
+ # import cv2
+ # import numpy as np
+
+ # v_src_masks = (F.interpolate(src_masks, size=target_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5).float().cpu().numpy()[:,0] * 255
+ # v_target_masks = target_masks.float().cpu().numpy()[:,0] * 255
+ # v_masks = np.concatenate([v_src_masks, v_target_masks], axis=2)
+
+ # for i in range(len(src_masks)):
+ # v1 = v_src_masks[i]
+ # v2 = v_target_masks[i]
+ # v = np.concatenate([v1,v2], axis=1)
+ # cv2.imwrite('v{}.png'.format(i), v)
+ # import pdb; pdb.set_trace()
+
+ # visualization
+ # VL.step()
+ # v_img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
+ # VL.add_image(v_img[:,:,::-1])
+ # candidate_masks = batched_inputs[0]['spatial_query']['rand_shape'].float().cpu().numpy()
+ # gt_masks = batched_inputs[0]['spatial_query']['gt_masks'].float().cpu().numpy()
+ # texts = ['cmask' for i in range(len(candidate_masks))]
+ # VL.overlay_obj_mask_to_image(v_img[:,:,::-1], candidate_masks, texts)
+ # texts = ['gmask' for i in range(len(candidate_masks))]
+ # VL.overlay_obj_mask_to_image(v_img[:,:,::-1], gt_masks, texts)
+
+ # import cv2
+ # for i in range(len(src_masks)):
+ # visual_src_mask_cur = (src_masks_cur[i].sigmoid()>0.5).detach().float().cpu().numpy() * 255
+ # visual_src_mask_mem = (src_masks_mem[i].sigmoid()>0.5).detach().float().cpu().numpy() * 255
+ # visual_src_mask = (src_masks[i,0].sigmoid()>0.5).detach().float().cpu().numpy() * 255
+ # visual_target_mask = (target_masks[i,0].sigmoid()>0.5).detach().float().cpu().numpy() * 255
+
+ # cv2.imwrite('visual_src_mask_cur_{}_{}.png'.format(i, mul[i].item()), visual_src_mask_cur)
+ # cv2.imwrite('visual_src_mask_mem_{}_{}.png'.format(i, mul[i].item()), visual_src_mask_mem)
+ # cv2.imwrite('visual_src_mask_{}_{}.png'.format(i, mul[i].item()), visual_src_mask)
+ # cv2.imwrite('visual_target_mask_{}_{}.png'.format(i, mul[i].item()), visual_target_mask)
+ # import pdb; pdb.set_trace()
+
+ with torch.no_grad():
+ # sample point_coords
+ point_coords = get_uncertain_point_coords_with_randomness(
+ src_masks,
+ lambda logits: calculate_uncertainty(logits),
+ self.num_points,
+ self.oversample_ratio,
+ self.importance_sample_ratio,
+ ).type(src_masks.dtype)
+ # get gt labels
+ point_labels = point_sample(
+ target_masks,
+ point_coords,
+ align_corners=False,
+ ).squeeze(1)
+
+ point_logits = point_sample(
+ src_masks,
+ point_coords,
+ align_corners=False,
+ ).squeeze(1)
+
+ num_masks = len(src_masks)
+ losses = {
+ "loss_spatial_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
+ "loss_spatial_dice_0": dice_loss_jit(point_logits, point_labels, num_masks),
+ }
+
+ # losses.update({"loss_spatial_ce_0": loss_spa_ce_pos + loss_spa_ce_neg})
+ losses.update({"loss_spatial_ce_0": loss_spa_ce_pos})
+
+ del src_masks
+ del target_masks
+ return losses
+
+ def loss_boxes(self, outputs, targets, indices, num_boxes, layer_id, extra):
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
+ """
+ if layer_id >= self.top_x_layers['box']:
+ return {"loss_bbox_0": 0, "loss_giou_0": 0}
+
+ assert 'pred_boxes' in outputs
+
+ if indices is None or len(targets) == 0:
+ loss = outputs['pred_boxes'].sum() * 0.0
+ losses = {"loss_bbox_0": loss, "loss_giou_0": loss}
+ return losses
+
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+ src_boxes = outputs["pred_boxes"]
+ src_boxes = src_boxes[src_idx].sigmoid()
+
+ target_boxes = [t['boxes'] for t in targets]
+ max_size = _max_by_axis([list(box.shape) for box in target_boxes])
+ max_size = [len(target_boxes)] + max_size
+ empty_boxes = torch.zeros(max_size).to(src_boxes.device)
+ for idx, tar_box in enumerate(target_boxes):
+ empty_boxes[idx,:tar_box.shape[0],:] = tar_box
+ target_boxes = empty_boxes[tgt_idx]
+
+ # target_isthings = [t['is_things'] for t in targets]
+ # max_size = _max_by_axis([list(lab.shape) for lab in target_isthings])
+ # max_size = [len(target_isthings)] + max_size
+ # empty_lab = torch.zeros(max_size).to(src_boxes.device)
+
+ # for idx, tar_thing in enumerate(target_isthings):
+ # empty_lab[idx,:tar_thing.shape[0]] = tar_thing
+ # target_isthings = empty_lab[tgt_idx]
+
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
+ losses = {}
+ losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes
+
+ loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
+ box_ops.box_cxcywh_to_xyxy(target_boxes)))
+ losses['loss_giou_0'] = loss_giou.sum() / num_boxes
+ return losses
+
+ def _get_src_permutation_idx(self, indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ def _get_tgt_permutation_idx(self, indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+ def get_loss(self, loss, outputs, targets, indices, num_masks, layer_id, extra):
+ loss_map = {
+ 'labels': self.loss_labels,
+ 'masks': self.loss_masks,
+ 'boxes': self.loss_boxes,
+ 'captions': self.loss_captions,
+ 'retrievals': self.loss_itc,
+ 'captionings': self.loss_captionings,
+ 'groundings': self.loss_groundings,
+ 'labels_openimage': self.loss_labels_openimage,
+ 'spatials': self.loss_spatials,
+ }
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
+ return loss_map[loss](outputs, targets, indices, num_masks, layer_id, extra)
+
+ def forward(self, outputs, targets, extra=None):
+ """This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
+
+ # Retrieve the matching between the outputs of the last layer and the targets
+ indices = self.matcher(outputs_without_aux, targets)
+
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
+ num_masks = sum(len(t["labels"]) for t in targets)
+ num_masks = torch.as_tensor(
+ [num_masks], dtype=torch.float, device=next(iter(outputs_without_aux.values())).device
+ )
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_masks)
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
+
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if "aux_outputs" in outputs:
+ # NOTE: we reverse the aux_outputs so that the first is the second last layer
+ for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
+ indices = self.matcher(aux_outputs, targets)
+ for loss in self.losses:
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
+ l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ return losses
+
+ def forward_vlp(self, outputs, targets, extra=None):
+ """This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ # Compute all the requested losses
+ losses = {}
+ num_masks = indices = None
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if "aux_outputs" in outputs:
+ # NOTE: we reverse the aux_outputs so that the first is the second last layer
+ for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
+ for loss in self.losses:
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
+ l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ return losses
+
+ def forward_grounding(self, outputs, targets, extra=None):
+ """This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ # Compute all the requested losses
+ losses = {}
+ indices = [[] for i in range(len(targets))]
+
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
+ num_masks = sum(len(t["grounding_masks"]) for t in targets) + 1e-7
+ num_masks = torch.as_tensor(
+ [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
+ )
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_masks)
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
+
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if "aux_outputs" in outputs:
+ # NOTE: we reverse the aux_outputs so that the first is the second last layer
+ for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
+ for loss in self.losses:
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
+ l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ return losses
+
+ def forward_openimage(self, outputs, targets, extra=None):
+ """This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ neg_class_emb = all_gather_grad(torch.cat([x['neg_class_emb'] for x in targets]))
+ neg_hash = all_gather_grad(torch.cat([x['neg_hash'] for x in targets]))
+
+ extra['neg_class_emb'] = neg_class_emb
+ extra['neg_hash'] = neg_hash
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
+
+ # Retrieve the matching between the outputs of the last layer and the targets
+ indices, pred_logits = self.matcher.openimage_forward(outputs_without_aux, targets, extra=extra)
+ outputs['pred_logits'] = pred_logits
+
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
+ num_masks = sum(len(t["labels"]) for t in targets)
+ num_masks = torch.as_tensor(
+ [num_masks], dtype=torch.float, device=neg_class_emb.device
+ )
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_masks)
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
+
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if "aux_outputs" in outputs:
+ # NOTE: we reverse the aux_outputs so that the first is the second last layer
+ for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
+ indices, pred_logits = self.matcher.openimage_forward(aux_outputs, targets, extra=extra)
+ aux_outputs['pred_logits'] = pred_logits
+ for loss in self.losses:
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
+ l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ return losses
+
+ def __repr__(self):
+ head = "Criterion " + self.__class__.__name__
+ body = [
+ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
+ "losses: {}".format(self.losses),
+ "weight_dict: {}".format(self.weight_dict),
+ "num_classes: {}".format(self.num_classes),
+ "eos_coef: {}".format(self.eos_coef),
+ "num_points: {}".format(self.num_points),
+ "oversample_ratio: {}".format(self.oversample_ratio),
+ "importance_sample_ratio: {}".format(self.importance_sample_ratio),
+ ]
+ _repr_indent = 4
+ lines = [head] + [" " * _repr_indent + line for line in body]
+ return "\n".join(lines)
diff --git a/modeling/modules/matcher.py b/modeling/modules/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c8abb59f245b716744b2083da4242542265810b
--- /dev/null
+++ b/modeling/modules/matcher.py
@@ -0,0 +1,632 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
+"""
+Modules to compute the matching cost and solve the corresponding LSAP.
+"""
+import warnings
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.optimize import linear_sum_assignment
+from torch import nn
+from torch.cuda.amp import autocast
+
+from .point_features import point_sample
+from ..language.loss import vl_similarity
+
+def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
+ """
+ Compute the DICE loss, similar to generalized IOU for masks
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ """
+ inputs = inputs.sigmoid()
+ inputs = inputs.flatten(1)
+ numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
+ denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss
+
+
+batch_dice_loss_jit = torch.jit.script(
+ batch_dice_loss
+) # type: torch.jit.ScriptModule
+
+
+def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
+ """
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ Returns:
+ Loss tensor
+ """
+ hw = inputs.shape[1]
+
+ pos = F.binary_cross_entropy_with_logits(
+ inputs, torch.ones_like(inputs), reduction="none"
+ )
+ neg = F.binary_cross_entropy_with_logits(
+ inputs, torch.zeros_like(inputs), reduction="none"
+ )
+
+ loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
+ "nc,mc->nm", neg, (1 - targets)
+ )
+
+ return loss / hw
+
+
+batch_sigmoid_ce_loss_jit = torch.jit.script(
+ batch_sigmoid_ce_loss
+) # type: torch.jit.ScriptModule
+
+
+class HungarianMatcher(nn.Module):
+ """This class computes an assignment between the targets and the predictions of the network
+
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+ while the others are un-matched (and thus treated as non-objects).
+ """
+
+ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0, spatial_cost = None):
+ """Creates the matcher
+
+ Params:
+ cost_class: This is the relative weight of the classification error in the matching cost
+ cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
+ cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
+ """
+ super().__init__()
+ self.cost_class = cost_class
+ self.cost_mask = cost_mask
+ self.cost_dice = cost_dice
+
+ self.num_points = num_points
+ self.spatial_cost_class = cost_class
+ self.spatial_cost_mask = cost_mask
+ self.spatial_cost_dice = cost_dice
+ assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
+
+ @torch.no_grad()
+ def memory_efficient_forward(self, outputs, targets):
+ """More memory-friendly matching"""
+ bs, num_queries = outputs["pred_logits"].shape[:2]
+
+ if bs == 0 or len(targets) == 0:
+ return None
+
+ indices = []
+
+ # Iterate through batch size
+ for b in range(bs):
+ out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes]
+ tgt_ids = targets[b]["labels"]
+
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
+ cost_class = -out_prob[:, tgt_ids]
+
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
+ # gt masks are already padded when preparing target
+ tgt_mask = targets[b]["masks"].to(out_mask)
+
+ out_mask = out_mask[:, None]
+ tgt_mask = tgt_mask[:, None]
+ # all masks share the same set of points for efficient matching!
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
+ # get gt labels
+ tgt_mask = point_sample(
+ tgt_mask,
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ out_mask = point_sample(
+ out_mask,
+ point_coords.repeat(out_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ with autocast(enabled=False):
+ out_mask = out_mask.float()
+ tgt_mask = tgt_mask.float()
+ # Compute the focal loss between masks
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
+
+ # Compute the dice loss betwen masks
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
+
+ # Final cost matrix
+ C = (
+ self.cost_mask * cost_mask
+ + self.cost_class * cost_class
+ + self.cost_dice * cost_dice
+ )
+ C = C.reshape(num_queries, -1).cpu()
+ if C.isnan().any():
+ C[C.isnan()] = 1e6 ### temporary fix
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
+ raise
+ indices.append(linear_sum_assignment(C))
+
+ return [
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
+ for i, j in indices
+ ]
+
+ @torch.no_grad()
+ def openimage_forward(self, outputs, targets, extra):
+ """More memory-friendly matching"""
+ bs, num_queries = outputs["pred_captions"].shape[:2]
+ if bs == 0 or len(targets) == 0:
+ return None
+
+ neg_class_emb = extra['neg_class_emb']
+ neg_hash = extra['neg_hash']
+ _, unique_indices = np.unique(neg_hash.cpu().numpy(), return_index=True)
+ neg_class_emb = neg_class_emb[unique_indices]
+ neg_hash = neg_hash[unique_indices]
+
+ indices = []
+ pred_logits = []
+ # Iterate through batch size
+ for b in range(bs):
+ _pos_class_emb = targets[b]['pos_class_emb']
+ _pos_hash = targets[b]['pos_hash']
+ _neg_overlap_pos = ~(neg_hash[..., None] == _pos_hash).any(-1)
+ _neg_class_emb = neg_class_emb[_neg_overlap_pos]
+ t_emb = torch.cat((_pos_class_emb, _neg_class_emb))
+ v_emb = outputs["pred_captions"][b]
+ del _pos_class_emb
+ del _neg_class_emb
+
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
+
+ out_prob = vl_similarity(v_emb, t_emb, temperature=extra['lang_logit'])
+ pred_logits += [out_prob]
+ out_prob = out_prob.softmax(-1)
+ tgt_ids = targets[b]["labels"]
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
+ cost_class = -out_prob[:, tgt_ids]
+
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
+ # gt masks are already padded when preparing target
+ tgt_mask = targets[b]["masks"].to(out_mask)
+
+ out_mask = out_mask[:, None]
+ tgt_mask = tgt_mask[:, None]
+ # all masks share the same set of points for efficient matching!
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
+ # get gt labels
+ tgt_mask = point_sample(
+ tgt_mask,
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ out_mask = point_sample(
+ out_mask,
+ point_coords.repeat(out_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ with autocast(enabled=False):
+ out_mask = out_mask.float()
+ tgt_mask = tgt_mask.float()
+ # Compute the focal loss between masks
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
+
+ # Compute the dice loss betwen masks
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
+
+ # Final cost matrix
+ C = (
+ self.cost_mask * cost_mask
+ + self.cost_class * cost_class
+ + self.cost_dice * cost_dice
+ )
+ C = C.reshape(num_queries, -1).cpu()
+ if C.isnan().any():
+ C[C.isnan()] = 1e6 ### temporary fix
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
+ raise
+ indices.append(linear_sum_assignment(C))
+
+ return [
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
+ for i, j in indices
+ ], pred_logits
+
+ @torch.no_grad()
+ def grounding_forward(self, outputs, targets, extra):
+ """More memory-friendly matching"""
+ bs, num_queries = outputs["pred_gmasks"].shape[:2]
+
+ if bs == 0 or len(targets) == 0:
+ return None
+
+ indices = []
+ # Iterate through batch size
+ for b in range(bs):
+ out_prob = outputs["pred_logits"][b]
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
+ cost_class = -out_prob.softmax(dim=0)
+
+ out_mask = outputs["pred_gmasks"][b] # [num_queries, H_pred, W_pred]
+ # gt masks are already padded when preparing target
+ tgt_mask = targets[b]["grounding_masks"].to(out_mask)
+
+ out_mask = out_mask[:, None]
+ tgt_mask = tgt_mask[:, None]
+
+ # all masks share the same set of points for efficient matching!
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
+ # get gt labels
+ tgt_mask = point_sample(
+ tgt_mask,
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ out_mask = point_sample(
+ out_mask,
+ point_coords.repeat(out_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ with autocast(enabled=False):
+ out_mask = out_mask.float()
+ tgt_mask = tgt_mask.float()
+ # Compute the focal loss between masks
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
+
+ # Compute the dice loss betwen masks
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
+
+ # Final cost matrix
+ C = (
+ self.cost_mask * cost_mask
+ + self.cost_class * cost_class
+ + self.cost_dice * cost_dice
+ )
+ C = C.reshape(num_queries, -1).cpu()
+ if C.isnan().any():
+ C[C.isnan()] = 1e6 ### temporary fix
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
+ raise
+ indices.append(linear_sum_assignment(C))
+
+ return [
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
+ for i, j in indices
+ ]
+
+ @torch.no_grad()
+ def spatial_forward(self, outputs, targets, extra):
+ """More memory-friendly matching"""
+ bs, num_queries = outputs["pred_smasks"].shape[:2]
+
+ if bs == 0 or len(targets) == 0:
+ return None
+
+ indices = []
+ # Iterate through batch size
+ for b in range(bs):
+ out_mask = outputs["pred_smasks"][b] # [num_queries, H_pred, W_pred]
+ # gt masks are already padded when preparing target
+ tgt_mask = targets[b]["gt_spatial_masks"].to(out_mask)
+ nd,ns = outputs["pred_pos_logits"][b].shape
+ index_masking = 1-torch.eye(ns, device=out_mask.device, dtype=tgt_mask.dtype).repeat_interleave(nd//ns,dim=0)
+ neg_masking = torch.zeros((nd,ns), device=out_mask.device, dtype=tgt_mask.dtype)
+ neg_masking.masked_fill_(index_masking.bool(), -float('inf'))
+ pos_masking = torch.zeros((nd,ns), device=out_mask.device, dtype=tgt_mask.dtype)
+ pos_masking.masked_fill_(index_masking.bool(), float('inf'))
+ out_prob = (outputs["pred_pos_logits"][b]+neg_masking)[:,:len(tgt_mask)] # remove redundant predictions for padding
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
+ cost_class = -out_prob.softmax(dim=0)
+
+ out_mask = out_mask[:, None]
+ tgt_mask = tgt_mask[:, None]
+
+ # all masks share the same set of points for efficient matching!
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
+ # get gt labels
+ tgt_mask = point_sample(
+ tgt_mask,
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ out_mask = point_sample(
+ out_mask,
+ point_coords.repeat(out_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ with autocast(enabled=False):
+ out_mask = out_mask.float()
+ tgt_mask = tgt_mask.float()
+ # Compute the focal loss between masks
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) + pos_masking[:,:len(tgt_mask)]
+ # Compute the dice loss betwen masks
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) + pos_masking[:,:len(tgt_mask)]
+
+ # Final cost matrix
+ C = (
+ self.spatial_cost_mask * cost_mask
+ + self.spatial_cost_class * cost_class
+ + self.spatial_cost_dice * cost_dice
+ )
+ C = C.reshape(num_queries, -1).cpu()
+ if C.isnan().any():
+ C[C.isnan()] = 1e6 ### temporary fix
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
+ raise
+ indices.append(linear_sum_assignment(C))
+
+ return [
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
+ for i, j in indices
+ ]
+
+ @torch.no_grad()
+ def spatial_forward_pn(self, outputs, targets, extra):
+ """More memory-friendly matching"""
+ bs, num_queries = outputs["pred_smasks"].shape[:2]
+
+ if bs == 0 or len(targets) == 0:
+ return None
+
+ fp_mask = extra['false_positive_mask']
+ gt_mask = torch.stack([targets[b]["gt_spatial_masks"] for b in range(bs)])
+
+ indices = []
+ # Iterate through batch size
+ for b in range(bs):
+ out_prob = outputs["pred_neg_logits"][b]
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
+ cost_class = -out_prob.softmax(dim=0)
+
+ out_mask = outputs["pred_smasks"][b] # [num_queries, H_pred, W_pred]
+ tgt_mask = fp_mask[b].to(out_mask)
+ ign_mask = (gt_mask[b] | fp_mask[b]).to(out_mask)
+
+ out_mask = out_mask[:, None]
+ tgt_mask = tgt_mask[:, None]
+ ign_mask = ign_mask[:, None]
+
+ # all masks share the same set of points for efficient matching!
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
+
+ # get gt labels
+ tgt_mask = point_sample(
+ tgt_mask,
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ out_mask = point_sample(
+ out_mask,
+ point_coords.repeat(out_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ ign_mask = point_sample(
+ ign_mask,
+ point_coords.repeat(ign_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ with autocast(enabled=False):
+ out_mask = out_mask.float()
+ tgt_mask = tgt_mask.float()
+ ign_mask = ign_mask.float()
+
+ # Compute the focal loss between masks
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask*ign_mask, tgt_mask*ign_mask)
+
+ # Compute the dice loss betwen masks
+ cost_dice = batch_dice_loss_jit(out_mask*ign_mask, tgt_mask*ign_mask)
+
+ # Final cost matrix
+ C = (
+ self.spatial_cost_mask * cost_mask
+ + self.spatial_cost_class * cost_class
+ + self.spatial_cost_dice * cost_dice
+ )
+ C = C.reshape(num_queries, -1).cpu()
+ if C.isnan().any():
+ C[C.isnan()] = 1e6 ### temporary fix
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
+ raise
+ indices.append(linear_sum_assignment(C))
+
+ return [
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
+ for i, j in indices
+ ]
+
+ @torch.no_grad()
+ def caption_forward_womask(self, outputs, targets, extra):
+ """More memory-friendly matching"""
+ bs, _ = outputs["pred_logits"].shape[:2]
+
+ if bs == 0 or len(targets) == 0:
+ return None
+
+ indices = []
+ t_emb = torch.cat([t['captions'] for t in targets])
+ v_emb = outputs['unmatched_pred_captions']
+ caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets])
+
+ # Iterate through batch size
+ for b in range(bs):
+ v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7)
+ num_queries = len(v_emb[b])
+ out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0]
+ tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])]
+
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
+ cost_class = -out_prob[:, tgt_ids]
+
+ # Final cost matrix
+ C = (self.cost_class * cost_class)
+ C = C.reshape(num_queries, -1).cpu()
+ if C.isnan().any():
+ C[C.isnan()] = 1e6 ### temporary fix
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
+ raise
+ indices.append(linear_sum_assignment(C))
+
+ return [
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
+ for i, j in indices
+ ]
+
+ @torch.no_grad()
+ def caption_forward_wmask(self, outputs, targets, extra):
+ """More memory-friendly matching"""
+ bs, _ = outputs["pred_logits"].shape[:2]
+
+ if bs == 0 or len(targets) == 0:
+ return None
+
+ indices = []
+ t_emb = torch.cat([t['captions'] for t in targets])
+ v_emb = outputs['unmatched_pred_captions']
+ caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets])
+
+ # Iterate through batch size
+ for b in range(bs):
+ v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7)
+ num_queries = len(v_emb[b])
+
+ out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0]
+ tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])]
+
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
+ cost_class = -out_prob[:, tgt_ids]
+
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
+ # gt masks are already padded when preparing target
+ tgt_mask = targets[b]["masks"].to(out_mask)
+
+ out_mask = out_mask[:, None]
+ tgt_mask = tgt_mask[:, None]
+ # all masks share the same set of points for efficient matching!
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
+ # get gt labels
+ tgt_mask = point_sample(
+ tgt_mask,
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ out_mask = point_sample(
+ out_mask,
+ point_coords.repeat(out_mask.shape[0], 1, 1),
+ align_corners=False,
+ ).squeeze(1)
+
+ with autocast(enabled=False):
+ out_mask = out_mask.float()
+ tgt_mask = tgt_mask.float()
+ # Compute the focal loss between masks
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
+
+ # Compute the dice loss betwen masks
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
+
+ # Final cost matrix
+ C = (
+ self.cost_mask * cost_mask
+ + self.cost_class * cost_class
+ + self.cost_dice * cost_dice
+ )
+ C = C.reshape(num_queries, -1).cpu()
+ if C.isnan().any():
+ C[C.isnan()] = 1e6 ### temporary fix
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
+ raise
+ indices.append(linear_sum_assignment(C))
+
+ return [
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
+ for i, j in indices
+ ]
+
+ @torch.no_grad()
+ def forward(self, outputs, targets, mode='default', extra={}):
+ """Performs the matching
+
+ Params:
+ outputs: This is a dict that contains at least these entries:
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+ "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
+
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+ objects in the target) containing the class labels
+ "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
+
+ Returns:
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected targets (in order)
+ For each batch element, it holds:
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+ """
+ if mode == 'default':
+ return self.memory_efficient_forward(outputs, targets)
+ elif mode == 'grounding':
+ return self.grounding_forward(outputs, targets, extra)
+ elif mode == 'spatial':
+ return self.spatial_forward(outputs, targets, extra)
+ elif mode == 'spatial_pn':
+ return self.spatial_forward_pn(outputs, targets, extra)
+ elif mode == 'caption_womask':
+ return self.caption_forward_womask(outputs, targets, extra)
+ elif mode == 'caption_wmask':
+ return self.caption_forward_wmask(outputs, targets, extra)
+ else:
+ assert False, "Mode {} is not supported.".format(mode)
+
+ def __repr__(self, _repr_indent=4):
+ head = "Matcher " + self.__class__.__name__
+ body = [
+ "cost_class: {}".format(self.cost_class),
+ "cost_mask: {}".format(self.cost_mask),
+ "cost_dice: {}".format(self.cost_dice),
+ ]
+ lines = [head] + [" " * _repr_indent + line for line in body]
+ return "\n".join(lines)
diff --git a/modeling/modules/point_features.py b/modeling/modules/point_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..c770811ff7ef1acc6d628b560079e27fda3a347b
--- /dev/null
+++ b/modeling/modules/point_features.py
@@ -0,0 +1,261 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import torch
+from torch.nn import functional as F
+
+from detectron2.layers import cat, shapes_to_tensor
+from detectron2.structures import BitMasks, Boxes
+
+# from ..layers import cat, shapes_to_tensor
+# from ..structures import BitMasks, Boxes
+
+"""
+Shape shorthand in this module:
+
+ N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the
+ number of images for semantic segmenation.
+ R: number of ROIs, combined over all images, in the minibatch
+ P: number of points
+"""
+
+
+def point_sample(input, point_coords, **kwargs):
+ """
+ A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
+ Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
+ [0, 1] x [0, 1] square.
+
+ Args:
+ input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
+ point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
+ [0, 1] x [0, 1] normalized point coordinates.
+
+ Returns:
+ output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
+ features for points in `point_coords`. The features are obtained via bilinear
+ interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
+ """
+ add_dim = False
+ if point_coords.dim() == 3:
+ add_dim = True
+ point_coords = point_coords.unsqueeze(2)
+ output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
+ if add_dim:
+ output = output.squeeze(3)
+ return output
+
+
+def generate_regular_grid_point_coords(R, side_size, device):
+ """
+ Generate regular square grid of points in [0, 1] x [0, 1] coordinate space.
+
+ Args:
+ R (int): The number of grids to sample, one for each region.
+ side_size (int): The side size of the regular grid.
+ device (torch.device): Desired device of returned tensor.
+
+ Returns:
+ (Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates
+ for the regular grids.
+ """
+ aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device)
+ r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False)
+ return r.view(1, -1, 2).expand(R, -1, -1)
+
+
+def get_uncertain_point_coords_with_randomness(
+ coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio
+):
+ """
+ Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
+ are calculated for each point using 'uncertainty_func' function that takes point's logit
+ prediction as input.
+ See PointRend paper for details.
+
+ Args:
+ coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
+ class-specific or class-agnostic prediction.
+ uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
+ contains logit predictions for P points and returns their uncertainties as a Tensor of
+ shape (N, 1, P).
+ num_points (int): The number of points P to sample.
+ oversample_ratio (int): Oversampling parameter.
+ importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
+
+ Returns:
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
+ sampled points.
+ """
+ assert oversample_ratio >= 1
+ assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
+ num_boxes = coarse_logits.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device, dtype=coarse_logits.dtype)
+ point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
+ # It is crucial to calculate uncertainty based on the sampled prediction value for the points.
+ # Calculating uncertainties of the coarse predictions first and sampling them for points leads
+ # to incorrect results.
+ # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
+ # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
+ # However, if we calculate uncertainties for the coarse predictions first,
+ # both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
+ point_uncertainties = uncertainty_func(point_logits)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
+ num_boxes, num_uncertain_points, 2
+ )
+ if num_random_points > 0:
+ point_coords = cat(
+ [
+ point_coords,
+ torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
+ ],
+ dim=1,
+ )
+ return point_coords
+
+
+def get_uncertain_point_coords_on_grid(uncertainty_map, num_points):
+ """
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
+
+ Args:
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
+ values for a set of points on a regular H x W grid.
+ num_points (int): The number of points P to select.
+
+ Returns:
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
+ [0, H x W) of the most uncertain points.
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
+ coordinates of the most uncertain points from the H x W grid.
+ """
+ R, _, H, W = uncertainty_map.shape
+ h_step = 1.0 / float(H)
+ w_step = 1.0 / float(W)
+
+ num_points = min(H * W, num_points)
+ point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1]
+ point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device)
+ point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
+ point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
+ return point_indices, point_coords
+
+
+def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords):
+ """
+ Get features from feature maps in `features_list` that correspond to specific point coordinates
+ inside each bounding box from `boxes`.
+
+ Args:
+ features_list (list[Tensor]): A list of feature map tensors to get features from.
+ feature_scales (list[float]): A list of scales for tensors in `features_list`.
+ boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all
+ together.
+ point_coords (Tensor): A tensor of shape (R, P, 2) that contains
+ [0, 1] x [0, 1] box-normalized coordinates of the P sampled points.
+
+ Returns:
+ point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled
+ from all features maps in feature_list for P sampled points for all R boxes in `boxes`.
+ point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level
+ coordinates of P points.
+ """
+ cat_boxes = Boxes.cat(boxes)
+ num_boxes = [b.tensor.size(0) for b in boxes]
+
+ point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords)
+ split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes)
+
+ point_features = []
+ for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image):
+ point_features_per_image = []
+ for idx_feature, feature_map in enumerate(features_list):
+ h, w = feature_map.shape[-2:]
+ scale = shapes_to_tensor([w, h]) / feature_scales[idx_feature]
+ point_coords_scaled = point_coords_wrt_image_per_image / scale.to(feature_map.device)
+ point_features_per_image.append(
+ point_sample(
+ feature_map[idx_img].unsqueeze(0),
+ point_coords_scaled.unsqueeze(0),
+ align_corners=False,
+ )
+ .squeeze(0)
+ .transpose(1, 0)
+ )
+ point_features.append(cat(point_features_per_image, dim=1))
+
+ return cat(point_features, dim=0), point_coords_wrt_image
+
+
+def get_point_coords_wrt_image(boxes_coords, point_coords):
+ """
+ Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates.
+
+ Args:
+ boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes.
+ coordinates.
+ point_coords (Tensor): A tensor of shape (R, P, 2) that contains
+ [0, 1] x [0, 1] box-normalized coordinates of the P sampled points.
+
+ Returns:
+ point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains
+ image-normalized coordinates of P sampled points.
+ """
+ with torch.no_grad():
+ point_coords_wrt_image = point_coords.clone()
+ point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * (
+ boxes_coords[:, None, 2] - boxes_coords[:, None, 0]
+ )
+ point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * (
+ boxes_coords[:, None, 3] - boxes_coords[:, None, 1]
+ )
+ point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0]
+ point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1]
+ return point_coords_wrt_image
+
+
+def sample_point_labels(instances, point_coords):
+ """
+ Sample point labels from ground truth mask given point_coords.
+
+ Args:
+ instances (list[Instances]): A list of N Instances, where N is the number of images
+ in the batch. So, i_th elememt of the list contains R_i objects and R_1 + ... + R_N is
+ equal to R. The ground-truth gt_masks in each instance will be used to compute labels.
+ points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of
+ instances and P is the number of points for each instance. The coordinates are in
+ the absolute image pixel coordinate space, i.e. [0, H] x [0, W].
+
+ Returns:
+ Tensor: A tensor of shape (R, P) that contains the labels of P sampled points.
+ """
+ with torch.no_grad():
+ gt_mask_logits = []
+ point_coords_splits = torch.split(
+ point_coords, [len(instances_per_image) for instances_per_image in instances]
+ )
+ for i, instances_per_image in enumerate(instances):
+ if len(instances_per_image) == 0:
+ continue
+ assert isinstance(
+ instances_per_image.gt_masks, BitMasks
+ ), "Point head works with GT in 'bitmask' format. Set INPUT.MASK_FORMAT to 'bitmask'."
+
+ gt_bit_masks = instances_per_image.gt_masks.tensor
+ h, w = instances_per_image.gt_masks.image_size
+ scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device)
+ points_coord_grid_sample_format = point_coords_splits[i] / scale
+ gt_mask_logits.append(
+ point_sample(
+ gt_bit_masks.to(torch.float32).unsqueeze(1),
+ points_coord_grid_sample_format,
+ align_corners=False,
+ ).squeeze(1)
+ )
+
+ point_labels = cat(gt_mask_logits)
+ return point_labels
diff --git a/modeling/modules/position_encoding.py b/modeling/modules/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..f49e05d57be26d70d150422b57b24fefac88bf06
--- /dev/null
+++ b/modeling/modules/position_encoding.py
@@ -0,0 +1,64 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
+"""
+Various positional encodings for the transformer.
+"""
+import math
+
+import torch
+from torch import nn
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=x.dtype)
+ x_embed = not_mask.cumsum(2, dtype=x.dtype)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+ def __repr__(self, _repr_indent=4):
+ head = "Positional encoding " + self.__class__.__name__
+ body = [
+ "num_pos_feats: {}".format(self.num_pos_feats),
+ "temperature: {}".format(self.temperature),
+ "normalize: {}".format(self.normalize),
+ "scale: {}".format(self.scale),
+ ]
+ # _repr_indent = 4
+ lines = [head] + [" " * _repr_indent + line for line in body]
+ return "\n".join(lines)
diff --git a/modeling/modules/postprocessing.py b/modeling/modules/postprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..93e00439921b34300eb52b3ede8622ebf6afb63e
--- /dev/null
+++ b/modeling/modules/postprocessing.py
@@ -0,0 +1,122 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import torch
+from torch.nn import functional as F
+
+from detectron2.structures import Instances, ROIMasks
+
+
+# perhaps should rename to "resize_instance"
+def detector_postprocess(
+ results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5
+):
+ """
+ Resize the output instances.
+ The input images are often resized when entering an object detector.
+ As a result, we often need the outputs of the detector in a different
+ resolution from its inputs.
+
+ This function will resize the raw outputs of an R-CNN detector
+ to produce outputs according to the desired output resolution.
+
+ Args:
+ results (Instances): the raw outputs from the detector.
+ `results.image_size` contains the input image resolution the detector sees.
+ This object might be modified in-place.
+ output_height, output_width: the desired output resolution.
+
+ Returns:
+ Instances: the resized output from the model, based on the output resolution
+ """
+ if isinstance(output_width, torch.Tensor):
+ # This shape might (but not necessarily) be tensors during tracing.
+ # Converts integer tensors to float temporaries to ensure true
+ # division is performed when computing scale_x and scale_y.
+ output_width_tmp = output_width.float()
+ output_height_tmp = output_height.float()
+ new_size = torch.stack([output_height, output_width])
+ else:
+ new_size = (output_height, output_width)
+ output_width_tmp = output_width
+ output_height_tmp = output_height
+
+ scale_x, scale_y = (
+ output_width_tmp / results.image_size[1],
+ output_height_tmp / results.image_size[0],
+ )
+ results = Instances(new_size, **results.get_fields())
+
+ if results.has("pred_boxes"):
+ output_boxes = results.pred_boxes
+ elif results.has("proposal_boxes"):
+ output_boxes = results.proposal_boxes
+ else:
+ output_boxes = None
+ assert output_boxes is not None, "Predictions must contain boxes!"
+
+ output_boxes.scale(scale_x, scale_y)
+ output_boxes.clip(results.image_size)
+
+ results = results[output_boxes.nonempty()]
+
+ if results.has("pred_masks"):
+ if isinstance(results.pred_masks, ROIMasks):
+ roi_masks = results.pred_masks
+ else:
+ # pred_masks is a tensor of shape (N, 1, M, M)
+ roi_masks = ROIMasks(results.pred_masks[:, 0, :, :])
+ results.pred_masks = roi_masks.to_bitmasks(
+ results.pred_boxes, output_height, output_width, mask_threshold
+ ).tensor # TODO return ROIMasks/BitMask object in the future
+
+ if results.has("pred_keypoints"):
+ results.pred_keypoints[:, :, 0] *= scale_x
+ results.pred_keypoints[:, :, 1] *= scale_y
+
+ return results
+
+def bbox_postprocess(result, input_size, img_size, output_height, output_width):
+ """
+ result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h]
+ """
+ if result is None:
+ return None
+
+ scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device)
+ result = result.sigmoid() * scale
+ x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2
+ h,w = img_size
+
+ x1 = x1.clamp(min=0, max=w)
+ y1 = y1.clamp(min=0, max=h)
+ x2 = x2.clamp(min=0, max=w)
+ y2 = y2.clamp(min=0, max=h)
+
+ box = torch.stack([x1,y1,x2,y2]).permute(1,0)
+ scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device)
+ box = box*scale
+ return box
+
+def sem_seg_postprocess(result, img_size, output_height, output_width):
+ """
+ Return semantic segmentation predictions in the original resolution.
+
+ The input images are often resized when entering semantic segmentor. Moreover, in same
+ cases, they also padded inside segmentor to be divisible by maximum network stride.
+ As a result, we often need the predictions of the segmentor in a different
+ resolution from its inputs.
+
+ Args:
+ result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W),
+ where C is the number of classes, and H, W are the height and width of the prediction.
+ img_size (tuple): image size that segmentor is taking as input.
+ output_height, output_width: the desired output resolution.
+
+ Returns:
+ semantic segmentation prediction (Tensor): A tensor of the shape
+ (C, output_height, output_width) that contains per-pixel soft predictions.
+ """
+ result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1)
+ result = F.interpolate(
+ result, size=(output_height, output_width), mode="bicubic", align_corners=False, antialias=True
+ )[0]
+ return result
diff --git a/modeling/utils/__init__.py b/modeling/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..807342aef0c01d0516b53fbd73cf5c877399ed73
--- /dev/null
+++ b/modeling/utils/__init__.py
@@ -0,0 +1,4 @@
+from .config import *
+from .misc import *
+from .interactive import *
+from .attention import *
\ No newline at end of file
diff --git a/modeling/utils/attention.py b/modeling/utils/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa3d7b04953b51b62b80317a3a2d56a11e06ad39
--- /dev/null
+++ b/modeling/utils/attention.py
@@ -0,0 +1,485 @@
+from typing import Callable, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch.nn import Parameter
+from torch.nn.modules.linear import Linear
+from torch.nn.init import xavier_uniform_, constant_
+from torch.overrides import (
+ has_torch_function, has_torch_function_unary, has_torch_function_variadic,
+ handle_torch_function)
+
+Tensor = torch.Tensor
+
+class _LinearWithBias(Linear):
+ bias: Tensor # type: ignore
+
+ def __init__(self, in_features: int, out_features: int) -> None:
+ super().__init__(in_features, out_features, bias=True) # type: ignore
+
+def multi_head_attention_forward(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Tensor,
+ in_proj_bias: Tensor,
+ bias_k: Optional[Tensor],
+ bias_v: Optional[Tensor],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Tensor,
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[Tensor] = None,
+ k_proj_weight: Optional[Tensor] = None,
+ v_proj_weight: Optional[Tensor] = None,
+ static_k: Optional[Tensor] = None,
+ static_v: Optional[Tensor] = None,
+) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ See "Attention Is All You Need" for more details.
+ embed_dim_to_check: total dimension of the model.
+ num_heads: parallel attention heads.
+ in_proj_weight, in_proj_bias: input projection weight and bias.
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
+ add_zero_attn: add a new batch of zeros to the key and
+ value sequences at dim=1.
+ dropout_p: probability of an element to be zeroed.
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
+ training: apply dropout if is ``True``.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
+ and value in different forms. If false, in_proj_weight will be used, which is
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
+ static_k, static_v: static key and value used for attention operators.
+
+
+ Shape:
+ Inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
+ will be unchanged. If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
+
+ Outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
+ if has_torch_function(tens_ops):
+ return handle_torch_function(
+ multi_head_attention_forward,
+ tens_ops,
+ query,
+ key,
+ value,
+ embed_dim_to_check,
+ num_heads,
+ in_proj_weight,
+ in_proj_bias,
+ bias_k,
+ bias_v,
+ add_zero_attn,
+ dropout_p,
+ out_proj_weight,
+ out_proj_bias,
+ training=training,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ use_separate_proj_weight=use_separate_proj_weight,
+ q_proj_weight=q_proj_weight,
+ k_proj_weight=k_proj_weight,
+ v_proj_weight=v_proj_weight,
+ static_k=static_k,
+ static_v=static_v,
+ )
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == embed_dim_to_check
+ # allow MHA to have different sizes for the feature dimension
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+ head_dim = embed_dim // num_heads
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
+ scaling = float(head_dim) ** -0.5
+
+ if not use_separate_proj_weight:
+ if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):
+ # self-attention
+ q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
+
+ elif key is value or torch.equal(key, value):
+ # encoder-decoder attention
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = F.linear(query, _w, _b)
+
+ if key is None:
+ assert value is None
+ k = None
+ v = None
+ else:
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
+
+ else:
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = F.linear(query, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = embed_dim * 2
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ k = F.linear(key, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim * 2
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ v = F.linear(value, _w, _b)
+ else:
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
+ len1, len2 = q_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == query.size(-1)
+
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
+ len1, len2 = k_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == key.size(-1)
+
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
+ len1, len2 = v_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == value.size(-1)
+
+ if in_proj_bias is not None:
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
+ else:
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
+ q = q * scaling
+
+ if attn_mask is not None:
+ assert (
+ attn_mask.dtype == torch.float32
+ or attn_mask.dtype == torch.float64
+ or attn_mask.dtype == torch.float16
+ or attn_mask.dtype == torch.uint8
+ or attn_mask.dtype == torch.bool
+ ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype)
+ if attn_mask.dtype == torch.uint8:
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
+ attn_mask = attn_mask.to(torch.bool)
+
+ if attn_mask.dim() == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
+ elif attn_mask.dim() == 3:
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
+ else:
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
+ # attn_mask's dim is 3 now.
+
+ # convert ByteTensor key_padding_mask to bool
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+ warnings.warn(
+ "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
+ )
+ key_padding_mask = key_padding_mask.to(torch.bool)
+
+ if bias_k is not None and bias_v is not None:
+ if static_k is None and static_v is None:
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = pad(key_padding_mask, (0, 1))
+ else:
+ assert static_k is None, "bias cannot be added to static key."
+ assert static_v is None, "bias cannot be added to static value."
+ else:
+ assert bias_k is None
+ assert bias_v is None
+
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+
+ if static_k is not None:
+ assert static_k.size(0) == bsz * num_heads
+ assert static_k.size(2) == head_dim
+ k = static_k
+
+ if static_v is not None:
+ assert static_v.size(0) == bsz * num_heads
+ assert static_v.size(2) == head_dim
+ v = static_v
+
+ src_len = k.size(1)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
+ if attn_mask is not None:
+ attn_mask = pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = pad(key_padding_mask, (0, 1))
+
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_output_weights.masked_fill_(attn_mask, float("-inf"))
+ else:
+ attn_output_weights += attn_mask
+
+ if key_padding_mask is not None:
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float("-inf"),
+ )
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
+
+ attn_output_weights = F.softmax(attn_output_weights, dim=-1).nan_to_num()
+ attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
+
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
+
+ if need_weights:
+ # average attention weights over heads
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
+ else:
+ return attn_output, None
+
+
+class MultiheadAttention(torch.nn.Module):
+ r"""Allows the model to jointly attend to information
+ from different representation subspaces.
+ See `Attention Is All You Need `_
+
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
+
+ Args:
+ embed_dim: total dimension of the model.
+ num_heads: parallel attention heads.
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+ bias: add bias as module parameter. Default: True.
+ add_bias_kv: add bias to the key and value sequences at dim=0.
+ add_zero_attn: add a new batch of zeros to the key and
+ value sequences at dim=1.
+ kdim: total number of features in key. Default: None.
+ vdim: total number of features in value. Default: None.
+
+ Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
+ to :attr:`embed_dim` such that query, key, and value have the same
+ number of features.
+
+ Examples::
+
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+ """
+ bias_k: Optional[torch.Tensor]
+ bias_v: Optional[torch.Tensor]
+
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+
+ if self._qkv_same_embed_dim is False:
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+ self.register_parameter('in_proj_weight', None)
+ else:
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
+ self.register_parameter('q_proj_weight', None)
+ self.register_parameter('k_proj_weight', None)
+ self.register_parameter('v_proj_weight', None)
+
+ if bias:
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+ self.out_proj = _LinearWithBias(embed_dim, embed_dim)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ if self._qkv_same_embed_dim:
+ xavier_uniform_(self.in_proj_weight)
+ else:
+ xavier_uniform_(self.q_proj_weight)
+ xavier_uniform_(self.k_proj_weight)
+ xavier_uniform_(self.v_proj_weight)
+
+ if self.in_proj_bias is not None:
+ constant_(self.in_proj_bias, 0.)
+ constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ xavier_normal_(self.bias_v)
+
+ def __setstate__(self, state):
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
+ if '_qkv_same_embed_dim' not in state:
+ state['_qkv_same_embed_dim'] = True
+
+ super(MultiheadAttention, self).__setstate__(state)
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ See "Attention Is All You Need" for more details.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. When given a binary mask and a value is True,
+ the corresponding value on the attention layer will be ignored. When given
+ a byte mask and a value is non-zero, the corresponding value on the attention
+ layer will be ignored
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+
+ Shapes for inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
+ source sequence length.
+
+ If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
+ length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
+ the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+
+ Shapes for outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+ if not self._qkv_same_embed_dim:
+ return multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
+ attn_mask=attn_mask, use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight)
+ else:
+ return multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
+ attn_mask=attn_mask)
\ No newline at end of file
diff --git a/modeling/utils/box_ops.py b/modeling/utils/box_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bae398cec9a8f9748c6763ddd72deae8a72a207
--- /dev/null
+++ b/modeling/utils/box_ops.py
@@ -0,0 +1,93 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Utilities for bounding box manipulation and GIoU.
+"""
+import torch
+from torchvision.ops.boxes import box_area
+
+
+def box_cxcywh_to_xyxy(x):
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+def box_xyxy_to_cxcywh(x):
+ x0, y0, x1, y1 = x.unbind(-1)
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
+ (x1 - x0), (y1 - y0)]
+ return torch.stack(b, dim=-1)
+
+def box_xywh_to_xyxy(x):
+ x0, y0, x1, y1 = x.unbind(-1)
+ b = [x0, y0, (x0 + x1), (y0 + y1)]
+ return torch.stack(b, dim=-1)
+
+
+# modified from torchvision to also return the union
+def box_iou(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / (union+1e-6)
+ return iou, union
+
+
+def generalized_box_iou(boxes1, boxes2):
+ """
+ Generalized IoU from https://giou.stanford.edu/
+
+ The boxes should be in [x0, y0, x1, y1] format
+
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
+ and M = len(boxes2)
+ """
+ # degenerate boxes gives inf / nan results
+ # so do an early check
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+ iou, union = box_iou(boxes1, boxes2)
+
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ area = wh[:, :, 0] * wh[:, :, 1]
+
+ return iou - (area - union) / (area+1e-6)
+
+
+def masks_to_boxes(masks):
+ """Compute the bounding boxes around the provided masks
+
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
+
+ Returns a [N, 4] tensors, with the boxes in xyxy format
+ """
+ if masks.numel() == 0:
+ return torch.zeros((0, 4), device=masks.device)
+
+ h, w = masks.shape[-2:]
+
+ y = torch.arange(0, h, dtype=torch.float)
+ x = torch.arange(0, w, dtype=torch.float)
+ y, x = torch.meshgrid(y, x)
+
+ x_mask = (masks * x.unsqueeze(0))
+ x_max = x_mask.flatten(1).max(-1)[0]
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ y_mask = (masks * y.unsqueeze(0))
+ y_max = y_mask.flatten(1).max(-1)[0]
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
\ No newline at end of file
diff --git a/modeling/utils/config.py b/modeling/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..766bb386498f0f034485a19027d5b30b0b6d20ff
--- /dev/null
+++ b/modeling/utils/config.py
@@ -0,0 +1,140 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import functools
+import inspect
+
+def configurable(init_func=None, *, from_config=None):
+ """
+ Decorate a function or a class's __init__ method so that it can be called
+ with a :class:`CfgNode` object using a :func:`from_config` function that translates
+ :class:`CfgNode` to arguments.
+
+ Examples:
+ ::
+ # Usage 1: Decorator on __init__:
+ class A:
+ @configurable
+ def __init__(self, a, b=2, c=3):
+ pass
+
+ @classmethod
+ def from_config(cls, cfg): # 'cfg' must be the first argument
+ # Returns kwargs to be passed to __init__
+ return {"a": cfg.A, "b": cfg.B}
+
+ a1 = A(a=1, b=2) # regular construction
+ a2 = A(cfg) # construct with a cfg
+ a3 = A(cfg, b=3, c=4) # construct with extra overwrite
+
+ # Usage 2: Decorator on any function. Needs an extra from_config argument:
+ @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
+ def a_func(a, b=2, c=3):
+ pass
+
+ a1 = a_func(a=1, b=2) # regular call
+ a2 = a_func(cfg) # call with a cfg
+ a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
+
+ Args:
+ init_func (callable): a class's ``__init__`` method in usage 1. The
+ class must have a ``from_config`` classmethod which takes `cfg` as
+ the first argument.
+ from_config (callable): the from_config function in usage 2. It must take `cfg`
+ as its first argument.
+ """
+
+ if init_func is not None:
+ assert (
+ inspect.isfunction(init_func)
+ and from_config is None
+ and init_func.__name__ == "__init__"
+ ), "Incorrect use of @configurable. Check API documentation for examples."
+
+ @functools.wraps(init_func)
+ def wrapped(self, *args, **kwargs):
+ try:
+ from_config_func = type(self).from_config
+ except AttributeError as e:
+ raise AttributeError(
+ "Class with @configurable must have a 'from_config' classmethod."
+ ) from e
+ if not inspect.ismethod(from_config_func):
+ raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
+
+ if _called_with_cfg(*args, **kwargs):
+ explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
+ init_func(self, **explicit_args)
+ else:
+ init_func(self, *args, **kwargs)
+
+ return wrapped
+
+ else:
+ if from_config is None:
+ return configurable # @configurable() is made equivalent to @configurable
+ assert inspect.isfunction(
+ from_config
+ ), "from_config argument of configurable must be a function!"
+
+ def wrapper(orig_func):
+ @functools.wraps(orig_func)
+ def wrapped(*args, **kwargs):
+ if _called_with_cfg(*args, **kwargs):
+ explicit_args = _get_args_from_config(from_config, *args, **kwargs)
+ return orig_func(**explicit_args)
+ else:
+ return orig_func(*args, **kwargs)
+
+ wrapped.from_config = from_config
+ return wrapped
+
+ return wrapper
+
+def _called_with_cfg(*args, **kwargs):
+ """
+ Returns:
+ bool: whether the arguments contain CfgNode and should be considered
+ forwarded to from_config.
+ """
+ from omegaconf import DictConfig
+
+ if len(args) and isinstance(args[0], (dict)):
+ return True
+ if isinstance(kwargs.pop("cfg", None), (dict)):
+ return True
+ # `from_config`'s first argument is forced to be "cfg".
+ # So the above check covers all cases.
+ return False
+
+def _get_args_from_config(from_config_func, *args, **kwargs):
+ """
+ Use `from_config` to obtain explicit arguments.
+
+ Returns:
+ dict: arguments to be used for cls.__init__
+ """
+ signature = inspect.signature(from_config_func)
+ if list(signature.parameters.keys())[0] != "cfg":
+ if inspect.isfunction(from_config_func):
+ name = from_config_func.__name__
+ else:
+ name = f"{from_config_func.__self__}.from_config"
+ raise TypeError(f"{name} must take 'cfg' as the first argument!")
+ support_var_arg = any(
+ param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
+ for param in signature.parameters.values()
+ )
+ if support_var_arg: # forward all arguments to from_config, if from_config accepts them
+ ret = from_config_func(*args, **kwargs)
+ else:
+ # forward supported arguments to from_config
+ supported_arg_names = set(signature.parameters.keys())
+ extra_kwargs = {}
+ for name in list(kwargs.keys()):
+ if name not in supported_arg_names:
+ extra_kwargs[name] = kwargs.pop(name)
+ ret = from_config_func(*args, **kwargs)
+ # forward the other arguments to __init__
+ ret.update(extra_kwargs)
+ return ret
\ No newline at end of file
diff --git a/modeling/utils/interactive.py b/modeling/utils/interactive.py
new file mode 100644
index 0000000000000000000000000000000000000000..164ddeaaf0f4e09187569b3f12fbbb68ea95dd12
--- /dev/null
+++ b/modeling/utils/interactive.py
@@ -0,0 +1,49 @@
+import os
+import copy
+import math
+
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+
+def rand_sample(x, divisor, max_len):
+ # non_zero_pos_point = [rand_sample((m.nonzero()/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
+ if len(x.nonzero()) == 0:
+ return x.nonzero().t()
+
+ non_zero_point_index = (x.nonzero()/divisor).t()
+ mask_ids = non_zero_point_index[0].unique().long()
+
+ # compute probability for each samle
+ probs = torch.zeros_like(non_zero_point_index[0])
+ for idx in mask_ids:
+ prob = 1./(len(mask_ids)*((non_zero_point_index[0:1]==idx).sum()))
+ probs[non_zero_point_index[0]==idx] = prob
+
+ indices = torch.multinomial(probs, num_samples=min(max_len, len(probs)), replacement=False).sort()[0]
+ non_zero_point_index = non_zero_point_index[:,indices]
+ return non_zero_point_index # [n, 512]
+
+def rand_sample_plain(x, max_len):
+ if x.shape[1] <= max_len:
+ return x
+ else:
+ rand_idx = torch.randperm(x.shape[1])[:max_len]
+ return x[:,rand_idx]
+
+def prepare_features(x, num_feature_levels, pe_layer, input_proj, level_embed):
+ src = []
+ pos = []
+ size_list = []
+
+ # disable mask, it does not affect performance
+ for i in range(num_feature_levels):
+ size_list.append(x[i].shape[-2:])
+ pos.append(pe_layer(x[i], None).flatten(2))
+ src.append(input_proj[i](x[i]).flatten(2) + level_embed.weight[i][None, :, None])
+
+ # flatten NxCxHxW to HWxNxC
+ pos[-1] = pos[-1].permute(2, 0, 1)
+ src[-1] = src[-1].permute(2, 0, 1)
+ return src, pos, size_list
\ No newline at end of file
diff --git a/modeling/utils/misc.py b/modeling/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f3fd93d9424ee23b9cb8851b322c0a923edcac6
--- /dev/null
+++ b/modeling/utils/misc.py
@@ -0,0 +1,328 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
+
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+from typing import List, Optional, Tuple, Any
+
+import torch
+import torchvision
+from torch import nn, Tensor, device
+import torch.distributed as dist
+import torch.nn.functional as F
+
+from detectron2.layers import cat, shapes_to_tensor
+
+from utilities.constants import *
+
+
+def pad_arbitrary_tensors(tensors, padding_value=0.):
+ max_len = torch.stack([torch.tensor(x.shape) for x in tensors]).max(dim=0)[0]
+ padded_tensor = torch.empty([len(tensors)] + max_len.tolist(), device=tensors[0].device).fill_(padding_value)
+ for i, x in enumerate(tensors):
+ padded_tensor[i, :x.shape[0], :x.shape[1]] = x
+ return padded_tensor
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], : img.shape[2]] = False
+ elif tensor_list[0].ndim == 2:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(txt.shape) for txt in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, l = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, l), dtype=torch.bool, device=device)
+ for txt, pad_txt, m in zip(tensor_list, tensor, mask):
+ pad_txt[: txt.shape[0], : txt.shape[1]] = txt
+ m[: txt.shape[1]] = False
+ else:
+ raise ValueError("not supported")
+ return NestedTensor(tensor, mask)
+
+def _collate_and_pad_divisibility(tensor_list: list, div=32):
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(
+ torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32)
+ ).to(torch.int64)
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ c,h,w = max_size
+ pad_h = (div - h % div) if h % div != 0 else 0
+ pad_w = (div - w % div) if w % div != 0 else 0
+ max_size = (c,h+pad_h,w+pad_w)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ return padded_imgs
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
+ ).to(torch.int64)
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ tensor = torch.stack(padded_imgs)
+ mask = torch.stack(padded_masks)
+
+ return NestedTensor(tensor, mask=mask)
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+# TODO: add background to
+def get_class_names(name):
+ if name is None:
+ return None
+ elif 'refcoco' in name:
+ return ["background"]
+ elif 'biomed' in name:
+ return BIOMED_CLASSES + ["background"]
+ elif 'med_sam' in name:
+ ### MedSAM class names
+ medsam_classes = ['liver', 'lung', 'pancreas', 'stomach', 'heart', 'gallbladder', 'prostate', 'brain ventricles', 'cerebellum',
+ 'left heart ventricle', 'right heart ventricle', 'vessel', 'polyp', 'surgical tool', 'pleural effusion', 'infection', 'gland', 'tumor']
+ return medsam_classes + ["background"]
+ elif 'coco' in name:
+ return COCO_PANOPTIC_CLASSES + ["background"]
+ elif 'ade20k_full' in name:
+ return ADE20K_847 + ["background"]
+ elif 'ade' in name:
+ return ADE_PANOPTIC_CLASSES + ["background"]
+ elif 'scannet_41' in name:
+ return SCAN_40 + ["background"]
+ elif 'scannet_21' in name:
+ return SCAN_20 + ["background"]
+ elif 'sun' in name:
+ return SUN_RGBD_37 + ["background"]
+ elif 'voc' in name:
+ return PASCAL_CLASSES + ["background"]
+ elif name == 'cityscapes_fine_sem_seg_val':
+ return CITYSCAPES + ["background"]
+ elif name == 'cityscapes_fine_instance_seg_val':
+ return CITYSCAPES_THING + ["background"]
+ elif name in ['cityscapes_fine_panoptic_val']:
+ return CITYSCAPES + ["background"]
+ elif name == 'bdd10k_val_sem_seg':
+ return BDD_SEM + ["background"]
+ elif name == 'bdd10k_40_panoptic_val':
+ return BDD_PANO + ["background"]
+ elif 'vlp' in name:
+ return ["background"]
+ else:
+ assert False, "text dataset name {} is not defined".format(name)
+
+def get_iou(gt_masks, pred_masks, ignore_label=-1):
+ rev_ignore_mask = ~(gt_masks == ignore_label)
+ gt_masks = gt_masks.bool()
+ n,h,w = gt_masks.shape
+ intersection = ((gt_masks & pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1)
+ union = ((gt_masks | pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1)
+ ious = (intersection / union)
+ return ious
+
+class Spatial_ImageList(object):
+ """
+ Structure that holds a list of images (of possibly
+ varying sizes) as a single tensor.
+ This works by padding the images to the same size.
+ The original sizes of each image is stored in `image_sizes`.
+
+ Attributes:
+ image_sizes (list[tuple[int, int]]): each tuple is (h, w).
+ During tracing, it becomes list[Tensor] instead.
+ """
+
+ def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]]):
+ """
+ Arguments:
+ tensor (Tensor): of shape (N, H, W) or (N, C_1, ..., C_K, H, W) where K >= 1
+ image_sizes (list[tuple[int, int]]): Each tuple is (h, w). It can
+ be smaller than (H, W) due to padding.
+ """
+ self.tensor = tensor
+ self.image_sizes = image_sizes
+
+ def __len__(self) -> int:
+ return len(self.image_sizes)
+
+ def __getitem__(self, idx) -> torch.Tensor:
+ """
+ Access the individual image in its original size.
+
+ Args:
+ idx: int or slice
+
+ Returns:
+ Tensor: an image of shape (H, W) or (C_1, ..., C_K, H, W) where K >= 1
+ """
+ size = self.image_sizes[idx]
+ return self.tensor[idx, ..., : size[0], : size[1]]
+
+ @torch.jit.unused
+ def to(self, *args: Any, **kwargs: Any) -> "Spatial_ImageList":
+ cast_tensor = self.tensor.to(*args, **kwargs)
+ return Spatial_ImageList(cast_tensor, self.image_sizes)
+
+ @property
+ def device(self) -> device:
+ return self.tensor.device
+
+ @staticmethod
+ def from_tensors(
+ tensors: List[torch.Tensor], size_divisibility: int = 0, pad_value: float = 0.0
+ ) -> "Spatial_ImageList":
+ """
+ Args:
+ tensors: a tuple or list of `torch.Tensor`, each of shape (Hi, Wi) or
+ (C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded
+ to the same shape with `pad_value`.
+ size_divisibility (int): If `size_divisibility > 0`, add padding to ensure
+ the common height and width is divisible by `size_divisibility`.
+ This depends on the model and many models need a divisibility of 32.
+ pad_value (float): value to pad
+
+ Returns:
+ an `Spatial_ImageList`.
+ """
+ assert len(tensors) > 0
+ assert isinstance(tensors, (tuple, list))
+ for t in tensors:
+ assert isinstance(t, torch.Tensor), type(t)
+
+ image_sizes = [(im.shape[-3], im.shape[-2], im.shape[-1]) for im in tensors]
+
+ image_sizes_tensor = [shapes_to_tensor(x) for x in image_sizes]
+ max_size = torch.stack(image_sizes_tensor).max(0).values
+
+ if size_divisibility > 1:
+ stride = size_divisibility
+ # the last two dims are H,W, both subject to divisibility requirement
+ max_size[-2:] = (max_size[-2:] + (stride - 1)).div(stride, rounding_mode="floor") * stride
+
+ # handle weirdness of scripting and tracing ...
+ if torch.jit.is_scripting():
+ max_size: List[int] = max_size.to(dtype=torch.long).tolist()
+ else:
+ if torch.jit.is_tracing():
+ image_sizes = image_sizes_tensor
+
+ if len(tensors) == 1:
+ # This seems slightly (2%) faster.
+ # TODO: check whether it's faster for multiple images as well
+ image_size = image_sizes[0]
+ padding_size = [0, max_size[-1] - image_size[2], 0, max_size[-2] - image_size[1]]
+ batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0)
+ else:
+ # max_size can be a tensor in tracing mode, therefore convert to list
+ batch_shape = [len(tensors)] + list(tensors[0].shape[:-3]) + list(max_size)
+ batched_imgs = tensors[0].new_full(batch_shape, pad_value)
+ for img, pad_img in zip(tensors, batched_imgs):
+ pad_img[:img.shape[-3],:img.shape[-2],:img.shape[-1]].copy_(img)
+
+ return Spatial_ImageList(batched_imgs.contiguous(), image_sizes)
\ No newline at end of file
diff --git a/modeling/vision/backbone/__init__.py b/modeling/vision/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..50f543f58aa1e6331fd07e545d7590048e1ea03a
--- /dev/null
+++ b/modeling/vision/backbone/__init__.py
@@ -0,0 +1,14 @@
+from .focal import *
+from .focal_dw import *
+from .davit import *
+from .vit import *
+from .backbone import *
+from .build import *
+
+
+def build_backbone(config, **kwargs):
+ model_name = config['MODEL']['BACKBONE']['NAME']
+ if not is_model(model_name):
+ raise ValueError(f'Unkown model: {model_name}')
+
+ return model_entrypoints(model_name)(config, **kwargs)
\ No newline at end of file
diff --git a/modeling/vision/backbone/backbone.py b/modeling/vision/backbone/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d97ac219a1a22ffd24aafb542526b111651d046
--- /dev/null
+++ b/modeling/vision/backbone/backbone.py
@@ -0,0 +1,53 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import torch.nn as nn
+
+from detectron2.modeling import ShapeSpec
+
+# from ..layers import ShapeSpec
+
+__all__ = ["Backbone"]
+
+
+class Backbone(nn.Module):
+ """
+ Abstract base class for network backbones.
+ """
+
+ def __init__(self):
+ """
+ The `__init__` method of any subclass can specify its own set of arguments.
+ """
+ super().__init__()
+
+ def forward(self):
+ """
+ Subclasses must override this method, but adhere to the same return type.
+
+ Returns:
+ dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
+ """
+ pass
+
+ @property
+ def size_divisibility(self) -> int:
+ """
+ Some backbones require the input height and width to be divisible by a
+ specific integer. This is typically true for encoder / decoder type networks
+ with lateral connection (e.g., FPN) for which feature maps need to match
+ dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
+ input size divisibility is required.
+ """
+ return 0
+
+ def output_shape(self):
+ """
+ Returns:
+ dict[str->ShapeSpec]
+ """
+ # this is a backward-compatible default
+ return {
+ name: ShapeSpec(
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+ )
+ for name in self._out_features
+ }
diff --git a/modeling/vision/backbone/build.py b/modeling/vision/backbone/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..65c5f809acf8008fcf11be46be90608bfd819f0d
--- /dev/null
+++ b/modeling/vision/backbone/build.py
@@ -0,0 +1,14 @@
+_model_entrypoints = {}
+
+
+def register_backbone(fn):
+ module_name_split = fn.__module__.split('.')
+ model_name = module_name_split[-1]
+ _model_entrypoints[model_name] = fn
+ return fn
+
+def model_entrypoints(model_name):
+ return _model_entrypoints[model_name]
+
+def is_model(model_name):
+ return model_name in _model_entrypoints
\ No newline at end of file
diff --git a/modeling/vision/backbone/common.py b/modeling/vision/backbone/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96
--- /dev/null
+++ b/modeling/vision/backbone/common.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from typing import Type
+
+
+class MLPBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ mlp_dim: int,
+ act: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ super().__init__()
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+ self.act = act()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.lin2(self.act(self.lin1(x)))
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
+class LayerNorm2d(nn.Module):
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
diff --git a/modeling/vision/backbone/davit.py b/modeling/vision/backbone/davit.py
new file mode 100644
index 0000000000000000000000000000000000000000..976448f423b893f903b89e2eda6b90b4c015cd8a
--- /dev/null
+++ b/modeling/vision/backbone/davit.py
@@ -0,0 +1,624 @@
+import os
+import itertools
+import logging
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from collections import OrderedDict
+
+from einops import rearrange
+from timm.models.layers import DropPath, trunc_normal_
+
+from detectron2.utils.file_io import PathManager
+from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
+
+from .build import register_backbone
+
+logger = logging.getLogger(__name__)
+
+
+
+class MySequential(nn.Sequential):
+ def forward(self, *inputs):
+ for module in self._modules.values():
+ if type(inputs) == tuple:
+ inputs = module(*inputs)
+ else:
+ inputs = module(inputs)
+ return inputs
+
+
+class PreNorm(nn.Module):
+ def __init__(self, norm, fn, drop_path=None):
+ super().__init__()
+ self.norm = norm
+ self.fn = fn
+ self.drop_path = drop_path
+
+ def forward(self, x, *args, **kwargs):
+ shortcut = x
+ if self.norm != None:
+ x, size = self.fn(self.norm(x), *args, **kwargs)
+ else:
+ x, size = self.fn(x, *args, **kwargs)
+
+ if self.drop_path:
+ x = self.drop_path(x)
+
+ x = shortcut + x
+
+ return x, size
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.net = nn.Sequential(OrderedDict([
+ ("fc1", nn.Linear(in_features, hidden_features)),
+ ("act", act_layer()),
+ ("fc2", nn.Linear(hidden_features, out_features))
+ ]))
+
+ def forward(self, x, size):
+ return self.net(x), size
+
+
+class DepthWiseConv2d(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ kernel_size,
+ padding,
+ stride,
+ bias=True,
+ ):
+ super().__init__()
+ self.dw = nn.Conv2d(
+ dim_in, dim_in,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=dim_in,
+ stride=stride,
+ bias=bias
+ )
+
+ def forward(self, x, size):
+ B, N, C = x.shape
+ H, W = size
+ assert N == H * W
+
+ x = self.dw(x.transpose(1, 2).view(B, C, H, W))
+ size = (x.size(-2), x.size(-1))
+ x = x.flatten(2).transpose(1, 2)
+ return x, size
+
+
+class ConvEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(
+ self,
+ patch_size=7,
+ in_chans=3,
+ embed_dim=64,
+ stride=4,
+ padding=2,
+ norm_layer=None,
+ pre_norm=True
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim,
+ kernel_size=patch_size,
+ stride=stride,
+ padding=padding
+ )
+
+ dim_norm = in_chans if pre_norm else embed_dim
+ self.norm = norm_layer(dim_norm) if norm_layer else None
+
+ self.pre_norm = pre_norm
+
+ def forward(self, x, size):
+ H, W = size
+ if len(x.size()) == 3:
+ if self.norm and self.pre_norm:
+ x = self.norm(x)
+ x = rearrange(
+ x, 'b (h w) c -> b c h w',
+ h=H, w=W
+ )
+
+ x = self.proj(x)
+
+ _, _, H, W = x.shape
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ if self.norm and not self.pre_norm:
+ x = self.norm(x)
+
+ return x, (H, W)
+
+
+class ChannelAttention(nn.Module):
+
+ def __init__(self, dim, groups=8, qkv_bias=True):
+ super().__init__()
+
+ self.groups = groups
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x, size):
+ B, N, C = x.shape
+
+ qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * (N ** -0.5)
+ attention = q.transpose(-1, -2) @ k
+ attention = attention.softmax(dim=-1)
+ x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ return x, size
+
+
+class ChannelBlock(nn.Module):
+
+ def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True,
+ drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ conv_at_attn=True, conv_at_ffn=True):
+ super().__init__()
+
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
+ self.channel_attn = PreNorm(
+ norm_layer(dim),
+ ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
+ drop_path
+ )
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
+ self.ffn = PreNorm(
+ norm_layer(dim),
+ Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
+ drop_path
+ )
+
+ def forward(self, x, size):
+ if self.conv1:
+ x, size = self.conv1(x, size)
+ x, size = self.channel_attn(x, size)
+
+ if self.conv2:
+ x, size = self.conv2(x, size)
+ x, size = self.ffn(x, size)
+
+ return x, size
+
+
+def window_partition(x, window_size: int):
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size: int, H: int, W: int):
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ def __init__(self, dim, num_heads, window_size, qkv_bias=True):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, size):
+
+ H, W = size
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ x = window_partition(x, self.window_size)
+ x = x.view(-1, self.window_size * self.window_size, C)
+
+ # W-MSA/SW-MSA
+ # attn_windows = self.attn(x_windows)
+
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+ attn = self.softmax(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+
+ # merge windows
+ x = x.view(
+ -1, self.window_size, self.window_size, C
+ )
+ x = window_reverse(x, self.window_size, Hp, Wp)
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ return x, size
+
+
+class SpatialBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
+ super().__init__()
+
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
+ self.window_attn = PreNorm(
+ norm_layer(dim),
+ WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
+ drop_path
+ )
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
+ self.ffn = PreNorm(
+ norm_layer(dim),
+ Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
+ drop_path
+ )
+
+ def forward(self, x, size):
+ if self.conv1:
+ x, size = self.conv1(x, size)
+ x, size = self.window_attn(x, size)
+
+ if self.conv2:
+ x, size = self.conv2(x, size)
+ x, size = self.ffn(x, size)
+ return x, size
+
+
+class DaViT(nn.Module):
+ """ DaViT: Dual-Attention Transformer
+
+ Args:
+ img_size (int): Image size, Default: 224.
+ in_chans (int): Number of input image channels. Default: 3.
+ num_classes (int): Number of classes for classification head. Default: 1000.
+ patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2).
+ patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2).
+ patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0).
+ patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False).
+ embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256).
+ num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16).
+ num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16).
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ enable_checkpoint (bool): If True, enable checkpointing. Default: False.
+ conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True.
+ conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True.
+ """
+
+ def __init__(
+ self,
+ img_size=224,
+ in_chans=3,
+ num_classes=1000,
+ depths=(1, 1, 3, 1),
+ patch_size=(7, 2, 2, 2),
+ patch_stride=(4, 2, 2, 2),
+ patch_padding=(3, 0, 0, 0),
+ patch_prenorm=(False, False, False, False),
+ embed_dims=(64, 128, 192, 256),
+ num_heads=(3, 6, 12, 24),
+ num_groups=(3, 6, 12, 24),
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ enable_checkpoint=False,
+ conv_at_attn=True,
+ conv_at_ffn=True,
+ out_indices=[],
+ ):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.num_groups = num_groups
+ self.num_stages = len(self.embed_dims)
+ self.enable_checkpoint = enable_checkpoint
+ assert self.num_stages == len(self.num_heads) == len(self.num_groups)
+
+ num_stages = len(embed_dims)
+ self.img_size = img_size
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)]
+
+
+ depth_offset = 0
+ convs = []
+ blocks = []
+ for i in range(num_stages):
+ conv_embed = ConvEmbed(
+ patch_size=patch_size[i],
+ stride=patch_stride[i],
+ padding=patch_padding[i],
+ in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
+ embed_dim=self.embed_dims[i],
+ norm_layer=norm_layer,
+ pre_norm=patch_prenorm[i]
+ )
+ convs.append(conv_embed)
+
+ print(f'=> Depth offset in stage {i}: {depth_offset}')
+ block = MySequential(
+ *[
+ MySequential(OrderedDict([
+ (
+ 'spatial_block', SpatialBlock(
+ embed_dims[i],
+ num_heads[i],
+ window_size,
+ drop_path_rate=dpr[depth_offset+j*2],
+ qkv_bias=qkv_bias,
+ mlp_ratio=mlp_ratio,
+ conv_at_attn=conv_at_attn,
+ conv_at_ffn=conv_at_ffn,
+ )
+ ),
+ (
+ 'channel_block', ChannelBlock(
+ embed_dims[i],
+ num_groups[i],
+ drop_path_rate=dpr[depth_offset+j*2+1],
+ qkv_bias=qkv_bias,
+ mlp_ratio=mlp_ratio,
+ conv_at_attn=conv_at_attn,
+ conv_at_ffn=conv_at_ffn,
+ )
+ )
+ ])) for j in range(depths[i])
+ ]
+ )
+ blocks.append(block)
+ depth_offset += depths[i]*2
+
+ self.convs = nn.ModuleList(convs)
+ self.blocks = nn.ModuleList(blocks)
+
+ self.out_indices = out_indices
+ # self.norms = norm_layer(self.embed_dims[-1])
+ # self.avgpool = nn.AdaptiveAvgPool1d(1)
+ # self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
+ self.apply(self._init_weights)
+
+ @property
+ def dim_out(self):
+ return self.embed_dims[-1]
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, std=0.02)
+ for name, _ in m.named_parameters():
+ if name in ['bias']:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.weight, 1.0)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1.0)
+ nn.init.constant_(m.bias, 0)
+
+ def _try_remap_keys(self, pretrained_dict):
+ remap_keys = {
+ "conv_embeds": "convs",
+ "main_blocks": "blocks",
+ "0.cpe.0.proj": "spatial_block.conv1.fn.dw",
+ "0.attn": "spatial_block.window_attn.fn",
+ "0.cpe.1.proj": "spatial_block.conv2.fn.dw",
+ "0.mlp": "spatial_block.ffn.fn.net",
+ "1.cpe.0.proj": "channel_block.conv1.fn.dw",
+ "1.attn": "channel_block.channel_attn.fn",
+ "1.cpe.1.proj": "channel_block.conv2.fn.dw",
+ "1.mlp": "channel_block.ffn.fn.net",
+ "0.norm1": "spatial_block.window_attn.norm",
+ "0.norm2": "spatial_block.ffn.norm",
+ "1.norm1": "channel_block.channel_attn.norm",
+ "1.norm2": "channel_block.ffn.norm"
+ }
+
+ full_key_mappings = {}
+ for k in pretrained_dict.keys():
+ old_k = k
+ for remap_key in remap_keys.keys():
+ if remap_key in k:
+ print(f'=> Repace {remap_key} with {remap_keys[remap_key]}')
+ k = k.replace(remap_key, remap_keys[remap_key])
+
+ full_key_mappings[old_k] = k
+
+ return full_key_mappings
+
+ def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True):
+ model_dict = self.state_dict()
+ stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x
+ full_key_mappings = self._try_remap_keys(pretrained_dict)
+
+ pretrained_dict = {
+ stripped_key(full_key_mappings[k]): v for k, v in pretrained_dict.items()
+ if stripped_key(full_key_mappings[k]) in model_dict.keys()
+ }
+ need_init_state_dict = {}
+ for k, v in pretrained_dict.items():
+ need_init = (
+ k.split('.')[0] in pretrained_layers
+ or pretrained_layers[0] == '*'
+ )
+ if need_init:
+ if verbose:
+ print(f'=> init {k} from pretrained state dict')
+
+ need_init_state_dict[k] = v
+ self.load_state_dict(need_init_state_dict, strict=False)
+
+ def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
+ if os.path.isfile(pretrained):
+ print(f'=> loading pretrained model {pretrained}')
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
+
+ self.from_state_dict(pretrained_dict, pretrained_layers, verbose)
+
+ def forward_features(self, x):
+ input_size = (x.size(2), x.size(3))
+
+ outs = {}
+ for i, (conv, block) in enumerate(zip(self.convs, self.blocks)):
+ x, input_size = conv(x, input_size)
+ if self.enable_checkpoint:
+ x, input_size = checkpoint.checkpoint(block, x, input_size)
+ else:
+ x, input_size = block(x, input_size)
+ if i in self.out_indices:
+ out = x.view(-1, *input_size, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
+ outs["res{}".format(i + 2)] = out
+
+ if len(self.out_indices) == 0:
+ outs["res5"] = x.view(-1, *input_size, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
+
+ return outs
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ # x = self.head(x)
+ return x
+
+class D2DaViT(DaViT, Backbone):
+ def __init__(self, cfg, input_shape):
+
+ spec = cfg['BACKBONE']['DAVIT']
+
+ super().__init__(
+ num_classes=0,
+ depths=spec['DEPTHS'],
+ embed_dims=spec['DIM_EMBED'],
+ num_heads=spec['NUM_HEADS'],
+ num_groups=spec['NUM_GROUPS'],
+ patch_size=spec['PATCH_SIZE'],
+ patch_stride=spec['PATCH_STRIDE'],
+ patch_padding=spec['PATCH_PADDING'],
+ patch_prenorm=spec['PATCH_PRENORM'],
+ drop_path_rate=spec['DROP_PATH_RATE'],
+ img_size=input_shape,
+ window_size=spec.get('WINDOW_SIZE', 7),
+ enable_checkpoint=spec.get('ENABLE_CHECKPOINT', False),
+ conv_at_attn=spec.get('CONV_AT_ATTN', True),
+ conv_at_ffn=spec.get('CONV_AT_FFN', True),
+ out_indices=spec.get('OUT_INDICES', []),
+ )
+
+ self._out_features = cfg['BACKBONE']['DAVIT']['OUT_FEATURES']
+
+ self._out_feature_strides = {
+ "res2": 4,
+ "res3": 8,
+ "res4": 16,
+ "res5": 32,
+ }
+ self._out_feature_channels = {
+ "res2": self.embed_dims[0],
+ "res3": self.embed_dims[1],
+ "res4": self.embed_dims[2],
+ "res5": self.embed_dims[3],
+ }
+
+ def forward(self, x):
+ """
+ Args:
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+ Returns:
+ dict[str->Tensor]: names and the corresponding features
+ """
+ assert (
+ x.dim() == 4
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+ outputs = {}
+ y = super().forward(x)
+
+ for k in y.keys():
+ if k in self._out_features:
+ outputs[k] = y[k]
+ return outputs
+
+ def output_shape(self):
+ return {
+ name: ShapeSpec(
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+ )
+ for name in self._out_features
+ }
+
+ @property
+ def size_divisibility(self):
+ return 32
+
+@register_backbone
+def get_davit_backbone(cfg):
+ davit = D2DaViT(cfg['MODEL'], 224)
+
+ if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
+ filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
+ logger.info(f'=> init from {filename}')
+ davit.from_pretrained(
+ filename,
+ cfg['MODEL']['BACKBONE']['DAVIT'].get('PRETRAINED_LAYERS', ['*']),
+ cfg['VERBOSE'])
+
+ return davit
\ No newline at end of file
diff --git a/modeling/vision/backbone/focal.py b/modeling/vision/backbone/focal.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb03baa4d0d4e0289d33035f2aa991473572e4e8
--- /dev/null
+++ b/modeling/vision/backbone/focal.py
@@ -0,0 +1,692 @@
+# --------------------------------------------------------
+# FocalNet for Semantic Segmentation
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Jianwei Yang
+# --------------------------------------------------------
+import math
+import time
+import numpy as np
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from detectron2.utils.file_io import PathManager
+from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
+
+from .build import register_backbone
+
+logger = logging.getLogger(__name__)
+
+class Mlp(nn.Module):
+ """ Multilayer perceptron."""
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+class FocalModulation(nn.Module):
+ """ Focal Modulation
+
+ Args:
+ dim (int): Number of input channels.
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ focal_level (int): Number of focal levels
+ focal_window (int): Focal window size at focal level 1
+ focal_factor (int, default=2): Step to increase the focal window
+ use_postln (bool, default=False): Whether use post-modulation layernorm
+ """
+
+ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False):
+
+ super().__init__()
+ self.dim = dim
+
+ # specific args for focalv3
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+ self.focal_factor = focal_factor
+ self.use_postln_in_modulation = use_postln_in_modulation
+ self.scaling_modulator = scaling_modulator
+
+ self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)
+
+ self.act = nn.GELU()
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.focal_layers = nn.ModuleList()
+
+ if self.use_postln_in_modulation:
+ self.ln = nn.LayerNorm(dim)
+
+ for k in range(self.focal_level):
+ kernel_size = self.focal_factor*k + self.focal_window
+ self.focal_layers.append(
+ nn.Sequential(
+ nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim,
+ padding=kernel_size//2, bias=False),
+ nn.GELU(),
+ )
+ )
+
+ def forward(self, x):
+ """ Forward function.
+
+ Args:
+ x: input features with shape of (B, H, W, C)
+ """
+ B, nH, nW, C = x.shape
+ x = self.f(x)
+ x = x.permute(0, 3, 1, 2).contiguous()
+ q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)
+
+ ctx_all = 0
+ for l in range(self.focal_level):
+ ctx = self.focal_layers[l](ctx)
+ ctx_all = ctx_all + ctx*gates[:, l:l+1]
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
+ ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]
+
+ if self.scaling_modulator:
+ ctx_all = ctx_all / (self.focal_level + 1)
+
+ x_out = q * self.h(ctx_all)
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
+ if self.use_postln_in_modulation:
+ x_out = self.ln(x_out)
+ x_out = self.proj(x_out)
+ x_out = self.proj_drop(x_out)
+ return x_out
+
+class FocalModulationBlock(nn.Module):
+ """ Focal Modulation Block.
+
+ Args:
+ dim (int): Number of input channels.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ drop (float, optional): Dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ focal_level (int): number of focal levels
+ focal_window (int): focal kernel size at level 1
+ """
+
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ focal_level=2, focal_window=9,
+ use_postln=False, use_postln_in_modulation=False,
+ scaling_modulator=False,
+ use_layerscale=False,
+ layerscale_value=1e-4):
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.focal_window = focal_window
+ self.focal_level = focal_level
+ self.use_postln = use_postln
+ self.use_layerscale = use_layerscale
+
+ self.norm1 = norm_layer(dim)
+ self.modulation = FocalModulation(
+ dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.H = None
+ self.W = None
+
+ self.gamma_1 = 1.0
+ self.gamma_2 = 1.0
+ if self.use_layerscale:
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, x):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ if not self.use_postln:
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # FM
+ x = self.modulation(x).view(B, H * W, C)
+ if self.use_postln:
+ x = self.norm1(x)
+
+ # FFN
+ x = shortcut + self.drop_path(self.gamma_1 * x)
+
+ if self.use_postln:
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
+ else:
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+
+ return x
+
+class BasicLayer(nn.Module):
+ """ A basic focal modulation layer for one stage.
+
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ drop (float, optional): Dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ focal_level (int): Number of focal levels
+ focal_window (int): Focal window size at focal level 1
+ use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ """
+
+ def __init__(self,
+ dim,
+ depth,
+ mlp_ratio=4.,
+ drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ focal_window=9,
+ focal_level=2,
+ use_conv_embed=False,
+ use_postln=False,
+ use_postln_in_modulation=False,
+ scaling_modulator=False,
+ use_layerscale=False,
+ use_checkpoint=False
+ ):
+ super().__init__()
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ FocalModulationBlock(
+ dim=dim,
+ mlp_ratio=mlp_ratio,
+ drop=drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ focal_window=focal_window,
+ focal_level=focal_level,
+ use_postln=use_postln,
+ use_postln_in_modulation=use_postln_in_modulation,
+ scaling_modulator=scaling_modulator,
+ use_layerscale=use_layerscale,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ patch_size=2,
+ in_chans=dim, embed_dim=2*dim,
+ use_conv_embed=use_conv_embed,
+ norm_layer=norm_layer,
+ is_stem=False
+ )
+
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ if self.downsample is not None:
+ x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)
+ x_down = self.downsample(x_reshaped)
+ x_down = x_down.flatten(2).transpose(1, 2)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False
+ is_stem (bool): Is the stem block or not.
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ if use_conv_embed:
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
+ if is_stem:
+ kernel_size = 7; padding = 2; stride = 4
+ else:
+ kernel_size = 3; padding = 1; stride = 2
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+ else:
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class FocalNet(nn.Module):
+ """ FocalNet backbone.
+
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ drop_rate (float): Dropout rate.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ focal_levels (Sequence[int]): Number of focal levels at four stages
+ focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ pretrain_img_size=1600,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ mlp_ratio=4.,
+ drop_rate=0.,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ patch_norm=True,
+ out_indices=[0, 1, 2, 3],
+ frozen_stages=-1,
+ focal_levels=[2,2,2,2],
+ focal_windows=[9,9,9,9],
+ use_conv_embed=False,
+ use_postln=False,
+ use_postln_in_modulation=False,
+ scaling_modulator=False,
+ use_layerscale=False,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None,
+ use_conv_embed=use_conv_embed, is_stem=True)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ mlp_ratio=mlp_ratio,
+ drop=drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
+ focal_window=focal_windows[i_layer],
+ focal_level=focal_levels[i_layer],
+ use_conv_embed=use_conv_embed,
+ use_postln=use_postln,
+ use_postln_in_modulation=use_postln_in_modulation,
+ scaling_modulator=scaling_modulator,
+ use_layerscale=use_layerscale,
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ if isinstance(pretrained, str):
+ self.apply(_init_weights)
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ self.apply(_init_weights)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):
+ model_dict = self.state_dict()
+
+ missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict]
+ logger.info(f'=> Missed keys {missed_dict}')
+ unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict]
+ logger.info(f'=> Unexpected keys {unexpected_dict}')
+
+ pretrained_dict = {
+ k: v for k, v in pretrained_dict.items()
+ if k in model_dict.keys()
+ }
+
+ need_init_state_dict = {}
+ for k, v in pretrained_dict.items():
+ need_init = (
+ (
+ k.split('.')[0] in pretrained_layers
+ or pretrained_layers[0] == '*'
+ )
+ and 'relative_position_index' not in k
+ and 'attn_mask' not in k
+ )
+
+ if need_init:
+ # if verbose:
+ # logger.info(f'=> init {k} from {pretrained}')
+
+ if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size():
+ table_pretrained = v
+ table_current = model_dict[k]
+ fsize1 = table_pretrained.shape[2]
+ fsize2 = table_current.shape[2]
+
+ # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv
+ if fsize1 < fsize2:
+ table_pretrained_resized = torch.zeros(table_current.shape)
+ table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained
+ v = table_pretrained_resized
+ elif fsize1 > fsize2:
+ table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2]
+ v = table_pretrained_resized
+
+
+ if ("modulation.f" in k or "pre_conv" in k):
+ table_pretrained = v
+ table_current = model_dict[k]
+ if table_pretrained.shape != table_current.shape:
+ if len(table_pretrained.shape) == 2:
+ dim = table_pretrained.shape[1]
+ assert table_current.shape[1] == dim
+ L1 = table_pretrained.shape[0]
+ L2 = table_current.shape[0]
+
+ if L1 < L2:
+ table_pretrained_resized = torch.zeros(table_current.shape)
+ # copy for linear project
+ table_pretrained_resized[:2*dim] = table_pretrained[:2*dim]
+ # copy for global token gating
+ table_pretrained_resized[-1] = table_pretrained[-1]
+ # copy for first multiple focal levels
+ table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]
+ # reassign pretrained weights
+ v = table_pretrained_resized
+ elif L1 > L2:
+ raise NotImplementedError
+ elif len(table_pretrained.shape) == 1:
+ dim = table_pretrained.shape[0]
+ L1 = table_pretrained.shape[0]
+ L2 = table_current.shape[0]
+ if L1 < L2:
+ table_pretrained_resized = torch.zeros(table_current.shape)
+ # copy for linear project
+ table_pretrained_resized[:dim] = table_pretrained[:dim]
+ # copy for global token gating
+ table_pretrained_resized[-1] = table_pretrained[-1]
+ # copy for first multiple focal levels
+ # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]
+ # reassign pretrained weights
+ v = table_pretrained_resized
+ elif L1 > L2:
+ raise NotImplementedError
+
+ need_init_state_dict[k] = v
+
+ self.load_state_dict(need_init_state_dict, strict=False)
+
+
+ def forward(self, x):
+ """Forward function."""
+ tic = time.time()
+ x = self.patch_embed(x)
+ Wh, Ww = x.size(2), x.size(3)
+
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = {}
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs["res{}".format(i + 2)] = out
+
+ if len(self.out_indices) == 0:
+ outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+
+ toc = time.time()
+ return outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(FocalNet, self).train(mode)
+ self._freeze_stages()
+
+
+class D2FocalNet(FocalNet, Backbone):
+ def __init__(self, cfg, input_shape):
+
+ pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE']
+ patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE']
+ in_chans = 3
+ embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM']
+ depths = cfg['BACKBONE']['FOCAL']['DEPTHS']
+ mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO']
+ drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE']
+ drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE']
+ norm_layer = nn.LayerNorm
+ patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM']
+ use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT']
+ out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES']
+ scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False)
+
+ super().__init__(
+ pretrain_img_size,
+ patch_size,
+ in_chans,
+ embed_dim,
+ depths,
+ mlp_ratio,
+ drop_rate,
+ drop_path_rate,
+ norm_layer,
+ patch_norm,
+ out_indices,
+ focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'],
+ focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'],
+ use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'],
+ use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'],
+ use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'],
+ scaling_modulator=scaling_modulator,
+ use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'],
+ use_checkpoint=use_checkpoint,
+ )
+
+ self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES']
+
+ self._out_feature_strides = {
+ "res2": 4,
+ "res3": 8,
+ "res4": 16,
+ "res5": 32,
+ }
+ self._out_feature_channels = {
+ "res2": self.num_features[0],
+ "res3": self.num_features[1],
+ "res4": self.num_features[2],
+ "res5": self.num_features[3],
+ }
+
+ def forward(self, x):
+ """
+ Args:
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+ Returns:
+ dict[str->Tensor]: names and the corresponding features
+ """
+ assert (
+ x.dim() == 4
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+ outputs = {}
+ y = super().forward(x)
+ for k in y.keys():
+ if k in self._out_features:
+ outputs[k] = y[k]
+ return outputs
+
+ def output_shape(self):
+ return {
+ name: ShapeSpec(
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+ )
+ for name in self._out_features
+ }
+
+ @property
+ def size_divisibility(self):
+ return 32
+
+@register_backbone
+def get_focal_backbone(cfg):
+ focal = D2FocalNet(cfg['MODEL'], 224)
+
+ if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
+ filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
+ logger.info(f'=> init from {filename}')
+ with PathManager.open(filename, "rb") as f:
+ ckpt = torch.load(f)['model']
+ focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])
+
+ return focal
\ No newline at end of file
diff --git a/modeling/vision/backbone/focal_dw.py b/modeling/vision/backbone/focal_dw.py
new file mode 100644
index 0000000000000000000000000000000000000000..a54c2116fd4780d948396b6dd72840421266cd50
--- /dev/null
+++ b/modeling/vision/backbone/focal_dw.py
@@ -0,0 +1,789 @@
+# --------------------------------------------------------
+# FocalNet for Semantic Segmentation
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Jianwei Yang
+# --------------------------------------------------------
+import math
+import time
+import numpy as np
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from detectron2.utils.file_io import PathManager
+from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
+
+from .build import register_backbone
+
+logger = logging.getLogger(__name__)
+
+class Mlp(nn.Module):
+ """ Multilayer perceptron."""
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+class FocalModulation(nn.Module):
+ """ Focal Modulation
+
+ Args:
+ dim (int): Number of input channels.
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ focal_level (int): Number of focal levels
+ focal_window (int): Focal window size at focal level 1
+ focal_factor (int, default=2): Step to increase the focal window
+ use_postln (bool, default=False): Whether use post-modulation layernorm
+ """
+
+ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False):
+
+ super().__init__()
+ self.dim = dim
+
+ # specific args for focalv3
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+ self.focal_factor = focal_factor
+ self.use_postln_in_modulation = use_postln_in_modulation
+ self.scaling_modulator = scaling_modulator
+
+ self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)
+
+ self.act = nn.GELU()
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.focal_layers = nn.ModuleList()
+
+ if self.use_postln_in_modulation:
+ self.ln = nn.LayerNorm(dim)
+
+ for k in range(self.focal_level):
+ kernel_size = self.focal_factor*k + self.focal_window
+ self.focal_layers.append(
+ nn.Sequential(
+ nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim,
+ padding=kernel_size//2, bias=False),
+ nn.GELU(),
+ )
+ )
+
+ def forward(self, x):
+ """ Forward function.
+
+ Args:
+ x: input features with shape of (B, H, W, C)
+ """
+ B, nH, nW, C = x.shape
+ x = self.f(x)
+ x = x.permute(0, 3, 1, 2).contiguous()
+ q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)
+
+ ctx_all = 0
+ for l in range(self.focal_level):
+ ctx = self.focal_layers[l](ctx)
+ ctx_all = ctx_all + ctx*gates[:, l:l+1]
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
+ ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]
+
+ if self.scaling_modulator:
+ ctx_all = ctx_all / (self.focal_level + 1)
+
+ x_out = q * self.h(ctx_all)
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
+ if self.use_postln_in_modulation:
+ x_out = self.ln(x_out)
+ x_out = self.proj(x_out)
+ x_out = self.proj_drop(x_out)
+ return x_out
+
+class FocalModulationBlock(nn.Module):
+ """ Focal Modulation Block.
+
+ Args:
+ dim (int): Number of input channels.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ drop (float, optional): Dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ focal_level (int): number of focal levels
+ focal_window (int): focal kernel size at level 1
+ """
+
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ focal_level=2, focal_window=9,
+ use_postln=False, use_postln_in_modulation=False,
+ scaling_modulator=False,
+ use_layerscale=False,
+ layerscale_value=1e-4):
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.focal_window = focal_window
+ self.focal_level = focal_level
+ self.use_postln = use_postln
+ self.use_layerscale = use_layerscale
+
+ self.dw1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
+ self.norm1 = norm_layer(dim)
+ self.modulation = FocalModulation(
+ dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator
+ )
+
+ self.dw2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.H = None
+ self.W = None
+
+ self.gamma_1 = 1.0
+ self.gamma_2 = 1.0
+ if self.use_layerscale:
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, x):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
+ x = x + self.dw1(x)
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)
+
+ shortcut = x
+ if not self.use_postln:
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # FM
+ x = self.modulation(x).view(B, H * W, C)
+ x = shortcut + self.drop_path(self.gamma_1 * x)
+ if self.use_postln:
+ x = self.norm1(x)
+
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
+ x = x + self.dw2(x)
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)
+
+ if not self.use_postln:
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.gamma_2 * self.mlp(x))
+ x = self.norm2(x)
+
+ return x
+
+class BasicLayer(nn.Module):
+ """ A basic focal modulation layer for one stage.
+
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ drop (float, optional): Dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ focal_level (int): Number of focal levels
+ focal_window (int): Focal window size at focal level 1
+ use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ """
+
+ def __init__(self,
+ dim,
+ depth,
+ mlp_ratio=4.,
+ drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ focal_window=9,
+ focal_level=2,
+ use_conv_embed=False,
+ use_postln=False,
+ use_postln_in_modulation=False,
+ scaling_modulator=False,
+ use_layerscale=False,
+ use_checkpoint=False,
+ use_pre_norm=False,
+ ):
+ super().__init__()
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ FocalModulationBlock(
+ dim=dim,
+ mlp_ratio=mlp_ratio,
+ drop=drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ focal_window=focal_window,
+ focal_level=focal_level,
+ use_postln=use_postln,
+ use_postln_in_modulation=use_postln_in_modulation,
+ scaling_modulator=scaling_modulator,
+ use_layerscale=use_layerscale,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ patch_size=2,
+ in_chans=dim, embed_dim=2*dim,
+ use_conv_embed=use_conv_embed,
+ norm_layer=norm_layer,
+ is_stem=False,
+ use_pre_norm=use_pre_norm
+ )
+
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ if self.downsample is not None:
+ x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)
+ x_down = self.downsample(x_reshaped)
+ x_down = x_down.flatten(2).transpose(1, 2)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+# class PatchEmbed(nn.Module):
+# r""" Image to Patch Embedding
+
+# Args:
+# img_size (int): Image size. Default: 224.
+# patch_size (int): Patch token size. Default: 4.
+# in_chans (int): Number of input image channels. Default: 3.
+# embed_dim (int): Number of linear projection output channels. Default: 96.
+# norm_layer (nn.Module, optional): Normalization layer. Default: None
+# """
+
+# def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96,
+# use_conv_embed=False, norm_layer=None, is_stem=False, use_pre_norm=False):
+# super().__init__()
+# patch_size = to_2tuple(patch_size)
+# patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+# self.img_size = img_size
+# self.patch_size = patch_size
+# self.patches_resolution = patches_resolution
+# self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+# self.in_chans = in_chans
+# self.embed_dim = embed_dim
+# self.use_pre_norm = use_pre_norm
+
+# if use_conv_embed:
+# # if we choose to use conv embedding, then we treat the stem and non-stem differently
+# if is_stem:
+# kernel_size = 7; padding = 3; stride = 4
+# else:
+# kernel_size = 3; padding = 1; stride = 2
+# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+# else:
+# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+# if self.use_pre_norm:
+# if norm_layer is not None:
+# self.norm = norm_layer(in_chans)
+# else:
+# self.norm = None
+# else:
+# if norm_layer is not None:
+# self.norm = norm_layer(embed_dim)
+# else:
+# self.norm = None
+
+# def forward(self, x):
+# B, C, H, W = x.shape
+# # FIXME look at relaxing size constraints
+# assert H == self.img_size[0] and W == self.img_size[1], \
+# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+
+# if self.use_pre_norm:
+# if self.norm is not None:
+# x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
+# x = self.norm(x).transpose(1, 2).view(B, C, H, W)
+# x = self.proj(x).flatten(2).transpose(1, 2)
+# else:
+# x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+# if self.norm is not None:
+# x = self.norm(x)
+# return x
+
+# def flops(self):
+# Ho, Wo = self.patches_resolution
+# flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+# if self.norm is not None:
+# flops += Ho * Wo * self.embed_dim
+# return flops
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False
+ is_stem (bool): Is the stem block or not.
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False, use_pre_norm=False):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.use_pre_norm = use_pre_norm
+
+ if use_conv_embed:
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
+ if is_stem:
+ kernel_size = 7; padding = 3; stride = 4
+ else:
+ kernel_size = 3; padding = 1; stride = 2
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+ else:
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ if self.use_pre_norm:
+ if norm_layer is not None:
+ self.norm = norm_layer(in_chans)
+ else:
+ self.norm = None
+ else:
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ B, C, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ if self.use_pre_norm:
+ if self.norm is not None:
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
+ x = self.norm(x).transpose(1, 2).view(B, C, H, W)
+ x = self.proj(x)
+ else:
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class FocalNet(nn.Module):
+ """ FocalNet backbone.
+
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ drop_rate (float): Dropout rate.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ focal_levels (Sequence[int]): Number of focal levels at four stages
+ focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ pretrain_img_size=1600,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ mlp_ratio=4.,
+ drop_rate=0.,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ patch_norm=True,
+ out_indices=[0, 1, 2, 3],
+ frozen_stages=-1,
+ focal_levels=[2,2,2,2],
+ focal_windows=[9,9,9,9],
+ use_pre_norms=[False, False, False, False],
+ use_conv_embed=False,
+ use_postln=False,
+ use_postln_in_modulation=False,
+ scaling_modulator=False,
+ use_layerscale=False,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None,
+ use_conv_embed=use_conv_embed, is_stem=True, use_pre_norm=False)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ mlp_ratio=mlp_ratio,
+ drop=drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
+ focal_window=focal_windows[i_layer],
+ focal_level=focal_levels[i_layer],
+ use_pre_norm=use_pre_norms[i_layer],
+ use_conv_embed=use_conv_embed,
+ use_postln=use_postln,
+ use_postln_in_modulation=use_postln_in_modulation,
+ scaling_modulator=scaling_modulator,
+ use_layerscale=use_layerscale,
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+ # self.norm = norm_layer(num_features[-1])
+
+ # add a norm layer for each output
+ for i_layer in self.out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ if isinstance(pretrained, str):
+ self.apply(_init_weights)
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ self.apply(_init_weights)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):
+ model_dict = self.state_dict()
+
+ missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict]
+ logger.info(f'=> Missed keys {missed_dict}')
+ unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict]
+ logger.info(f'=> Unexpected keys {unexpected_dict}')
+
+ pretrained_dict = {
+ k: v for k, v in pretrained_dict.items()
+ if k in model_dict.keys()
+ }
+
+ need_init_state_dict = {}
+ for k, v in pretrained_dict.items():
+ need_init = (
+ (
+ k.split('.')[0] in pretrained_layers
+ or pretrained_layers[0] == '*'
+ )
+ and 'relative_position_index' not in k
+ and 'attn_mask' not in k
+ )
+
+ if need_init:
+ # if verbose:
+ # logger.info(f'=> init {k} from {pretrained}')
+
+ if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size():
+ table_pretrained = v
+ table_current = model_dict[k]
+ fsize1 = table_pretrained.shape[2]
+ fsize2 = table_current.shape[2]
+
+ # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv
+ if fsize1 < fsize2:
+ table_pretrained_resized = torch.zeros(table_current.shape)
+ table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained
+ v = table_pretrained_resized
+ elif fsize1 > fsize2:
+ table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2]
+ v = table_pretrained_resized
+
+
+ if ("modulation.f" in k or "pre_conv" in k):
+ table_pretrained = v
+ table_current = model_dict[k]
+ if table_pretrained.shape != table_current.shape:
+ if len(table_pretrained.shape) == 2:
+ dim = table_pretrained.shape[1]
+ assert table_current.shape[1] == dim
+ L1 = table_pretrained.shape[0]
+ L2 = table_current.shape[0]
+
+ if L1 < L2:
+ table_pretrained_resized = torch.zeros(table_current.shape)
+ # copy for linear project
+ table_pretrained_resized[:2*dim] = table_pretrained[:2*dim]
+ # copy for global token gating
+ table_pretrained_resized[-1] = table_pretrained[-1]
+ # copy for first multiple focal levels
+ table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]
+ # reassign pretrained weights
+ v = table_pretrained_resized
+ elif L1 > L2:
+ raise NotImplementedError
+ elif len(table_pretrained.shape) == 1:
+ dim = table_pretrained.shape[0]
+ L1 = table_pretrained.shape[0]
+ L2 = table_current.shape[0]
+ if L1 < L2:
+ table_pretrained_resized = torch.zeros(table_current.shape)
+ # copy for linear project
+ table_pretrained_resized[:dim] = table_pretrained[:dim]
+ # copy for global token gating
+ table_pretrained_resized[-1] = table_pretrained[-1]
+ # copy for first multiple focal levels
+ # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1]
+ # reassign pretrained weights
+ v = table_pretrained_resized
+ elif L1 > L2:
+ raise NotImplementedError
+
+ need_init_state_dict[k] = v
+
+ self.load_state_dict(need_init_state_dict, strict=False)
+
+
+ def forward(self, x):
+ """Forward function."""
+ tic = time.time()
+ x = self.patch_embed(x)
+ Wh, Ww = x.size(2), x.size(3)
+
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = {}
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs["res{}".format(i + 2)] = out
+
+ if len(self.out_indices) == 0:
+ outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+
+ toc = time.time()
+ return outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(FocalNet, self).train(mode)
+ self._freeze_stages()
+
+
+class D2FocalNet(FocalNet, Backbone):
+ def __init__(self, cfg, input_shape):
+
+ pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE']
+ patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE']
+ in_chans = 3
+ embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM']
+ depths = cfg['BACKBONE']['FOCAL']['DEPTHS']
+ mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO']
+ drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE']
+ drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE']
+ norm_layer = nn.LayerNorm
+ patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM']
+ use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT']
+ out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES']
+ scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False)
+
+ super().__init__(
+ pretrain_img_size,
+ patch_size,
+ in_chans,
+ embed_dim,
+ depths,
+ mlp_ratio,
+ drop_rate,
+ drop_path_rate,
+ norm_layer,
+ patch_norm,
+ out_indices,
+ focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'],
+ focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'],
+ use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'],
+ use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'],
+ use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'],
+ scaling_modulator=scaling_modulator,
+ use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'],
+ use_checkpoint=use_checkpoint,
+ )
+
+ self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES']
+
+ self._out_feature_strides = {
+ "res2": 4,
+ "res3": 8,
+ "res4": 16,
+ "res5": 32,
+ }
+ self._out_feature_channels = {
+ "res2": self.num_features[0],
+ "res3": self.num_features[1],
+ "res4": self.num_features[2],
+ "res5": self.num_features[3],
+ }
+
+ def forward(self, x):
+ """
+ Args:
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+ Returns:
+ dict[str->Tensor]: names and the corresponding features
+ """
+ assert (
+ x.dim() == 4
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+ outputs = {}
+ y = super().forward(x)
+ for k in y.keys():
+ if k in self._out_features:
+ outputs[k] = y[k]
+ return outputs
+
+ def output_shape(self):
+ return {
+ name: ShapeSpec(
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+ )
+ for name in self._out_features
+ }
+
+ @property
+ def size_divisibility(self):
+ return 32
+
+@register_backbone
+def get_focal_backbone(cfg):
+ focal = D2FocalNet(cfg['MODEL'], 224)
+
+ if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
+ filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
+ logger.info(f'=> init from {filename}')
+ with PathManager.open(filename, "rb") as f:
+ ckpt = torch.load(f)['model']
+ focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])
+
+ return focal
\ No newline at end of file
diff --git a/modeling/vision/backbone/vit.py b/modeling/vision/backbone/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..af9bf5d7e0e3610b751a042bdf8d509d49b5e3c2
--- /dev/null
+++ b/modeling/vision/backbone/vit.py
@@ -0,0 +1,590 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Type
+from functools import partial
+
+from .common import LayerNorm2d, MLPBlock
+
+from detectron2.utils.file_io import PathManager
+from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
+
+from .build import register_backbone
+
+logger = logging.getLogger(__name__)
+
+# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
+class ImageEncoderViT(nn.Module):
+ def __init__(
+ self,
+ img_size: int = 1024,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ out_chans: int = 256,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_abs_pos: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ global_attn_indexes: Tuple[int, ...] = (),
+ ) -> None:
+ """
+ Args:
+ img_size (int): Input image size.
+ patch_size (int): Patch size.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ depth (int): Depth of ViT.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_abs_pos (bool): If True, use absolute positional embeddings.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks.
+ global_attn_indexes (list): Indexes for blocks using global attention.
+ """
+ super().__init__()
+ self.img_size = img_size
+
+ self.patch_embed = PatchEmbed(
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+
+ self.pos_embed: Optional[nn.Parameter] = None
+ if use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
+ )
+
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ block = Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ window_size=window_size if i not in global_attn_indexes else 0,
+ input_size=(img_size // patch_size, img_size // patch_size),
+ )
+ self.blocks.append(block)
+
+ self.neck = nn.Sequential(
+ nn.Conv2d(
+ embed_dim,
+ out_chans,
+ kernel_size=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ nn.Conv2d(
+ out_chans,
+ out_chans,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.patch_embed(x)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.neck(x.permute(0, 3, 1, 2))
+
+ return x
+
+
+class Block(nn.Module):
+ """Transformer blocks with support of window attention and residual propagation blocks"""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks. If it equals 0, then
+ use global attention.
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
+ positional parameter size.
+ """
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ input_size=input_size if window_size == 0 else (window_size, window_size),
+ )
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+ self.window_size = window_size
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x
+ x = self.norm1(x)
+ # Window partition
+ if self.window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, self.window_size)
+
+ x = self.attn(x)
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+ x = shortcut + x
+ x = x + self.mlp(self.norm2(x))
+
+ return x
+
+
+class Attention(nn.Module):
+ """Multi-head Attention block with relative position embeddings."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
+ positional parameter size.
+ """
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.use_rel_pos = use_rel_pos
+ if self.use_rel_pos:
+ assert (
+ input_size is not None
+ ), "Input size must be provided if using relative positional encoding."
+ # initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, H, W, _ = x.shape
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+
+ if self.use_rel_pos:
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+ x = self.proj(x)
+
+ return x
+
+
+def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ B, H, W, C = x.shape
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows, (Hp, Wp)
+
+
+def window_unpartition(
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
+) -> torch.Tensor:
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
+ return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+ Args:
+ q_size (int): size of query q.
+ k_size (int): size of key k.
+ rel_pos (Tensor): relative position embeddings (L, C).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+ else:
+ rel_pos_resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+ attn: torch.Tensor,
+ q: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+) -> torch.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
+ Args:
+ attn (Tensor): attention map.
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+
+ Returns:
+ attn (Tensor): attention map with added relative positional embeddings.
+ """
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+ B, _, dim = q.shape
+ r_q = q.reshape(B, q_h, q_w, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+ attn = (
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+ ).view(B, q_h * q_w, k_h * k_w)
+
+ return attn
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, int] = (16, 16),
+ stride: Tuple[int, int] = (16, 16),
+ padding: Tuple[int, int] = (0, 0),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
+
+class SimpleFPN(nn.Module):
+ def __init__(self, in_dim=768, out_dims=[128, 256, 512, 1024]):
+ super().__init__()
+ self.down_4_chan = max(out_dims[0]*2, in_dim // 2)
+ self.down_4 = nn.Sequential(
+ nn.ConvTranspose2d(in_dim, self.down_4_chan, 2, stride=2),
+ nn.GroupNorm(1, self.down_4_chan),
+ nn.GELU(),
+ nn.ConvTranspose2d(self.down_4_chan, self.down_4_chan // 2, 2, stride=2),
+ nn.GroupNorm(1, self.down_4_chan // 2),
+ nn.Conv2d(self.down_4_chan // 2, out_dims[0], 1),
+ nn.GroupNorm(1, out_dims[0]),
+ nn.GELU()
+ )
+ self.down_8_chan = max(out_dims[1], in_dim // 2)
+ self.down_8 = nn.Sequential(
+ nn.ConvTranspose2d(in_dim, self.down_8_chan, 2, stride=2),
+ nn.GroupNorm(1, self.down_8_chan),
+ nn.Conv2d(self.down_8_chan, out_dims[1], 1),
+ nn.GroupNorm(1, out_dims[1]),
+ nn.GELU()
+ )
+ self.down_16 = nn.Sequential(
+ nn.Conv2d(in_dim, out_dims[2], 1),
+ nn.GroupNorm(1, out_dims[2]),
+ nn.GELU()
+ )
+ self.down_32_chan = max(out_dims[3], in_dim * 2)
+ self.down_32 = nn.Sequential(
+ nn.Conv2d(in_dim, self.down_32_chan, 2, stride=2),
+ nn.GroupNorm(1, self.down_32_chan),
+ nn.Conv2d(self.down_32_chan, out_dims[3], 1),
+ nn.GroupNorm(1, out_dims[3]),
+ nn.GELU()
+ )
+
+ self.init_weights()
+
+ def init_weights(self):
+ # TODO
+ pass
+
+ def forward(self, x):
+ x_down_4 = self.down_4(x)
+ x_down_8 = self.down_8(x)
+ x_down_16 = self.down_16(x)
+ x_down_32 = self.down_32(x)
+
+ return {
+ 'res2': x_down_4,
+ 'res3': x_down_8,
+ 'res4': x_down_16,
+ 'res5': x_down_32
+ }
+
+
+class D2ViT(ImageEncoderViT, Backbone):
+ def __init__(self, cfg, input_shape):
+ size = cfg['BACKBONE']['VIT']['SIZE']
+ if size == "base":
+ encoder_depth = 12
+ encoder_embed_dim = 768
+ encoder_num_heads = 12
+ encoder_global_attn_indexes = [2, 5, 8, 11]
+ neck_in_dim=768
+ neck_out_dims=[128, 256, 512, 1024]
+ elif size == "large":
+ encoder_embed_dim = 1024
+ encoder_depth = 24
+ encoder_num_heads = 16
+ encoder_global_attn_indexes = [5, 11, 17, 23]
+ neck_in_dim=1024
+ neck_out_dims=[128, 256, 512, 1024]
+ elif size == "huge":
+ encoder_embed_dim = 1280
+ encoder_depth = 32
+ encoder_num_heads = 16
+ encoder_global_attn_indexes = [7, 15, 23, 31]
+ neck_in_dim=1280
+ neck_out_dims=[128, 256, 512, 1024]
+
+ prompt_embed_dim = 256
+ image_size = 1024
+ vit_patch_size = 16
+ image_embedding_size = image_size // vit_patch_size
+
+ super().__init__(
+ depth=encoder_depth,
+ embed_dim=encoder_embed_dim,
+ img_size=image_size,
+ mlp_ratio=4,
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+ num_heads=encoder_num_heads,
+ patch_size=vit_patch_size,
+ qkv_bias=True,
+ use_rel_pos=True,
+ global_attn_indexes=encoder_global_attn_indexes,
+ window_size=14,
+ out_chans=prompt_embed_dim,
+ )
+
+ self.neck = SimpleFPN(in_dim=neck_in_dim, out_dims=neck_out_dims)
+
+ self._out_features = cfg['BACKBONE']['VIT']['OUT_FEATURES']
+
+ self._out_feature_strides = {
+ "res2": 4,
+ "res3": 8,
+ "res4": 16,
+ "res5": 32,
+ }
+ self._out_feature_channels = {
+ "res2": neck_out_dims[0],
+ "res3": neck_out_dims[1],
+ "res4": neck_out_dims[2],
+ "res5": neck_out_dims[3],
+ }
+
+ def forward(self, x):
+ """
+ Args:
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+ Returns:
+ dict[str->Tensor]: names and the corresponding features
+ """
+ assert (
+ x.dim() == 4
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+ outputs = {}
+ y = super().forward(x)
+ for k in y.keys():
+ if k in self._out_features:
+ outputs[k] = y[k]
+ return outputs
+
+ def output_shape(self):
+ return {
+ name: ShapeSpec(
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+ )
+ for name in self._out_features
+ }
+
+ def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True):
+ model_dict = self.state_dict()
+ pretrained_dict = pretrained_dict['model'] if 'model' in pretrained_dict else pretrained_dict
+ pretrained_dict = {k.replace('image_encoder.', ''):v for k,v in pretrained_dict.items()}
+ pretrained_dict = {
+ k: v for k, v in pretrained_dict.items()
+ if k in model_dict.keys()
+ }
+ need_init_state_dict = {}
+ for k, v in pretrained_dict.items():
+ need_init = (
+ (
+ k.split('.')[0] in pretrained_layers
+ or pretrained_layers[0] == '*'
+ )
+ and 'relative_position_index' not in k
+ and 'attn_mask' not in k
+ )
+ if need_init:
+ need_init_state_dict[k] = v
+ logger.info(f'=> loaded keys {need_init_state_dict.keys()}')
+ unloaded_keys = set(model_dict.keys()) - set(need_init_state_dict.keys())
+ logger.info(f'=> unloaded keys {unloaded_keys}')
+ self.load_state_dict(need_init_state_dict, strict=False)
+
+ @property
+ def size_divisibility(self):
+ return 32
+
+@register_backbone
+def get_vit_backbone(cfg):
+ vit = D2ViT(cfg['MODEL'], 224)
+
+ if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True:
+ filename = cfg['MODEL']['BACKBONE']['PRETRAINED']
+ assert os.path.isfile(filename), f"=> no checkpoint found at '{filename}'"
+ logger.info(f'=> init from {filename}')
+ with PathManager.open(filename, "rb") as f:
+ ckpt = torch.load(f)
+ vit.load_weights(ckpt, cfg['MODEL']['BACKBONE']['VIT'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE'])
+
+ return vit
diff --git a/modeling/vision/encoder/__init__.py b/modeling/vision/encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..89af46390df7beeb8b74585408fe95ab9511d444
--- /dev/null
+++ b/modeling/vision/encoder/__init__.py
@@ -0,0 +1,15 @@
+from .transformer_encoder_fpn import *
+try:
+ from .transformer_encoder_deform import *
+except:
+ print('Deformable Transformer Encoder is not available.')
+from .build import *
+
+
+def build_encoder(config, *args, **kwargs):
+ model_name = config['MODEL']['ENCODER']['NAME']
+
+ if not is_model(model_name):
+ raise ValueError(f'Unkown model: {model_name}')
+
+ return model_entrypoints(model_name)(config, *args, **kwargs)
\ No newline at end of file
diff --git a/modeling/vision/encoder/build.py b/modeling/vision/encoder/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced42d145632d1e4963aad2881653bb8d90e34ce
--- /dev/null
+++ b/modeling/vision/encoder/build.py
@@ -0,0 +1,14 @@
+_model_entrypoints = {}
+
+
+def register_encoder(fn):
+ module_name_split = fn.__module__.split('.')
+ model_name = module_name_split[-1]
+ _model_entrypoints[model_name] = fn
+ return fn
+
+def model_entrypoints(model_name):
+ return _model_entrypoints[model_name]
+
+def is_model(model_name):
+ return model_name in _model_entrypoints
\ No newline at end of file
diff --git a/modeling/vision/encoder/ops/=3.8 b/modeling/vision/encoder/ops/=3.8
new file mode 100644
index 0000000000000000000000000000000000000000..40bd241508a2addc6ce1109f9f52e98c73a9d79a
--- /dev/null
+++ b/modeling/vision/encoder/ops/=3.8
@@ -0,0 +1,56 @@
+Collecting package metadata (current_repodata.json): ...working... done
+Solving environment: ...working... done
+
+## Package Plan ##
+
+ environment location: /home/theodorezhao/miniconda3/envs/seem
+
+ added / updated specs:
+ - python
+
+
+The following packages will be downloaded:
+
+ package | build
+ ---------------------------|-----------------
+ expat-2.5.0 | h6a678d5_0 172 KB
+ pip-23.3 | py312h06a4308_0 3.3 MB
+ python-3.12.0 | h996f2a0_0 35.0 MB
+ setuptools-68.0.0 | py312h06a4308_0 1.2 MB
+ wheel-0.37.1 | pyhd3eb1b0_0 33 KB
+ ------------------------------------------------------------
+ Total: 39.6 MB
+
+The following NEW packages will be INSTALLED:
+
+ _libgcc_mutex pkgs/main/linux-64::_libgcc_mutex-0.1-main
+ _openmp_mutex pkgs/main/linux-64::_openmp_mutex-5.1-1_gnu
+ bzip2 pkgs/main/linux-64::bzip2-1.0.8-h7b6447c_0
+ ca-certificates pkgs/main/linux-64::ca-certificates-2023.08.22-h06a4308_0
+ expat pkgs/main/linux-64::expat-2.5.0-h6a678d5_0
+ ld_impl_linux-64 pkgs/main/linux-64::ld_impl_linux-64-2.38-h1181459_1
+ libffi pkgs/main/linux-64::libffi-3.4.4-h6a678d5_0
+ libgcc-ng pkgs/main/linux-64::libgcc-ng-11.2.0-h1234567_1
+ libgomp pkgs/main/linux-64::libgomp-11.2.0-h1234567_1
+ libstdcxx-ng pkgs/main/linux-64::libstdcxx-ng-11.2.0-h1234567_1
+ libuuid pkgs/main/linux-64::libuuid-1.41.5-h5eee18b_0
+ ncurses pkgs/main/linux-64::ncurses-6.4-h6a678d5_0
+ openssl pkgs/main/linux-64::openssl-3.0.11-h7f8727e_2
+ pip pkgs/main/linux-64::pip-23.3-py312h06a4308_0
+ python pkgs/main/linux-64::python-3.12.0-h996f2a0_0
+ readline pkgs/main/linux-64::readline-8.2-h5eee18b_0
+ setuptools pkgs/main/linux-64::setuptools-68.0.0-py312h06a4308_0
+ sqlite pkgs/main/linux-64::sqlite-3.41.2-h5eee18b_0
+ tk pkgs/main/linux-64::tk-8.6.12-h1ccaba5_0
+ tzdata pkgs/main/noarch::tzdata-2023c-h04d1e81_0
+ wheel pkgs/main/noarch::wheel-0.37.1-pyhd3eb1b0_0
+ xz pkgs/main/linux-64::xz-5.4.2-h5eee18b_0
+ zlib pkgs/main/linux-64::zlib-1.2.13-h5eee18b_0
+
+
+Proceed ([y]/n)?
+
+Downloading and Extracting Packages: ...working... done
+Preparing transaction: ...working... done
+Verifying transaction: ...working... done
+Executing transaction: ...working... done
diff --git a/modeling/vision/encoder/ops/functions/__init__.py b/modeling/vision/encoder/ops/functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b06b5ac538b63bdb9a6c82e4635b95bb5491d5b
--- /dev/null
+++ b/modeling/vision/encoder/ops/functions/__init__.py
@@ -0,0 +1,13 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+from .ms_deform_attn_func import MSDeformAttnFunction
+
diff --git a/modeling/vision/encoder/ops/functions/ms_deform_attn_func.py b/modeling/vision/encoder/ops/functions/ms_deform_attn_func.py
new file mode 100644
index 0000000000000000000000000000000000000000..94a36ab85b7c5f9ecee342db91a5d5731740740f
--- /dev/null
+++ b/modeling/vision/encoder/ops/functions/ms_deform_attn_func.py
@@ -0,0 +1,72 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+try:
+ import MultiScaleDeformableAttention as MSDA
+except ModuleNotFoundError as e:
+ info_string = (
+ "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
+ "\t`cd mask2former/modeling/pixel_decoder/ops`\n"
+ "\t`sh make.sh`\n"
+ )
+ raise ModuleNotFoundError(info_string)
+
+
+class MSDeformAttnFunction(Function):
+ @staticmethod
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
+ ctx.im2col_step = im2col_step
+ output = MSDA.ms_deform_attn_forward(
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
+ grad_value, grad_sampling_loc, grad_attn_weight = \
+ MSDA.ms_deform_attn_backward(
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
+
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
+
+
+def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
+ # for debug and test only,
+ # need to use cuda version instead
+ N_, S_, M_, D_ = value.shape
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
+ # N_*M_, D_, Lq_, P_
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
+ mode='bilinear', padding_mode='zeros', align_corners=False)
+ sampling_value_list.append(sampling_value_l_)
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
+ return output.transpose(1, 2).contiguous()
diff --git a/modeling/vision/encoder/ops/make.sh b/modeling/vision/encoder/ops/make.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c4ce889dba583d6e8c391d6a03a5338b6b8ac5e5
--- /dev/null
+++ b/modeling/vision/encoder/ops/make.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+python setup.py build install --user
diff --git a/modeling/vision/encoder/ops/modules/__init__.py b/modeling/vision/encoder/ops/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fdbf03359958f3d67ab00f879bf6b61a6c8f06a
--- /dev/null
+++ b/modeling/vision/encoder/ops/modules/__init__.py
@@ -0,0 +1,12 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+from .ms_deform_attn import MSDeformAttn
diff --git a/modeling/vision/encoder/ops/modules/ms_deform_attn.py b/modeling/vision/encoder/ops/modules/ms_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7b4c42ea504a0859ccadd72646919c941e72f73
--- /dev/null
+++ b/modeling/vision/encoder/ops/modules/ms_deform_attn.py
@@ -0,0 +1,125 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import warnings
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.init import xavier_uniform_, constant_
+
+from ..functions import MSDeformAttnFunction
+from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch
+
+
+def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
+ return (n & (n-1) == 0) and n != 0
+
+
+class MSDeformAttn(nn.Module):
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
+ """
+ Multi-Scale Deformable Attention Module
+ :param d_model hidden dimension
+ :param n_levels number of feature levels
+ :param n_heads number of attention heads
+ :param n_points number of sampling points per attention head per feature level
+ """
+ super().__init__()
+ if d_model % n_heads != 0:
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
+ _d_per_head = d_model // n_heads
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
+ if not _is_power_of_2(_d_per_head):
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
+ "which is more efficient in our CUDA implementation.")
+
+ self.im2col_step = 128
+
+ self.d_model = d_model
+ self.n_levels = n_levels
+ self.n_heads = n_heads
+ self.n_points = n_points
+
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
+ self.value_proj = nn.Linear(d_model, d_model)
+ self.output_proj = nn.Linear(d_model, d_model)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ constant_(self.sampling_offsets.weight.data, 0.)
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
+ for i in range(self.n_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ constant_(self.attention_weights.weight.data, 0.)
+ constant_(self.attention_weights.bias.data, 0.)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.)
+
+ def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
+ """
+ :param query (N, Length_{query}, C)
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
+
+ :return output (N, Length_{query}, C)
+ """
+ N, Len_q, _ = query.shape
+ N, Len_in, _ = input_flatten.shape
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
+
+ value = self.value_proj(input_flatten)
+ if input_padding_mask is not None:
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
+ # N, Len_q, n_heads, n_levels, n_points, 2
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
+ sampling_locations = reference_points[:, :, None, :, None, :] \
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+ else:
+ raise ValueError(
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
+ try:
+ output = MSDeformAttnFunction.apply(
+ value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
+ except:
+ # CPU
+ output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
+ # # For FLOPs calculation only
+ # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
+ output = self.output_proj(output)
+ return output
diff --git a/modeling/vision/encoder/ops/setup.py b/modeling/vision/encoder/ops/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b57ad313ac8f9b6586892142da8ba943e516cec
--- /dev/null
+++ b/modeling/vision/encoder/ops/setup.py
@@ -0,0 +1,78 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+import os
+import glob
+
+import torch
+
+from torch.utils.cpp_extension import CUDA_HOME
+from torch.utils.cpp_extension import CppExtension
+from torch.utils.cpp_extension import CUDAExtension
+
+from setuptools import find_packages
+from setuptools import setup
+
+requirements = ["torch", "torchvision"]
+
+def get_extensions():
+ this_dir = os.path.dirname(os.path.abspath(__file__))
+ extensions_dir = os.path.join(this_dir, "src")
+
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
+
+ sources = main_file + source_cpu
+ extension = CppExtension
+ extra_compile_args = {"cxx": []}
+ define_macros = []
+
+ # Force cuda since torch ask for a device, not if cuda is in fact available.
+ if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
+ extension = CUDAExtension
+ sources += source_cuda
+ define_macros += [("WITH_CUDA", None)]
+ extra_compile_args["nvcc"] = [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ else:
+ if CUDA_HOME is None:
+ raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
+ else:
+ raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
+
+ sources = [os.path.join(extensions_dir, s) for s in sources]
+ include_dirs = [extensions_dir]
+ ext_modules = [
+ extension(
+ "MultiScaleDeformableAttention",
+ sources,
+ include_dirs=include_dirs,
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+ ]
+ return ext_modules
+
+setup(
+ name="MultiScaleDeformableAttention",
+ version="1.0",
+ author="Weijie Su",
+ url="https://github.com/fundamentalvision/Deformable-DETR",
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
+ packages=find_packages(exclude=("configs", "tests",)),
+ ext_modules=get_extensions(),
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+)
diff --git a/modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp b/modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..48757e2b0156b2c1513b615d2a17e5aee5172ae7
--- /dev/null
+++ b/modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp
@@ -0,0 +1,46 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#include
+
+#include
+#include
+
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
diff --git a/modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.h b/modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..51bb27e9ee828f967e8aa854c2d55574040c6d7e
--- /dev/null
+++ b/modeling/vision/encoder/ops/src/cpu/ms_deform_attn_cpu.h
@@ -0,0 +1,38 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#pragma once
+#include
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
+
diff --git a/modeling/vision/encoder/ops/src/cuda/ms_deform_attn_cuda.cu b/modeling/vision/encoder/ops/src/cuda/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..0c465dab3d636dfd6a44523c63f148b6e15084d9
--- /dev/null
+++ b/modeling/vision/encoder/ops/src/cuda/ms_deform_attn_cuda.cu
@@ -0,0 +1,158 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#include
+#include "cuda/ms_deform_im2col_cuda.cuh"
+
+#include
+#include
+#include
+#include
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
\ No newline at end of file
diff --git a/modeling/vision/encoder/ops/src/cuda/ms_deform_attn_cuda.h b/modeling/vision/encoder/ops/src/cuda/ms_deform_attn_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..4f0658e8668a11f0e7d71deff9adac71884f2e87
--- /dev/null
+++ b/modeling/vision/encoder/ops/src/cuda/ms_deform_attn_cuda.h
@@ -0,0 +1,35 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#pragma once
+#include
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
diff --git a/modeling/vision/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh b/modeling/vision/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..c04e0d4ab97d25c1756fcd8d08dd1e5a6d280b7c
--- /dev/null
+++ b/modeling/vision/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1332 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
\ No newline at end of file
diff --git a/modeling/vision/encoder/ops/src/ms_deform_attn.h b/modeling/vision/encoder/ops/src/ms_deform_attn.h
new file mode 100644
index 0000000000000000000000000000000000000000..2f80a1b294c55b37d13bb3558ff7aeadba3b37de
--- /dev/null
+++ b/modeling/vision/encoder/ops/src/ms_deform_attn.h
@@ -0,0 +1,67 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#pragma once
+
+#include "cpu/ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "cuda/ms_deform_attn_cuda.h"
+#endif
+
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
diff --git a/modeling/vision/encoder/ops/src/vision.cpp b/modeling/vision/encoder/ops/src/vision.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4a08821e0121a77556aa7a263ec8ebfa928b13b6
--- /dev/null
+++ b/modeling/vision/encoder/ops/src/vision.cpp
@@ -0,0 +1,21 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#include "ms_deform_attn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
diff --git a/modeling/vision/encoder/ops/test.py b/modeling/vision/encoder/ops/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..be764f06db923662af64e8fdf813f416d9c0e09c
--- /dev/null
+++ b/modeling/vision/encoder/ops/test.py
@@ -0,0 +1,92 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import time
+import torch
+import torch.nn as nn
+from torch.autograd import gradcheck
+
+from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
+
+
+N, M, D = 1, 2, 2
+Lq, L, P = 2, 2, 2
+shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
+level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
+S = sum([(H*W).item() for H, W in shapes])
+
+
+torch.manual_seed(3)
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_double():
+ value = torch.rand(N, S, M, D) * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2)
+ attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
+ fwdok = torch.allclose(output_cuda, output_pytorch)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_float():
+ value = torch.rand(N, S, M, D) * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2)
+ attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
+
+ value = torch.rand(N, S, M, channels) * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2)
+ attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ func = MSDeformAttnFunction.apply
+
+ value.requires_grad = grad_value
+ sampling_locations.requires_grad = grad_sampling_loc
+ attention_weights.requires_grad = grad_attn_weight
+
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
+
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
+
+
+if __name__ == '__main__':
+ check_forward_equal_with_pytorch_double()
+ check_forward_equal_with_pytorch_float()
+
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
+ check_gradient_numerical(channels, True, True, True)
+
+
+
diff --git a/modeling/vision/encoder/transformer_blocks.py b/modeling/vision/encoder/transformer_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..54134f34556b32c98401be2eb862e539ccb812d4
--- /dev/null
+++ b/modeling/vision/encoder/transformer_blocks.py
@@ -0,0 +1,370 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
+"""
+Transformer class.
+
+Copy-paste from torch.nn.Transformer with modifications:
+ * positional encodings are passed in MHattention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+"""
+import copy
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, mask, query_embed, pos_embed):
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ if mask is not None:
+ mask = mask.flatten(1)
+
+ tgt = torch.zeros_like(query_embed)
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ hs = self.decoder(
+ tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
+ )
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ output = src
+
+ for layer in self.layers:
+ output = layer(
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
+ )
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ output = tgt
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ query_pos=query_pos,
+ )
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(src, pos)
+
+ src2 = self.self_attn(
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+ )[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+ )[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,
+ )
+ return self.forward_post(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,
+ )
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
diff --git a/modeling/vision/encoder/transformer_encoder_deform.py b/modeling/vision/encoder/transformer_encoder_deform.py
new file mode 100644
index 0000000000000000000000000000000000000000..40e4c360f0f397060fa833cf4a244284c4300d4c
--- /dev/null
+++ b/modeling/vision/encoder/transformer_encoder_deform.py
@@ -0,0 +1,378 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import numpy as np
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+from torch.cuda.amp import autocast
+
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from .ops.modules import MSDeformAttn
+from .build import register_encoder
+from .transformer_blocks import _get_clones, _get_activation_fn
+from ...utils import configurable
+from ...modules import PositionEmbeddingSine
+
+
+# MSDeformAttn Transformer encoder in deformable detr
+class MSDeformAttnTransformerEncoderOnly(nn.Module):
+ def __init__(self, d_model=256, nhead=8,
+ num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
+ activation="relu",
+ num_feature_levels=4, enc_n_points=4,
+ ):
+ super().__init__()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, enc_n_points)
+ self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers)
+
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+ normal_(self.level_embed)
+
+ def get_valid_ratio(self, mask):
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def forward(self, srcs, pos_embeds):
+ masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]
+ # prepare input for encoder
+ src_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
+ bs, c, h, w = src.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+ src = src.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ src_flatten.append(src)
+ mask_flatten.append(mask)
+ src_flatten = torch.cat(src_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
+
+ # encoder
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
+ return memory, spatial_shapes, level_start_index
+
+
+class MSDeformAttnTransformerEncoderLayer(nn.Module):
+ def __init__(self,
+ d_model=256, d_ffn=1024,
+ dropout=0.1, activation="relu",
+ n_levels=4, n_heads=8, n_points=4):
+ super().__init__()
+
+ # self attention
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout2 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout3 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, src):
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+ src = src + self.dropout3(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
+ # self attention
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ # ffn
+ src = self.forward_ffn(src)
+
+ return src
+
+
+class MSDeformAttnTransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ reference_points_list = []
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
+ output = src
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
+ for _, layer in enumerate(self.layers):
+ output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
+
+ return output
+
+
+# @SEM_SEG_HEADS_REGISTRY.register()
+class MSDeformAttnPixelDecoder(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ input_shape: Dict[str, ShapeSpec],
+ *,
+ transformer_dropout: float,
+ transformer_nheads: int,
+ transformer_dim_feedforward: int,
+ transformer_enc_layers: int,
+ conv_dim: int,
+ mask_dim: int,
+ norm: Optional[Union[str, Callable]] = None,
+ # deformable transformer encoder args
+ transformer_in_features: List[str],
+ common_stride: int,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ input_shape: shapes (channels and stride) of the input features
+ transformer_dropout: dropout probability in transformer
+ transformer_nheads: number of heads in transformer
+ transformer_dim_feedforward: dimension of feedforward network
+ transformer_enc_layers: number of transformer encoder layers
+ conv_dims: number of output channels for the intermediate conv layers.
+ mask_dim: number of output channels for the final conv layer.
+ norm (str or callable): normalization for all conv layers
+ """
+ super().__init__()
+ transformer_input_shape = {
+ k: v for k, v in input_shape.items() if k in transformer_in_features
+ }
+
+ # this is the input shape of pixel decoder
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
+ self.feature_strides = [v.stride for k, v in input_shape]
+ self.feature_channels = [v.channels for k, v in input_shape]
+
+ # this is the input shape of transformer encoder (could use less features than pixel decoder
+ transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)
+ self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5"
+ transformer_in_channels = [v.channels for k, v in transformer_input_shape]
+ self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers
+
+ self.transformer_num_feature_levels = len(self.transformer_in_features)
+ if self.transformer_num_feature_levels > 1:
+ input_proj_list = []
+ # from low resolution to high resolution (res5 -> res2)
+ for in_channels in transformer_in_channels[::-1]:
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, conv_dim, kernel_size=1),
+ nn.GroupNorm(32, conv_dim),
+ ))
+ self.input_proj = nn.ModuleList(input_proj_list)
+ else:
+ self.input_proj = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),
+ nn.GroupNorm(32, conv_dim),
+ )])
+
+ for proj in self.input_proj:
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
+ nn.init.constant_(proj[0].bias, 0)
+
+ self.transformer = MSDeformAttnTransformerEncoderOnly(
+ d_model=conv_dim,
+ dropout=transformer_dropout,
+ nhead=transformer_nheads,
+ dim_feedforward=transformer_dim_feedforward,
+ num_encoder_layers=transformer_enc_layers,
+ num_feature_levels=self.transformer_num_feature_levels,
+ )
+ N_steps = conv_dim // 2
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+ self.mask_dim = mask_dim
+ # use 1x1 conv instead
+ self.mask_features = Conv2d(
+ conv_dim,
+ mask_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ weight_init.c2_xavier_fill(self.mask_features)
+
+ self.maskformer_num_feature_levels = 3 # always use 3 scales
+ self.common_stride = common_stride
+
+ # extra fpn levels
+ stride = min(self.transformer_feature_strides)
+ self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride))
+
+ lateral_convs = []
+ output_convs = []
+
+ use_bias = norm == ""
+ for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]):
+ lateral_norm = get_norm(norm, conv_dim)
+ output_norm = get_norm(norm, conv_dim)
+
+ lateral_conv = Conv2d(
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
+ )
+ output_conv = Conv2d(
+ conv_dim,
+ conv_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias,
+ norm=output_norm,
+ activation=F.relu,
+ )
+ weight_init.c2_xavier_fill(lateral_conv)
+ weight_init.c2_xavier_fill(output_conv)
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
+ self.add_module("layer_{}".format(idx + 1), output_conv)
+
+ lateral_convs.append(lateral_conv)
+ output_convs.append(output_conv)
+ # Place convs into top-down order (from low to high resolution)
+ # to make the top-down computation in forward clearer.
+ self.lateral_convs = lateral_convs[::-1]
+ self.output_convs = output_convs[::-1]
+
+ @classmethod
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+ ret = {}
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ ret["input_shape"] = {
+ k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
+ }
+ ret["conv_dim"] = enc_cfg['CONVS_DIM']
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
+ ret["norm"] = enc_cfg['NORM']
+ ret["transformer_dropout"] = dec_cfg['DROPOUT']
+ ret["transformer_nheads"] = dec_cfg['NHEADS']
+ # ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+ ret["transformer_dim_feedforward"] = 1024 # use 1024 for deformable transformer encoder
+ ret[
+ "transformer_enc_layers"
+ ] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config
+ ret["transformer_in_features"] = enc_cfg['DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES']
+ ret["common_stride"] = enc_cfg['COMMON_STRIDE']
+ return ret
+
+ @autocast(enabled=False)
+ def forward_features(self, features):
+ srcs = []
+ pos = []
+ # Reverse feature maps into top-down order (from low to high resolution)
+ for idx, f in enumerate(self.transformer_in_features[::-1]):
+ x = features[f].float() # deformable detr does not support half precision
+ srcs.append(self.input_proj[idx](x))
+ pos.append(self.pe_layer(x))
+
+
+ y, spatial_shapes, level_start_index = self.transformer(srcs, pos)
+ bs = y.shape[0]
+
+ split_size_or_sections = [None] * self.transformer_num_feature_levels
+ for i in range(self.transformer_num_feature_levels):
+ if i < self.transformer_num_feature_levels - 1:
+ split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
+ else:
+ split_size_or_sections[i] = y.shape[1] - level_start_index[i]
+ y = torch.split(y, split_size_or_sections, dim=1)
+
+ out = []
+ multi_scale_features = []
+ num_cur_levels = 0
+ for i, z in enumerate(y):
+ out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
+
+ # append `out` with extra FPN levels
+ # Reverse feature maps into top-down order (from low to high resolution)
+ for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):
+ x = features[f].float()
+ lateral_conv = self.lateral_convs[idx]
+ output_conv = self.output_convs[idx]
+ cur_fpn = lateral_conv(x)
+ # Following FPN implementation, we use nearest upsampling here
+ y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False)
+ y = output_conv(y)
+ out.append(y)
+
+ for o in out:
+ if num_cur_levels < self.maskformer_num_feature_levels:
+ multi_scale_features.append(o)
+ num_cur_levels += 1
+
+ return self.mask_features(out[-1]), out[0], multi_scale_features
+
+
+
+@register_encoder
+def get_transformer_encoder_deform(cfg, input_shape):
+ """
+ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
+ """
+ model = MSDeformAttnPixelDecoder(cfg, input_shape)
+ forward_features = getattr(model, "forward_features", None)
+ if not callable(forward_features):
+ raise ValueError(
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
+ f"Please implement forward_features for {name} to only return mask features."
+ )
+ return model
\ No newline at end of file
diff --git a/modeling/vision/encoder/transformer_encoder_fpn.py b/modeling/vision/encoder/transformer_encoder_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac0e02b033d741a22489c5c296c2f5ae1350ce47
--- /dev/null
+++ b/modeling/vision/encoder/transformer_encoder_fpn.py
@@ -0,0 +1,323 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import numpy as np
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+from torch.cuda.amp import autocast
+
+import fvcore.nn.weight_init as weight_init
+from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm
+
+from .build import register_encoder
+from .transformer_blocks import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn
+from ...modules import PositionEmbeddingSine
+from ...utils import configurable
+
+
+# This is a modified FPN decoder.
+class BasePixelDecoder(nn.Module):
+ def __init__(
+ self,
+ input_shape: Dict[str, ShapeSpec],
+ *,
+ conv_dim: int,
+ mask_dim: int,
+ mask_on: bool,
+ norm: Optional[Union[str, Callable]] = None,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ input_shape: shapes (channels and stride) of the input features
+ conv_dims: number of output channels for the intermediate conv layers.
+ mask_dim: number of output channels for the final conv layer.
+ norm (str or callable): normalization for all conv layers
+ """
+ super().__init__()
+
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
+ feature_channels = [v.channels for k, v in input_shape]
+
+ lateral_convs = []
+ output_convs = []
+
+ use_bias = norm == ""
+ for idx, in_channels in enumerate(feature_channels):
+ if idx == len(self.in_features) - 1:
+ output_norm = get_norm(norm, conv_dim)
+ output_conv = Conv2d(
+ in_channels,
+ conv_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias,
+ norm=output_norm,
+ activation=F.relu,
+ )
+ weight_init.c2_xavier_fill(output_conv)
+ self.add_module("layer_{}".format(idx + 1), output_conv)
+
+ lateral_convs.append(None)
+ output_convs.append(output_conv)
+ else:
+ lateral_norm = get_norm(norm, conv_dim)
+ output_norm = get_norm(norm, conv_dim)
+
+ lateral_conv = Conv2d(
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
+ )
+ output_conv = Conv2d(
+ conv_dim,
+ conv_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias,
+ norm=output_norm,
+ activation=F.relu,
+ )
+ weight_init.c2_xavier_fill(lateral_conv)
+ weight_init.c2_xavier_fill(output_conv)
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
+ self.add_module("layer_{}".format(idx + 1), output_conv)
+
+ lateral_convs.append(lateral_conv)
+ output_convs.append(output_conv)
+ # Place convs into top-down order (from low to high resolution)
+ # to make the top-down computation in forward clearer.
+ self.lateral_convs = lateral_convs[::-1]
+ self.output_convs = output_convs[::-1]
+
+ self.mask_on = mask_on
+ if self.mask_on:
+ self.mask_dim = mask_dim
+ self.mask_features = Conv2d(
+ conv_dim,
+ mask_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+ weight_init.c2_xavier_fill(self.mask_features)
+
+ self.maskformer_num_feature_levels = 3 # always use 3 scales
+
+ @classmethod
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+ enc_cfg = cfg['MODEL']['ENCODER']
+ ret = {}
+ ret["input_shape"] = {
+ k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
+ }
+ ret["conv_dim"] = enc_cfg['CONVS_DIM']
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
+ ret["norm"] = enc_cfg['NORM']
+ return ret
+
+ def forward_features(self, features):
+ multi_scale_features = []
+ num_cur_levels = 0
+ # Reverse feature maps into top-down order (from low to high resolution)
+ for idx, f in enumerate(self.in_features[::-1]):
+ x = features[f]
+ lateral_conv = self.lateral_convs[idx]
+ output_conv = self.output_convs[idx]
+ if lateral_conv is None:
+ y = output_conv(x)
+ else:
+ cur_fpn = lateral_conv(x)
+ # Following FPN implementation, we use nearest upsampling here
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
+ y = output_conv(y)
+ if num_cur_levels < self.maskformer_num_feature_levels:
+ multi_scale_features.append(y)
+ num_cur_levels += 1
+
+ mask_features = self.mask_features(y) if self.mask_on else None
+ return mask_features, None, multi_scale_features
+
+ def forward(self, features, targets=None):
+ logger = logging.getLogger(__name__)
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
+ return self.forward_features(features)
+
+
+class TransformerEncoderOnly(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, mask, pos_embed):
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ if mask is not None:
+ mask = mask.flatten(1)
+
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ return memory.permute(1, 2, 0).view(bs, c, h, w)
+
+
+# This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map.
+class TransformerEncoderPixelDecoder(BasePixelDecoder):
+ @configurable
+ def __init__(
+ self,
+ input_shape: Dict[str, ShapeSpec],
+ *,
+ transformer_dropout: float,
+ transformer_nheads: int,
+ transformer_dim_feedforward: int,
+ transformer_enc_layers: int,
+ transformer_pre_norm: bool,
+ conv_dim: int,
+ mask_dim: int,
+ mask_on: int,
+ norm: Optional[Union[str, Callable]] = None,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ input_shape: shapes (channels and stride) of the input features
+ transformer_dropout: dropout probability in transformer
+ transformer_nheads: number of heads in transformer
+ transformer_dim_feedforward: dimension of feedforward network
+ transformer_enc_layers: number of transformer encoder layers
+ transformer_pre_norm: whether to use pre-layernorm or not
+ conv_dims: number of output channels for the intermediate conv layers.
+ mask_dim: number of output channels for the final conv layer.
+ norm (str or callable): normalization for all conv layers
+ """
+ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm, mask_on=mask_on)
+
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
+ feature_strides = [v.stride for k, v in input_shape]
+ feature_channels = [v.channels for k, v in input_shape]
+
+ in_channels = feature_channels[len(self.in_features) - 1]
+ self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
+ weight_init.c2_xavier_fill(self.input_proj)
+ self.transformer = TransformerEncoderOnly(
+ d_model=conv_dim,
+ dropout=transformer_dropout,
+ nhead=transformer_nheads,
+ dim_feedforward=transformer_dim_feedforward,
+ num_encoder_layers=transformer_enc_layers,
+ normalize_before=transformer_pre_norm,
+ )
+ N_steps = conv_dim // 2
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+ # update layer
+ use_bias = norm == ""
+ output_norm = get_norm(norm, conv_dim)
+ output_conv = Conv2d(
+ conv_dim,
+ conv_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias,
+ norm=output_norm,
+ activation=F.relu,
+ )
+ weight_init.c2_xavier_fill(output_conv)
+ delattr(self, "layer_{}".format(len(self.in_features)))
+ self.add_module("layer_{}".format(len(self.in_features)), output_conv)
+ self.output_convs[0] = output_conv
+
+ @classmethod
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+ enc_cfg = cfg['MODEL']['ENCODER']
+ dec_cfg = cfg['MODEL']['DECODER']
+
+ ret = super().from_config(cfg, input_shape)
+ ret["transformer_dropout"] = dec_cfg['DROPOUT']
+ ret["transformer_nheads"] = dec_cfg['NHEADS']
+ ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
+ ret["transformer_enc_layers"] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config
+ ret["transformer_pre_norm"] = dec_cfg['PRE_NORM']
+
+ ret['mask_on'] = cfg['MODEL']['DECODER']['MASK']
+ return ret
+
+ def forward_features(self, features):
+ multi_scale_features = []
+ num_cur_levels = 0
+
+ # Reverse feature maps into top-down order (from low to high resolution)
+ for idx, f in enumerate(self.in_features[::-1]):
+ x = features[f]
+ lateral_conv = self.lateral_convs[idx]
+ output_conv = self.output_convs[idx]
+ if lateral_conv is None:
+ transformer = self.input_proj(x)
+ pos = self.pe_layer(x)
+ transformer = self.transformer(transformer, None, pos)
+ y = output_conv(transformer)
+ # save intermediate feature as input to Transformer decoder
+ transformer_encoder_features = transformer
+ else:
+ cur_fpn = lateral_conv(x)
+ # Following FPN implementation, we use nearest upsampling here
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
+ y = output_conv(y)
+ if num_cur_levels < self.maskformer_num_feature_levels:
+ multi_scale_features.append(y)
+ num_cur_levels += 1
+
+ mask_features = self.mask_features(y) if self.mask_on else None
+ return mask_features, transformer_encoder_features, multi_scale_features
+
+ def forward(self, features, targets=None):
+ logger = logging.getLogger(__name__)
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
+ return self.forward_features(features)
+
+
+
+@register_encoder
+def get_transformer_encoder_fpn(cfg, input_shape):
+ """
+ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
+ """
+ model = TransformerEncoderPixelDecoder(cfg, input_shape)
+ forward_features = getattr(model, "forward_features", None)
+ if not callable(forward_features):
+ raise ValueError(
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
+ f"Please implement forward_features for {name} to only return mask features."
+ )
+ return model
\ No newline at end of file
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7faeae3edeb61a5d0ea0c010115f1ccaa989ca2a
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1,2 @@
+mpich
+libmpich-dev
\ No newline at end of file
diff --git a/pipeline/XDecoderPipeline.py b/pipeline/XDecoderPipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ea586db8e7eb24e8f55cfa9df3b1904a1717f4e
--- /dev/null
+++ b/pipeline/XDecoderPipeline.py
@@ -0,0 +1,213 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import logging
+import time
+import datetime
+import json
+import os
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from typing import Tuple, Dict, List, Union
+from infinibatch import iterators
+
+from trainer.default_trainer import DefaultTrainer
+
+from detectron2.evaluation import inference_on_dataset
+from detectron2.utils.logger import log_every_n_seconds
+from detectron2.data import MetadataCatalog
+
+from modeling import build_model
+from modeling.utils import get_class_names
+from modeling.BaseModel import BaseModel
+from datasets import build_evaluator, build_eval_dataloader, build_train_dataloader
+from utilities.distributed import is_main_process
+from utilities.constants import COCO_PANOPTIC_CLASSES
+from trainer.utils.misc import move_batch_to_device, cast_batch_to_half
+
+from .utils.misc import hook_metadata, hook_switcher, hook_opt
+
+logger = logging.getLogger(__name__)
+
+
+class XDecoderPipeline:
+ def __init__(self, opt):
+ self._opt = opt
+ print(self._opt['RESUME_FROM'])
+
+ def initialize_model(self):
+ model_name = "default"
+ model = build_model(self._opt)
+ model.train()
+
+ if is_main_process():
+ logger.info(model)
+
+ raw_models = {model_name: BaseModel(self._opt, model)}
+ return raw_models
+
+ def get_dataloaders(
+ self, trainer: DefaultTrainer,
+ dataset_label: str,
+ is_evaluation: bool
+ ) -> Union[DataLoader, iterators.CheckpointableIterator]:
+ distributed = self._opt['world_size'] > 1
+ if is_evaluation:
+ if not hasattr(self, 'valid_loader'):
+ dataloaders = build_eval_dataloader(self._opt)
+ self.valid_loader = dataloaders
+ else:
+ dataloaders = self.valid_loader
+ idx = 0 if dataset_label=='dev' else self._opt['DATASETS']['TEST'].index(dataset_label)
+ dataloader = dataloaders[idx]
+ self.evaluator = build_evaluator(self._opt, self._opt['DATASETS']['TEST'][idx], self._opt['SAVE_DIR'])
+ else:
+ if not hasattr(self, 'train_loader'):
+ dataloader = build_train_dataloader(self._opt)
+ self.train_loader = dataloader
+ logger.info(f'num of train samples: {len(dataloader)}')
+ else:
+ dataloader = self.train_loader
+
+ # temp solution for lr scheduler
+ steps_total = len(self.train_loader)
+ steps_acc = self._opt['GRADIENT_ACCUMULATE_STEP']
+ steps_update = steps_total // steps_acc
+ self._opt["LR_SCHEDULER_PARAMS"]["steps_update_per_epoch"] = steps_update
+ return dataloader
+
+ @staticmethod
+ def forward_func(trainer, batch):
+ loss = trainer.models['default'](batch)
+ return loss
+
+ def forward_step(
+ self,
+ trainer: DefaultTrainer,
+ batch,
+ grad_acc_batches: List,
+ grad_acc_index: int,
+ is_distributed: bool,
+ ) -> Tuple[Dict[str, float], Dict[str, int], Dict]:
+ loss_info, sample_size_info, extra_info = {}, {}, {}
+ batch = move_batch_to_device(batch, self._opt['device'])
+ if self._opt['FP16']:
+ # in FP16 mode, DeepSpeed casts the model to FP16, so the input needs to be manually casted to FP16
+ batch = cast_batch_to_half(batch)
+ loss = trainer.compute_loss(self.forward_func, batch)
+ loss_info = {k: v.detach().item() for k,v in loss.items()}
+ sample_size_info = {'num_samples': len(batch)}
+ loss = sum(loss for loss in loss.values())
+ trainer.backward_loss(loss, model_names=['default'])
+ trainer.update_model(model_name='default')
+ return loss_info, sample_size_info, extra_info
+
+ def evaluate_model(
+ self,
+ trainer: DefaultTrainer,
+ save_folder,
+ ) -> Tuple[Dict, Dict[str, float], bool]:
+
+ model = trainer.raw_models['default'].eval()
+ self._opt = hook_opt(self._opt)
+ dataset_names = self._opt['DATASETS']['TEST']
+ scores = {}
+ summary = {}
+
+ for dataset_label in dataset_names:
+ torch.cuda.empty_cache()
+ eval_batch_gen = self.get_dataloaders(trainer, dataset_label, is_evaluation=True)
+ self.evaluator.reset()
+ with torch.no_grad():
+ names = get_class_names(dataset_label)
+ if self._opt['MODEL']['ENCODER']['BINARY_CLASSES']:
+ names = ['target', 'background']
+ model.model.metadata = MetadataCatalog.get(dataset_label)
+ model.model.metadata = hook_metadata(model.model.metadata, dataset_label)
+ eval_type = model.model.metadata.evaluator_type
+ if 'background' in names:
+ model.model.sem_seg_head.num_classes = len(names) - 1
+ model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(names, is_eval=True)
+ hook_switcher(model, dataset_label)
+ total = len(eval_batch_gen)
+ num_warmup = min(5, total - 1)
+ start_time = time.perf_counter()
+ total_data_time = 0
+ total_compute_time = 0
+ total_eval_time = 0
+ start_data_time = time.perf_counter()
+
+ for idx, batch in enumerate(eval_batch_gen):
+ total_data_time += time.perf_counter() - start_data_time
+ if idx == num_warmup:
+ start_time = time.perf_counter()
+ total_data_time = 0
+ total_compute_time = 0
+ total_eval_time = 0
+
+ start_compute_time = time.perf_counter()
+ batch = move_batch_to_device(batch, self._opt['device'])
+ if self._opt['FP16']:
+ # in FP16 mode, DeepSpeed casts the model to FP16, so the input needs to be manually casted to FP16
+ batch = cast_batch_to_half(batch)
+
+ outputs = model(batch, mode=eval_type)
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ total_compute_time += time.perf_counter() - start_compute_time
+ start_eval_time = time.perf_counter()
+
+ self.evaluator.process(batch, outputs)
+ total_eval_time += time.perf_counter() - start_eval_time
+
+ iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
+ data_seconds_per_iter = total_data_time / iters_after_start
+ compute_seconds_per_iter = total_compute_time / iters_after_start
+ eval_seconds_per_iter = total_eval_time / iters_after_start
+ total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start
+
+ if is_main_process() and (idx >= num_warmup * 2 or compute_seconds_per_iter > 5):
+ eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1)))
+ log_every_n_seconds(
+ logging.INFO,
+ (
+ f"Task {dataset_label}. "
+ f"Inference done {idx + 1}/{total}. "
+ f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "
+ f"Inference: {compute_seconds_per_iter:.4f} s/iter. "
+ f"Eval: {eval_seconds_per_iter:.4f} s/iter. "
+ f"Total: {total_seconds_per_iter:.4f} s/iter. "
+ f"ETA={eta}"
+ ),
+ n=5,
+ )
+ start_data_time = time.perf_counter()
+
+ results = self.evaluator.evaluate()
+ model.model.sem_seg_head.predictor.lang_encoder.reset_text_embeddings()
+
+ if is_main_process():
+ scores["{}/{}".format(dataset_label, eval_type)] = results
+
+ # set back to training stat.
+ model.model.sem_seg_head.num_classes = self._opt['MODEL']['ENCODER']['NUM_CLASSES']
+ model.model.metadata = MetadataCatalog.get(self._opt['DATASETS']['TRAIN'][0])
+ # save scores
+ if is_main_process():
+ model_name = self._opt['RESUME_FROM'].split('/')[-1].split('.')[0]
+ with open(os.path.join(save_folder,f'{model_name}_eval_results.json'), 'w') as f:
+ json.dump(scores, f, indent=4)
+ # todo
+ # hack to return only results/scores
+ for datatype in scores:
+ for evaltype in scores[datatype]:
+ if 'instance_results' in scores[datatype][evaltype]:
+ scores[datatype][evaltype]= scores[datatype][evaltype]['scores']
+ return scores
\ No newline at end of file
diff --git a/pipeline/__init__.py b/pipeline/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/pipeline/utils/misc.py b/pipeline/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2999e17b2f862918d1be4e3e87c6d3e16f31d1a6
--- /dev/null
+++ b/pipeline/utils/misc.py
@@ -0,0 +1,51 @@
+import logging
+import torch
+
+logger = logging.getLogger(__name__)
+
+def hook_opt(opt):
+
+ try:
+ grounding_flag = opt['REF']['INPUT']['SPATIAL']
+ except:
+ grounding_flag = False
+
+ if grounding_flag:
+ opt['ATTENTION_ARCH']['SELF_ATTENTION']['queries']['grounding'] = ['queries_grounding', 'tokens_grounding', 'tokens_spatial']
+
+ try:
+ spatial_flag = opt['STROKE_SAMPLER']['EVAL']['GROUNDING']
+ except:
+ spatial_flag = False
+
+ if spatial_flag:
+ opt['ATTENTION_ARCH']['SELF_ATTENTION']['queries']['spatial'] = ['queries_spatial', 'tokens_spatial', 'memories_spatial', 'tokens_grounding']
+
+ return opt
+
+# HACK for evalution
+def hook_metadata(metadata, name):
+ return metadata
+
+# HACK for evalution
+def hook_switcher(model, name):
+ mappings = {}
+ if name in ['cityscapes_fine_sem_seg_val', 'scannet_21_val_seg', 'scannet_38_val_seg', 'scannet_41_val_seg', 'sunrgbd_37_val_seg', 'context_59_val_seg', 'context_459_val_seg', 'voc_2012_val_seg', 'bdd10k_val_sem_seg', 'ade20k_full_sem_seg_val']:
+ mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': False}
+ elif name in ['cityscapes_fine_instance_seg_val'] or 'seginw' in name:
+ mappings = {'SEMANTIC_ON': False, 'INSTANCE_ON': True, 'PANOPTIC_ON': False}
+ elif name in ['cityscapes_fine_panoptic_val', 'scannet_21_panoptic_val', 'bdd10k_40_panoptic_val']:
+ mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': False, 'PANOPTIC_ON': True}
+ elif name in ['coco_2017_val_panoptic_with_sem_seg', 'ade20k_panoptic_val', 'coco_2017_test-dev']:
+ mappings = {'SEMANTIC_ON': True, 'INSTANCE_ON': True, 'PANOPTIC_ON': True}
+ else:
+ if 'biomed' not in name and name not in ["med_sam_train", "med_sam_test", "vlp_val", "vlp_captioning_val", "vlp_val2017", "vlp_captioning_val2017", "imagenet_val", "refcocog_val_google", "phrasecut_val", "phrasecut_test", "refcocop_val_unc", "refcoco_val_unc", "refcocog_val_umd", "pascalvoc_val_Point", "grounding_coco_entity_val", "vlp_coco_entity_val"]:
+ assert False, "dataset switcher is not defined"
+
+ for key, value in mappings.items():
+ if key == 'SEMANTIC_ON':
+ model.model.semantic_on = value
+ if key == 'INSTANCE_ON':
+ model.model.instance_on = value
+ if key == 'PANOPTIC_ON':
+ model.model.panoptic_on = value
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3e8582951d19f641482c86c66b7c7dc76c0730e8
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,54 @@
+# --no-cache-dir
+pillow==9.4.0
+opencv-python==4.8.1.78
+# torch==2.5.1
+pyyaml==6.0.1
+json_tricks==3.17.3
+yacs==0.1.8
+scikit-learn==1.3.1
+pandas==2.0.3
+timm==0.4.12
+numpy==1.26.4
+einops==0.8.0
+fvcore==0.1.5.post20221221
+transformers==4.34.0
+sentencepiece==0.1.99
+ftfy==6.1.1
+regex==2023.10.3
+nltk==3.8.1
+pydicom
+nibabel
+SimpleITK
+vision-datasets==0.2.2
+cython==3.0.2
+pycocotools==2.0.7
+diffdist==0.1
+scikit-image==0.21.0
+mup==1.0.0
+accelerate==0.23.0
+kornia==0.7.0
+infinibatch==0.1.1
+open-clip-torch==2.26.1
+
+git+https://github.com/MaureenZOU/detectron2-xyz.git
+
+antlr4-python3-runtime==4.9.3
+appdirs==1.4.4
+black==21.4b2
+cloudpickle==3.0.0
+hjson==3.1.0
+huggingface-hub==0.17.3
+hydra-core==1.3.2
+imageio==2.35.1
+iopath==0.1.9
+mypy-extensions==1.0.0
+ninja==1.11.1.1
+omegaconf==2.3.0
+pathspec==0.12.1
+portalocker==2.10.1
+py-cpuinfo==9.0.0
+pydantic==1.10.18
+pydot==3.0.1
+tabulate==0.9.0
+termcolor==2.4.0
+tokenizers==0.14.1
diff --git a/trainer/__init__.py b/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaab4b4d6c77357e89c74147fd096507af8ca329
--- /dev/null
+++ b/trainer/__init__.py
@@ -0,0 +1 @@
+from .xdecoder_trainer import *
\ No newline at end of file
diff --git a/trainer/default_trainer.py b/trainer/default_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd928d1f0949e3b99e3ab8845259ebaeb4a8637b
--- /dev/null
+++ b/trainer/default_trainer.py
@@ -0,0 +1,305 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+from datetime import datetime
+import time
+import os
+import sys
+import importlib
+import json
+import random
+#import wandb
+import logging
+import numpy as np
+import copy
+import contextlib
+import shutil
+from typing import Any, Callable, Union
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.optim.lr_scheduler as lr_scheduler
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from mpi4py import MPI
+from infinibatch import iterators
+
+from .distributed_trainer import DistributedTrainer
+from .utils_trainer import UtilsTrainer
+from .utils.misc import *
+from .utils.serialization import JSONEncoder, filter_jsonable
+
+logger = logging.getLogger(__name__)
+
+
+class DefaultTrainer(UtilsTrainer, DistributedTrainer):
+
+ def __init__(self, opt):
+ """
+ Set up the task the model is being trained for.
+ """
+ super().__init__(opt)
+ base_name = 'base_dir'
+ base_path = os.path.join(self.opt['base_path'], '__init__.py')
+ spec = importlib.util.spec_from_file_location(base_name, base_path)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[base_name] = module
+ spec.loader.exec_module(module)
+ logger.info(f"Imported {base_name} at base_path {self.opt['base_path']}")
+
+ pipeline_module = importlib.import_module(f"base_dir.pipeline.{self.opt['PIPELINE']}")
+ pipeline_class = getattr(pipeline_module, self.opt['PIPELINE'])
+ logger.info(f"Pipeline for training: {self.opt['PIPELINE']}")
+ self.pipeline = pipeline_class(self.opt)
+
+ def eval(self, ):
+ logger.info('-----------------------------------------------')
+ logger.info("Evaluating model ... ")
+ self.mode = "eval"
+
+ # self.model_names, self.raw_models, self.criteria = self.pipeline.set_up_model()
+ self.raw_models = self.pipeline.initialize_model()
+ self.model_names = self.raw_models.keys()
+
+ # move models to the device
+ for module_name in self.model_names:
+ self.raw_models[module_name].to(self.opt['device'])
+
+ # load model during evaluation
+ if self.opt['WEIGHT'] and os.path.isfile(self.opt['RESUME_FROM']):
+ model_path = self.opt['RESUME_FROM']
+ self.load_model(model_path)
+ else:
+ raise ValueError(f"Model not found: {model_path}")
+
+ results = self._eval_on_set(self.save_folder)
+ return results
+
+ def _eval_on_set(self, save_folder):
+ logger.info(f"Evaluation start ...")
+ if self.opt['FP16']:
+ from torch.cuda.amp import autocast
+ with autocast():
+ results = self.pipeline.evaluate_model(self, save_folder)
+ else:
+ results = self.pipeline.evaluate_model(self, save_folder)
+ if self.opt['rank'] == 0:
+ logger.info(results)
+ return results
+
+ def compute_loss(self, forward_func, batch):
+
+ def forward(func, trainer, batch):
+ if self.opt['FP16']:
+ from torch.cuda.amp import autocast
+ with autocast():
+ loss = func(trainer, batch)
+ else:
+ loss = func(trainer, batch)
+ return loss
+
+ loss = forward(forward_func, self, batch)
+ return loss
+
+ def backward_loss(self, loss, model_names=['default']): # noqa: E252
+
+ def backward(loss_tensor):
+ if self.opt['FP16']:
+ self.grad_scaler.scale(loss_tensor).backward()
+ else:
+ loss_tensor.backward()
+
+ if self.grad_acc_steps > 1:
+ loss = loss / self.grad_acc_steps
+
+ backward(loss)
+ return loss
+
+ def update_model(self, model_name='default'):
+ if self.opt['FP16']:
+ self.grad_scaler.unscale_(self.optimizers[model_name])
+ self.grad_scaler.step(self.optimizers[model_name])
+ else:
+ self.optimizers[model_name].step()
+
+ self.optimizers[model_name].zero_grad()
+ self.train_params['optim_steps'][model_name] += 1
+ self.lr_schedulers[model_name].step()
+
+ def train_step(self, batch):
+ self.grad_acc_batches.append(batch) # support batch accumulation
+
+ if self.is_gradient_accumulation_boundary():
+ # set all modules and criteria into training mode
+ for model_name in self.model_names:
+ self.models[model_name].train()
+
+ assert len(self.grad_acc_batches) == self.grad_acc_steps
+
+ total_batch_sample = 0
+ for batch_index, batch in enumerate(self.grad_acc_batches):
+
+ loss_info, sample_size_info, extra_info = \
+ self.pipeline.forward_step(self,
+ batch,
+ self.grad_acc_batches,
+ batch_index,
+ is_distributed=(self.opt['world_size'] > 1))
+
+ self.train_loss.update_iter(loss_info)
+ total_batch_sample += sample_size_info['num_samples']
+
+ if self.opt['FP16']:
+ # Update GradScaler after an effective batch
+ self.grad_scaler.update()
+
+ # update losses and item counts of an effective batch to the AverageMeters
+ if self.opt['world_size'] > 1:
+ total_batch_sample = torch.tensor(total_batch_sample).to(self.opt['device'])
+ torch.distributed.all_reduce(total_batch_sample, torch.distributed.ReduceOp.SUM)
+ total_batch_sample = total_batch_sample.item()
+
+ self.train_params['total_batch_size'] += total_batch_sample
+ self.grad_acc_batches = []
+
+ self.train_params['num_updates'] += 1
+
+ def init_train(self):
+ self.mode = "train"
+ logger.info('-------------------------------------------------------')
+ logger.info("Training on rank: {}".format(self.opt['rank']))
+
+ self.raw_models = self.pipeline.initialize_model()
+ self.model_names = list(self.raw_models.keys())
+
+ # move models to the device
+ for module_name in self.model_names:
+ self.raw_models[module_name].to(self.opt['device'])
+
+ self.train_dataloaders = self.pipeline.get_dataloaders(self, 'train', is_evaluation=False)
+ self.train_params = {
+ "updates_per_epoch": len(self.train_dataloaders),
+ "total_batch_size": 0,
+ "num_updates": 0,
+ "optim_steps": {module_name: 0 for module_name in self.model_names},
+ "start_epoch_idx": 0,
+ "start_batch_idx": 0,
+ "current_epoch_idx": 0,
+ "current_batch_idx": 0,
+ "resume_epoch_idx": 0,
+ }
+
+ self.train_loss = LossMeter()
+ self.grad_acc_batches = []
+
+ if self.opt['CUDA']:
+ torch.cuda.empty_cache()
+
+ self.create_optimizer_and_scheduler()
+ self.models = {model_name: self.raw_models[model_name] for model_name in self.model_names}
+ self._initialize_ddp()
+
+ if self.opt.get('WEIGHT', False):
+ self.load_weight(self.opt['RESUME_FROM'], must_exist=True)
+ if self.opt.get('RESUME', False):
+ self.load_checkpoint(self.opt['RESUME_FROM'], must_exist=True)
+
+ ######################
+ # Start the main loop
+ ######################
+ if self.opt['rank'] == 0:
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(f" Num of GPUs = {self.opt['world_size']}")
+ logger.info(f" Num Epochs = {self.opt['SOLVER']['MAX_NUM_EPOCHS']}")
+ logger.info(f" Num of Mini Batches per Epoch = {self.train_params['updates_per_epoch']}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {self.opt['SOLVER']['MAX_NUM_EPOCHS'] * self.train_params['updates_per_epoch']}")
+ logger.info(f" Gradient Accumulation steps = {self.grad_acc_steps}")
+ logger.info(f" Total optimization steps = {self.opt['SOLVER']['MAX_NUM_EPOCHS'] * self.train_params['updates_per_epoch'] // self.grad_acc_steps}")
+
+ def train(self):
+ """
+ Training
+ """
+ self.init_train()
+ current_optim_steps = self._get_and_validate_current_optim_steps()
+ num_epochs = self.opt['SOLVER']['MAX_NUM_EPOCHS']
+
+ if self.opt.get('EVAL_AT_START', False):
+ results = self._eval_on_set(self.save_folder)
+ # if self.opt['rank'] == 0 and self.opt['WANDB']:
+ # wandb.log(results)
+
+ train_prev_logged_time = datetime.now()
+ for epoch in range(self.train_params['start_epoch_idx'], num_epochs):
+ self.train_params['current_epoch_idx'] = epoch
+ logger.info(f"Start epoch: {epoch} training.")
+
+ epoch_start_time = datetime.now()
+ for batch_idx, batch in enumerate(self.train_dataloaders):
+ if self.train_params['current_epoch_idx'] == self.train_params['start_epoch_idx']:
+ if batch_idx < self.train_params['start_batch_idx']: # skip the first few batches for resuming
+ continue
+
+ self.train_params['current_batch_idx'] = batch_idx
+ prev_optim_steps = current_optim_steps
+ prev_total_batch_size = self.train_params['total_batch_size']
+
+ # update
+ self.prev_optim_steps = prev_optim_steps
+ self.train_step(batch)
+
+ current_optim_steps = self._get_and_validate_current_optim_steps()
+
+ # logging
+ if prev_optim_steps != current_optim_steps: # an optimizer update was made
+ log_first = self.opt.get("LOG_FIRST", 10)
+ log_every = self.opt.get("LOG_EVERY", 100)
+ if (current_optim_steps % log_every == 0) or (epoch == 0 and current_optim_steps <= log_first): # print logging
+
+ last_lr = {}
+ for module_name in self.model_names:
+ last_lr[module_name] = self.lr_schedulers[module_name].get_last_lr()[0]
+
+ train_time_delta = (datetime.now() - train_prev_logged_time).total_seconds()
+ train_prev_logged_time = datetime.now()
+ MB = 1024.0 * 1024.0
+ memory = torch.cuda.max_memory_allocated() / MB
+
+ if self.opt['rank'] == 0:
+ # if self.opt['WANDB']:
+ # # log for wandb
+ # wb_loss_info = {key: obj.val for key, obj in self.train_loss.losses.items()}
+ # wandb.log(wb_loss_info, step=self.prev_optim_steps)
+
+ # log for terminal
+ logger.info(f"epochs[{epoch:6}] optim steps[{current_optim_steps:.0f}] "
+ f"learning rate[{', '.join([f'{key}: {val:.5e}' for key, val in last_lr.items()])}] "
+ f"train loss[{', '.join([f'{key}: {obj.val:.5f}/{obj.avg:.5f}' for key, obj in self.train_loss.losses.items()])}] "
+ # f"total_loss[{total_loss:.5f}/{total_loss_avg:.5f} "
+ f"items per batch[{self.train_params['total_batch_size'] - prev_total_batch_size}] "
+ f"items per second[{(self.train_params['total_batch_size'] - prev_total_batch_size) / train_time_delta:.2f}] "
+ f"total items[{self.train_params['total_batch_size']}] "
+ f"mini batches[{self.train_params['num_updates']:6}] "
+ f"memory[{memory:.0f}] "
+ f"epoch remaining[{str((datetime.now() - epoch_start_time) / (batch_idx + 1) * (self.train_params['updates_per_epoch'] - batch_idx - 1)).split('.')[0]}]")
+
+ # evaluate and save ckpt every epoch
+ if batch_idx + 1 == self.train_params['updates_per_epoch']:
+ if self.opt.get('SAVE_CHECKPOINT', True):
+ self.save_checkpoint(self.train_params['num_updates'])
+ results = self._eval_on_set(self.save_folder)
+ # if self.opt['rank'] == 0 and self.opt['WANDB']:
+ # wandb.log(results)
+ break
+
+ logger.info(f"This epoch takes {datetime.now() - epoch_start_time}")
+ logger.info(f"PROGRESS: {100.0 * (epoch + 1) / num_epochs:.2f}%")
+ logger.info(f"Config files are at {self.opt['conf_files']}")
+
+ # if not self.opt.get('SAVE_CHECKPOINT', True):
+ # self.save_checkpoint(self.train_params['num_updates'])
\ No newline at end of file
diff --git a/trainer/distributed_trainer.py b/trainer/distributed_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4fbd48d71e7e2ddabc3c7ccf767b152192becbd
--- /dev/null
+++ b/trainer/distributed_trainer.py
@@ -0,0 +1,124 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import os
+import logging
+from mpi4py import MPI
+
+import torch
+
+from .utils.hook import add_hook
+from .utils.mpi_adapter import MPIAdapter
+from .utils.misc import save_opt_to_yaml
+
+logger = logging.getLogger(__name__)
+
+
+class DistributedTrainer:
+ def __init__(self, opt):
+ self.opt = opt
+
+ # parse environment information for distributed training
+ adapter = MPIAdapter(self.opt['PORT'])
+ self.opt['world_size'] = adapter.world_size
+ self.opt['local_size'] = adapter.local_size
+ self.opt['rank'] = adapter.rank
+ self.opt['local_rank'] = adapter.local_rank
+
+ self.set_opt_hook()
+
+ # set up device
+ if not self.opt['CUDA']:
+ self.opt['device'] = torch.device("cpu")
+ logger.info("Using CPU")
+ else:
+ torch.cuda.set_device(self.opt['local_rank'])
+ self.opt['device'] = torch.device("cuda", self.opt['local_rank'])
+ logger.info("Using CUDA")
+
+ # init distributed training
+ adapter.log_info()
+ if torch.distributed.is_available() and self.opt['world_size'] > 1:
+ adapter.init_process_group(backend='nccl')
+
+ # save config file
+ self.save_folder = self.opt['SAVE_DIR']
+
+ if self.opt['world_size'] > 1:
+ torch.distributed.barrier()
+
+ if self.opt['rank'] == 0:
+ os.makedirs(self.save_folder, exist_ok=True)
+
+ logger.info(f"Save config file to {os.path.join(self.save_folder, 'conf_copy.yaml')}")
+ save_opt_to_yaml(self.opt, os.path.join(self.save_folder, 'conf_copy.yaml'))
+
+ # ddp: log stats and update learning rate
+ self.grad_acc_steps = self.opt['GRADIENT_ACCUMULATE_STEP']
+ logger.info(f"Base learning rate: {self.opt['SOLVER']['BASE_LR']}")
+ logger.info(f"Number of GPUs: {self.opt['world_size']}")
+ logger.info(f"Gradient accumulation steps: {self.grad_acc_steps}")
+
+ if self.opt['world_size'] > 1:
+ add_hook()
+
+ # prepare metadata for save folder
+ conf_file = self.opt['conf_files'][0]
+ if 'BASENAME' not in self.opt:
+ self.opt['BASENAME'] = os.path.basename(conf_file)
+
+ self.init_save_folder()
+
+ def set_opt_hook(self):
+ # Fill in the default values for required keywords
+ self.opt['CUDA'] = self.opt.get('CUDA', True) and torch.cuda.is_available()
+ self.opt['FP16'] = self.opt.get('FP16', False) and self.opt['CUDA']
+ self.opt['GRADIENT_ACCUMULATE_STEP'] = int(self.opt.get('GRADIENT_ACCUMULATE_STEP', 1))
+ self.opt['EVAL_PER_UPDATE_NUM'] = int(self.opt.get('EVAL_PER_UPDATE_NUM', 0))
+ self.opt['LR_SCHEDULER_PARAMS'] = self.opt.get('LR_SCHEDULER_PARAMS', {})
+
+ if 'SAVE_DIR' not in self.opt:
+ assert False, "Please initialize SAVE_DIR in your config file."
+ self.opt['SAVE_DIR'] = os.path.normpath(self.opt['SAVE_DIR'])
+ logger.info(f"Setting SAVE_DIR as {self.opt['SAVE_DIR']}")
+
+ def init_save_folder(self):
+ """
+ Initialize the save folder for logs, model, checkpoint, and evaluation.
+ """
+ runid = 1
+
+ if self.opt['world_size'] > 1:
+ torch.distributed.barrier()
+
+ if self.opt['rank'] == 0:
+ while True:
+ save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}")
+ try:
+ os.makedirs(save_folder, exist_ok=False)
+ break
+ except FileExistsError:
+ runid = runid + 1
+
+ if self.opt['world_size'] > 1:
+ torch.distributed.barrier()
+
+ if self.opt['world_size'] > 1:
+ runid = 1
+ while True:
+ save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}")
+ if not os.path.exists(save_folder):
+ break
+ else:
+ runid += 1
+
+ runid -= 1
+ save_folder = os.path.join(self.opt['SAVE_DIR'], f"{self.opt['BASENAME']}_conf~", f"run_{runid}")
+ # this second os.makedirs() call on all ranks is to force sync the save_folder creation between blobFuse and local fs
+ os.makedirs(save_folder, exist_ok=True)
+
+ self.save_folder = save_folder
\ No newline at end of file
diff --git a/trainer/utils/__init__.py b/trainer/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/trainer/utils/hook.py b/trainer/utils/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..53f46368c05a268bba6836165d82e12646426577
--- /dev/null
+++ b/trainer/utils/hook.py
@@ -0,0 +1,61 @@
+import sys
+import logging
+
+logger = logging.getLogger(__name__)
+
+_orig_except_hook = None
+
+
+def _global_except_hook(exctype, value, traceback):
+ """Catches an unhandled exception and call MPI_Abort()."""
+ try:
+ if _orig_except_hook:
+ _orig_except_hook(exctype, value, traceback)
+ else:
+ sys.__excepthook__(exctype, value, traceback)
+
+ finally:
+ import mpi4py.MPI
+ rank = mpi4py.MPI.COMM_WORLD.Get_rank()
+ logger.warning("******************************************")
+ logger.warning("DefaultTrainer:")
+ logger.warning(f" Uncaught exception on rank {rank}.")
+ logger.warning(" Calling MPI_Abort() to shut down MPI...")
+ logger.warning("******************************************")
+ logging.shutdown()
+
+ try:
+ import mpi4py.MPI
+ mpi4py.MPI.COMM_WORLD.Abort(1)
+ except Exception as e:
+ # Something is completely broken...
+ # There's nothing we can do any more
+ sys.stderr.write("Sorry, failed to stop MPI and the process may hang.\n")
+ sys.stderr.flush()
+ raise e
+
+
+def add_hook():
+ """
+ Add a global hook function that captures all unhandled exceptions.
+ The function calls MPI_Abort() to force all processes abort.
+
+ An MPI runtime is expected to kill all of its child processes
+ if one of them exits abnormally or without calling `MPI_Finalize()`.
+ However, when a Python program run on `mpi4py`, the MPI runtime
+ often fails to detect a process failure, and the rest of the processes
+ hang infinitely.
+
+ See https://github.com/chainer/chainermn/issues/236 and
+ https://mpi4py.readthedocs.io/en/stable/mpi4py.run.html for more
+ information.
+ """
+ global _orig_except_hook
+
+ if _orig_except_hook is not None:
+ logger.warning("GlobalExceptHook.add_hook() seems to be called multiple times. Ignoring.")
+ return
+
+ logger.info("Adding global except hook for the distributed job to shutdown MPI if unhandled exception is raised on some of the ranks.")
+ _orig_except_hook = sys.excepthook
+ sys.excepthook = _global_except_hook
diff --git a/trainer/utils/misc.py b/trainer/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f46374ce9e3df241d0dc67100c0961367c6598e
--- /dev/null
+++ b/trainer/utils/misc.py
@@ -0,0 +1,162 @@
+import math
+import yaml
+import logging
+from typing import Optional
+
+import torch
+from torch import Tensor
+
+logger = logging.getLogger(__name__)
+
+
+class ObjectView(object):
+ def __init__(self, d):
+ self.__dict__ = d
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value."""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1, decay=0):
+ self.val = val
+ if decay:
+ alpha = math.exp(-n / decay) # exponential decay over 100 updates
+ self.sum = alpha * self.sum + (1 - alpha) * val * n
+ self.count = alpha * self.count + (1 - alpha) * n
+ else:
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def move_batch_to_device(batch, device):
+ """
+ Move the batch to the device.
+ It should be called before feeding the batch to the model.
+
+ Args:
+ batch (torch.tensor or container of torch.tensor): input batch
+ device (torch.device): device to move the batch to
+ Returns:
+ return_batch: same type as the input batch with internal tensors moved to device
+ """
+ if torch.is_tensor(batch):
+ return_batch = batch.to(device)
+ elif isinstance(batch, list):
+ return_batch = [move_batch_to_device(t, device) for t in batch]
+ elif isinstance(batch, tuple):
+ return_batch = tuple(move_batch_to_device(t, device) for t in batch)
+ elif isinstance(batch, dict):
+ return_batch = {}
+ for k in batch:
+ return_batch[k] = move_batch_to_device(batch[k], device)
+ else:
+ logger.debug(f"Can not move type {type(batch)} to device. Skipping it in the batch.")
+ return_batch = batch
+
+ return return_batch
+
+
+def cast_batch_to_half(batch):
+ """
+ Cast the float32 tensors in a batch to float16.
+ It should be called before feeding the batch to the FP16 DeepSpeed model.
+
+ Args:
+ batch (torch.tensor or container of torch.tensor): input batch
+ Returns:
+ return_batch: same type as the input batch with internal float32 tensors casted to float16
+ """
+ if torch.is_tensor(batch):
+ if torch.is_floating_point(batch):
+ return_batch = batch.to(torch.float16)
+ else:
+ return_batch = batch
+ elif isinstance(batch, list):
+ return_batch = [cast_batch_to_half(t) for t in batch]
+ elif isinstance(batch, tuple):
+ return_batch = tuple(cast_batch_to_half(t) for t in batch)
+ elif isinstance(batch, dict):
+ return_batch = {}
+ for k in batch:
+ return_batch[k] = cast_batch_to_half(batch[k])
+ else:
+ logger.debug(f"Can not cast type {type(batch)} to float16. Skipping it in the batch.")
+ return_batch = batch
+
+ return return_batch
+
+# Adapted from https://github.com/marian-nmt/marian-dev/blob/master/src/training/exponential_smoothing.h
+def apply_exponential_smoothing(avg_params: Tensor,
+ updated_params: Tensor,
+ steps: int,
+ beta: float=0.9999, # noqa: E252
+ ref_target_words: Optional[int]=None, # noqa: E252
+ actual_target_words: Optional[int]=None): # noqa: E252
+ r'''
+ Applies exponential smoothing on a model's parameters, updating them in place.
+ Can provide improved performance compared to inference using a single checkpoint.
+
+ .. math::
+ s_{t+1} = \beta \cdot s_t + (1-\beta) \cdot p_{t+1}
+ where :math:`s_t` are the smoothed params (`avg_params`) at time :math:`t` and :math:`p_{t+1}` are the incoming
+ updated_parameters from the most recent step (time :math:`t+1`).
+
+ Args:
+ avg_params List[Tensor]:
+ Model parameters derived using the repeated average for all t < steps. Updated in-place.
+ updated_params List[Tensor]:
+ Model parameters from the latest update.
+ steps int:
+ Number of optimizer steps taken.
+ beta float:
+ Parameter that controls the decay speed. Default = 0.9999
+ ref_target_words Optional[int]:
+ Reference number of target labels expected in a batch.
+ actual_target_words Optional[int]:
+ The actual number of target labels in this batch.
+ '''
+
+ if ref_target_words is not None and actual_target_words is not None:
+ beta = beta ** (actual_target_words / ref_target_words)
+ steps = max(steps, steps * (actual_target_words / ref_target_words)) # BUG: does not account for changing batch size
+
+ # Decay parameters more quickly at the beginning to avoid retaining the random initialization
+ decay_by = min(beta, (steps + 1.) / (steps + 10))
+
+ # Equivalent to: decay_by * avg_params + (1.0 - decay_by) * updated_params
+ updated_params = updated_params.to(avg_params.dtype)
+ avg_params.copy_(decay_by * (avg_params - updated_params) + updated_params)
+
+def save_opt_to_yaml(opt, conf_file):
+ with open(conf_file, 'w', encoding='utf-8') as f:
+ yaml.dump(opt, f)
+
+class LossMeter(object):
+ def __init__(self):
+ self.reset()
+
+ def reset(self,):
+ self.losses = {}
+
+ def update_iter(self, losses):
+ for key, value in losses.items():
+ self.add(key, value)
+
+ def add(self, name, loss):
+ if name not in self.losses:
+ self.losses[name] = AverageMeter()
+ self.losses[name].update(loss)
+
+ def get(self, name):
+ if name not in self.losses:
+ return 0
+ return self.losses[name]
\ No newline at end of file
diff --git a/trainer/utils/mpi_adapter.py b/trainer/utils/mpi_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ea661282e7bb83dc431e3fba81e6bc8eb7d2e3
--- /dev/null
+++ b/trainer/utils/mpi_adapter.py
@@ -0,0 +1,141 @@
+import logging
+from mpi4py import MPI
+import os
+import re
+import subprocess
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+class MPIAdapter:
+ """
+ MPIAdapter automatically detects and analyzes the training environment for distributed training
+ and offers methods to set up distributed training jobs.
+
+ For example, it determines whether training happens on AML, Philly, or locally.
+ It also determines variables such as the world size and the rank of each GPU.
+ """
+
+ def __init__(self, port='55551', set_env_vars=True):
+ local_address = '127.0.0.1'
+ default_torch_distributed_port = port # chosen arbitrarily
+
+ if 'OMPI_COMM_WORLD_SIZE' not in os.environ:
+ # application was started without MPI
+ # default to single node with single process
+ self.env_info = 'no MPI'
+ self.world_size = 1
+ self.local_size = 1
+ self.rank = 0
+ self.local_rank = 0
+ self.master_address = local_address
+ self.master_port = default_torch_distributed_port
+ else:
+ # application was started with MPI
+ # get MPI parameters
+ self.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ self.local_size = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
+ self.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ self.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+
+ if 'PHILLY_CONTAINER_IP' in os.environ:
+ # application is running on Philly
+ # read environment variables on master node and broadcast via MPI
+ self.env_info = 'philly'
+ if self.rank == 0:
+ self.master_address = os.environ['PHILLY_CONTAINER_IP']
+ self.master_port = os.environ['PHILLY_CONTAINER_PORT_RANGE_START']
+ else:
+ self.master_address = None
+ self.master_port = None
+ self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0)
+ self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0)
+ elif "AMLK8S_NUM_WORKER" in os.environ or "AZ_CMK8S_JOB_WORK_DIR" in os.environ:
+ # application is running on AMLK8S (ITP)
+ # read master address from a specific file.
+ self.env_info = 'AMLK8S (ITP)'
+ # from: https://k8s-wiki.azureml.com/faq.html
+ regexp = r"[\s\S]*export[\s]*DLTS_SD_worker0_IP=([0-9.]+)[\s|s]*"
+ with open("/dlts-runtime/env/init.env", 'r') as f:
+ line = f.read()
+ match = re.match(regexp, line)
+ if match:
+ self.master_address = str(match.group(1))
+ else:
+ # Did not find master node ip in file. It must be a single-node
+ # debugging job with custom "mpirun" command
+ assert self.world_size == self.local_size, \
+ "It's not a single-node debugging job on AMLK8S (ITP), but no master ip is found in file."
+ self.env_info = 'single-node AMLK8S (ITP) debugging job'
+ self.master_address = local_address
+ self.master_port = default_torch_distributed_port
+ elif 'AZ_BATCH_MASTER_NODE' in os.environ:
+ # application is running on multiple nodes on AML
+ self.env_info = 'multi-node AML'
+ master_node_params = os.environ['AZ_BATCH_MASTER_NODE'].split(':')
+ self.master_address = master_node_params[0]
+ self.master_port = default_torch_distributed_port
+ elif self.world_size == self.local_size:
+ # application is running with MPI on single node
+ self.env_info = 'single-node AML or other MPI environment'
+ self.master_address = local_address
+ self.master_port = default_torch_distributed_port
+ else:
+ # multi-node MPI environment, but not Philly or AML
+ # we use "hostname -I" command on rank 0 to get the master address
+ self.env_info = 'multi-node other MPI environment'
+ if self.rank == 0:
+ hostname_cmd = ["hostname -I"]
+ result = subprocess.check_output(hostname_cmd, shell=True)
+ self.master_address = result.decode('utf-8').split()[0]
+ self.master_port = default_torch_distributed_port
+ else:
+ self.master_address = None
+ self.master_port = None
+ self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0)
+ self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0)
+
+ self.init_method_url = f'tcp://{self.master_address}:{self.master_port}'
+ if set_env_vars:
+ self._set_env_vars()
+
+ def log_info(self):
+ """
+ Logs information about distributed training environment.
+ """
+ # of not printing logger.info messages on processes with rank > 0
+ logger.warning('----------------')
+ logger.warning('MPI Adapter data')
+ logger.warning('----------------')
+ logger.warning(f'environment info: {self.env_info}')
+ logger.warning(f'init method url: {self.init_method_url}')
+ logger.warning(f'world size: {self.world_size}')
+ logger.warning(f'local size: {self.local_size}')
+ logger.warning(f'rank: {self.rank}')
+ logger.warning(f'local rank: {self.local_rank}')
+ logger.warning(f'master address: {self.master_address}')
+ logger.warning(f'master port: {self.master_port}')
+ logger.warning('----------------')
+
+ def init_process_group(self, backend):
+ """
+ Initializes the default PyTorch distributed process group.
+ """
+ # of not printing logger.info messages on processes with rank > 0
+ logger.warning('trying to initialize process group ...')
+ torch.distributed.init_process_group(backend=backend,
+ init_method=self.init_method_url,
+ world_size=self.world_size,
+ rank=self.rank)
+ logger.warning('process group initialized')
+
+ def _set_env_vars(self):
+ """
+ Sets environment variables for world size, rank, local rank, master addr, and master port.
+ """
+ os.environ['WORLD_SIZE'] = str(self.world_size)
+ os.environ['RANK'] = str(self.rank)
+ os.environ["LOCAL_RANK"] = str(self.local_rank)
+ os.environ['MASTER_ADDR'] = self.master_address
+ os.environ['MASTER_PORT'] = self.master_port
diff --git a/trainer/utils/serialization.py b/trainer/utils/serialization.py
new file mode 100644
index 0000000000000000000000000000000000000000..439a095f6449e75496a8644bb097166e78967c1f
--- /dev/null
+++ b/trainer/utils/serialization.py
@@ -0,0 +1,27 @@
+import json
+import numpy as np
+from typing import Dict
+
+
+class JSONEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ else:
+ return super(JSONEncoder, self).default(obj)
+
+
+def is_jsonable(x, json_encoder=None):
+ try:
+ json.dumps(x, cls=json_encoder)
+ return True
+ except Exception:
+ return False
+
+
+def filter_jsonable(data: Dict, json_encoder=None) -> Dict:
+ return {k: v for k, v in data.items() if is_jsonable(k, json_encoder=json_encoder) and is_jsonable(v, json_encoder=json_encoder)}
\ No newline at end of file
diff --git a/trainer/utils_trainer.py b/trainer/utils_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1f7863328c12922cd3fd38f710517b1d159c203
--- /dev/null
+++ b/trainer/utils_trainer.py
@@ -0,0 +1,194 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+from datetime import datetime
+import time
+import os
+import sys
+import importlib
+import json
+import random
+import logging
+import numpy as np
+import copy
+import contextlib
+import shutil
+from typing import Any, Callable, Union
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.optim.lr_scheduler as lr_scheduler
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from mpi4py import MPI
+from infinibatch import iterators
+
+from .distributed_trainer import DistributedTrainer
+from .utils.misc import *
+from .utils.serialization import JSONEncoder, filter_jsonable
+from utilities.distributed import get_world_size
+
+logger = logging.getLogger(__name__)
+
+
+class UtilsTrainer(DistributedTrainer):
+
+ def __init__(self, opt):
+ super().__init__(opt)
+
+ def is_gradient_accumulation_boundary(self):
+ return (self.train_params['num_updates'] + 1) % self.grad_acc_steps == 0
+
+ def get_batch_size(self, batch, module_name='default'):
+ if hasattr(self.raw_models[module_name], 'get_batch_size'):
+ if callable(self.raw_models[module_name].get_batch_size):
+ return self.raw_models[module_name].get_batch_size(batch)
+ return {}
+
+ def _initialize_ddp(self):
+ if self.opt['FP16']:
+ from torch.cuda.amp import GradScaler
+ self.grad_scaler = GradScaler()
+ logger.warning("PyTorch AMP GradScaler initialized.")
+
+ for module_name in self.model_names:
+ if self.opt['world_size'] > 1:
+ # ddp: wrap modules for distributed data parallel training
+ self.models[module_name] = nn.parallel.DistributedDataParallel(self.models[module_name],
+ device_ids=[self.opt['local_rank']],
+ output_device=self.opt['local_rank'],
+ find_unused_parameters=self.opt.get('FIND_UNUSED_PARAMETERS', True))
+
+ def _get_and_validate_current_optim_steps(self):
+ current_optim_steps = set([self.train_params['optim_steps'][module_name] for module_name in self.model_names])
+ assert len(current_optim_steps) == 1, f"All modules should be at the same optim step: {self.train_params['optim_steps']}"
+ return next(iter(current_optim_steps))
+
+ def load_model(self, load_path):
+ for module_name in self.model_names:
+ self.raw_models[module_name] = self.raw_models[module_name].from_pretrained(load_path)
+ self.raw_models[module_name].to(self.opt['device'])
+
+ def save_checkpoint(self, tag):
+ tag = str(tag).zfill(8)
+ logger.warning('Saving checkpoint...')
+
+ resume_epoch_idx = self.train_params['current_epoch_idx']
+ resume_batch_idx = self.train_params['current_batch_idx'] + 1
+
+ if resume_batch_idx == self.train_params['updates_per_epoch']:
+ self.train_params['start_batch_idx'] = 0
+ self.train_params['start_epoch_idx'] = resume_epoch_idx + 1
+ else:
+ self.train_params['start_batch_idx'] = resume_batch_idx
+ self.train_params['start_epoch_idx'] = resume_epoch_idx
+
+ save_dir = os.path.join(self.save_folder, tag)
+
+ if self.opt['world_size'] > 1:
+ torch.distributed.barrier()
+
+ if self.opt['rank'] == 0:
+ os.makedirs(self.save_folder, exist_ok=True)
+
+ if self.opt['world_size'] > 1:
+ torch.distributed.barrier()
+
+ if self.opt['rank'] == 0:
+ os.makedirs(save_dir, exist_ok=True)
+
+ if self.opt['rank'] == 0:
+ if self.opt['FP16']:
+ amp_state = self.grad_scaler.state_dict()
+ else:
+ amp_state = None
+ for module_name in self.model_names:
+ module_save_dir = os.path.join(save_dir, module_name)
+ os.makedirs(module_save_dir, exist_ok=True)
+ save_path = os.path.join(module_save_dir, 'module_training_states.pt')
+ state = {'module': self.models[module_name].state_dict(),
+ 'optimizer': self.optimizers[module_name].state_dict(),
+ 'lr_scheduler': self.lr_schedulers[module_name].state_dict(),
+ 'amp_state': amp_state,}
+ torch.save(state, save_path)
+
+ if self.opt['rank'] == 0:
+ save_path = os.path.join(save_dir, 'trainer_states.pt')
+ trainer_state = {'train_loss': self.train_loss,
+ 'train_params': self.train_params,}
+ torch.save(trainer_state, save_path)
+
+ num_retries = 0
+ while num_retries < 3:
+ try:
+ random_state_path = os.path.join(save_dir, f"random_state_rank_{self.opt['rank']:04d}")
+ random_state = {'random': random.getstate(),
+ 'numpy_random': np.random.get_state(),
+ 'torch_random': torch.get_rng_state(),
+ 'torch_cuda_random': torch.cuda.get_rng_state(device=self.opt['device']) if self.opt['CUDA'] else None
+ }
+ torch.save(random_state, random_state_path)
+ num_retries = 3
+ except Exception as err:
+ num_retries += 1
+ logger.warning(err)
+ logger.warning("Failed to save checkpoint at retry {}, waiting for 30s to retry.".format(num_retries))
+ time.sleep(30)
+
+ if self.opt['rank'] == 0:
+ for module_name in self.model_names:
+ module_save_dir = os.path.join(save_dir, module_name)
+ self.raw_models[module_name].save_pretrained(module_save_dir)
+
+ if self.opt['rank'] == 0:
+ # save the latest checkpoint location to json file
+ checkpoint_location = {'checkpoint_tag': tag,
+ 'checkpoint_path': os.path.relpath(self.save_folder, start=self.opt['SAVE_DIR'])}
+ with open(os.path.join(self.opt['SAVE_DIR'], f"resume_checkpoint.json"), 'w', encoding='utf-8') as f:
+ json.dump(checkpoint_location, f, cls=JSONEncoder)
+
+ logger.warning(f'Finished saving checkpoint and model to {save_dir}.')
+
+ def load_weight(self, checkpoint_path=None, must_exist=False):
+ self.load_model(checkpoint_path)
+ logger.warning(f'Load weights from {checkpoint_path}...')
+
+ def load_checkpoint(self, checkpoint_path=None, must_exist=False):
+ logger.warning(f'Resuming checkpoint from {checkpoint_path}...')
+
+ for model_name in self.model_names:
+ model_load_path = os.path.join(checkpoint_path, model_name, 'module_training_states.pt')
+ state = torch.load(model_load_path, map_location=self.opt['device'])
+
+ logger.warning(f'HACK to strip module from model state dict on single gpu debugging!')
+ ckpt = state['module']
+ if get_world_size() <= 1:
+ ckpt = {key.replace('module.',''):ckpt[key] for key in ckpt.keys()}
+
+ self.models[model_name].load_state_dict(ckpt)
+ self.optimizers[model_name].load_state_dict(state['optimizer'])
+ self.lr_schedulers[model_name].load_state_dict(state['lr_scheduler'])
+ if self.opt['FP16']:
+ self.grad_scaler.load_state_dict(state['amp_state'])
+
+ load_path = os.path.join(checkpoint_path, 'trainer_states.pt')
+ trainer_state = torch.load(load_path, map_location='cpu')
+ self.train_loss = trainer_state['train_loss']
+ self.train_params = trainer_state['train_params']
+
+ random_state_path = os.path.join(checkpoint_path, f"random_state_rank_{self.opt['rank']:04d}")
+ if os.path.exists(random_state_path):
+ random_state = torch.load(random_state_path, map_location='cpu')
+ random.setstate(random_state['random'])
+ np.random.set_state(random_state['numpy_random'])
+ torch.set_rng_state(random_state['torch_random'])
+ if self.opt['CUDA']:
+ torch.cuda.set_rng_state(random_state['torch_cuda_random'], device=self.opt['device'])
+ else:
+ logging.warning("Could not find random state for rank {}".format(self.opt['rank']))
+
+ logger.warning(f'Finished loading checkpoint from {checkpoint_path}.')
\ No newline at end of file
diff --git a/trainer/xdecoder_trainer.py b/trainer/xdecoder_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..18283c4c96ed91986186d906f14822c76892fc11
--- /dev/null
+++ b/trainer/xdecoder_trainer.py
@@ -0,0 +1,191 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+
+import logging
+import os
+import json
+import random
+import copy
+import itertools
+from typing import Any, Dict, List, Set, Union
+from datetime import datetime
+from mpi4py import MPI
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from detectron2.projects.deeplab import build_lr_scheduler
+from fvcore.common.config import CfgNode
+from infinibatch import iterators
+
+from utilities.distributed import is_main_process, get_world_size
+from .default_trainer import DefaultTrainer
+from .utils.serialization import JSONEncoder, filter_jsonable
+
+logger = logging.getLogger(__name__)
+
+
+class XDecoder_Trainer(DefaultTrainer):
+ """
+ Construct Mask2Former_Trainer for optimizer and lr_scheduler
+ """
+ def create_optimizer_and_scheduler(self):
+ """
+ Set up self.optimizers and self.lr_schedulers
+
+ This method initializes self.optimizers and self.lr_schedulers as dictionaries of
+ instances of the classes that OPTIMIZER and LR_SCHEDULER in the config file points to.
+ One optimizer and lr scheduler for each model in self.raw_models. They have the same keys
+ as self.raw_models.
+ """
+ self.opt['init_optimizer_in_deepspeed'] = False
+ self.opt['init_lr_scheduler_in_deepspeed'] = False
+
+ self.optimizers = {module_name: None for module_name in self.model_names}
+ self.lr_schedulers = {module_name: None for module_name in self.model_names}
+
+ cfg_solver = self.opt['SOLVER']
+ weight_decay_norm = cfg_solver['WEIGHT_DECAY_NORM']
+ weight_decay_embed = cfg_solver['WEIGHT_DECAY_EMBED']
+ weight_decay_bias = cfg_solver.get('WEIGHT_DECAY_BIAS', 0.0)
+
+ defaults = {}
+ defaults["lr"] = cfg_solver['BASE_LR']
+ defaults["weight_decay"] = cfg_solver['WEIGHT_DECAY']
+
+ norm_module_types = (
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.SyncBatchNorm,
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
+ torch.nn.GroupNorm,
+ torch.nn.InstanceNorm1d,
+ torch.nn.InstanceNorm2d,
+ torch.nn.InstanceNorm3d,
+ torch.nn.LayerNorm,
+ torch.nn.LocalResponseNorm,
+ )
+
+ fix_param = self.opt['SOLVER'].get('FIX_PARAM',{})
+ ignore_fix = self.opt['SOLVER'].get('IGNORE_FIX',[])
+ for _module_name in self.model_names:
+
+ flag_continue = False
+ module_params = {}
+ for name, param in self.raw_models[_module_name].named_parameters():
+ for ig in ignore_fix:
+ if ig in name:
+ flag_continue = True
+ break
+
+ if flag_continue:
+ flag_continue = False
+ continue
+
+ for key, value in fix_param.items():
+ if key in name and value == True:
+ param.requires_grad = False
+
+ if key in name:
+ if key not in module_params:
+ module_params[key] = 0
+ module_params[key] += param.numel()
+
+ logger.info(f"Module {_module_name} has parameters: {module_params}")
+ #raise NotImplementedError("Please check the fix_param and ignore_fix in the config file")
+
+ lr_multiplier = self.opt['SOLVER']['LR_MULTIPLIER']
+
+ for _module_name in self.model_names:
+ # parameters = self.raw_models[module_name].get_training_parameters()
+ # self.optimizers[module_name] = optimizer_class(parameters, **optimizer_parameters)
+ # params = []
+ # for module_param_name, value in self.raw_models[module_name].named_parameters(recurse=True):
+ params: List[Dict[str, Any]] = []
+ memo: Set[torch.nn.parameter.Parameter] = set()
+ for module_name, module in self.raw_models[_module_name].named_modules():
+ for module_param_name, value in module.named_parameters(recurse=False):
+ if not value.requires_grad:
+ continue
+ # Avoid duplicating parameters
+ if value in memo:
+ continue
+ memo.add(value)
+
+ hyperparams = copy.copy(defaults)
+
+ for key, lr_mul in lr_multiplier.items():
+ if key in "{}.{}".format(module_name, module_param_name):
+ hyperparams["lr"] = hyperparams["lr"] * lr_mul
+ if is_main_process():
+ logger.info("Modify Learning rate of {}: {}".format("{}.{}".format(module_name, module_param_name), lr_mul))
+
+ if (
+ "relative_position_bias_table" in module_param_name
+ or "absolute_pos_embed" in module_param_name
+ ):
+ hyperparams["weight_decay"] = 0.0
+ if isinstance(module, norm_module_types):
+ hyperparams["weight_decay"] = weight_decay_norm
+ if isinstance(module, torch.nn.Embedding):
+ hyperparams["weight_decay"] = weight_decay_embed
+ if "bias" in module_name:
+ hyperparams["weight_decay"] = weight_decay_bias
+ params.append({"params": [value], **hyperparams})
+
+ def maybe_add_full_model_gradient_clipping(optim):
+ # detectron2 doesn't have full model gradient clipping now
+ clip_norm_val = cfg_solver['CLIP_GRADIENTS']['CLIP_VALUE']
+ enable = (
+ cfg_solver['CLIP_GRADIENTS']['ENABLED']
+ and cfg_solver['CLIP_GRADIENTS']['CLIP_TYPE'] == "full_model"
+ and clip_norm_val > 0.0
+ )
+
+ class FullModelGradientClippingOptimizer(optim):
+ def step(self, closure=None):
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
+ super().step(closure=closure)
+
+ return FullModelGradientClippingOptimizer if enable else optim
+
+ optimizer_type = cfg_solver['OPTIMIZER']
+ if optimizer_type == "SGD":
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
+ params, cfg_solver['BASE_LR'], momentum=cfg_solver['MOMENTUM']
+ )
+ elif optimizer_type == "ADAMW":
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
+ params, cfg_solver['BASE_LR']
+ )
+ else:
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
+
+ self.optimizers[_module_name] = optimizer
+ self.optimizers[_module_name].zero_grad()
+
+ num_epoch = self.opt['SOLVER']['MAX_NUM_EPOCHS']
+ cfg_solver['MAX_ITER'] = num_epoch * self.train_params['updates_per_epoch']
+ cfg_solver['STEPS'] = [int(x*cfg_solver['MAX_ITER']) for x in cfg_solver['STEPS']]
+ logger.info(f"Calculate MAX_ITER @ {cfg_solver['MAX_ITER']} and STEPS @ {cfg_solver['STEPS']}")
+
+ for module_name in self.model_names:
+ scheduler_cfg = CfgNode({'SOLVER': cfg_solver})
+ self.lr_schedulers[module_name] = build_lr_scheduler(scheduler_cfg, self.optimizers[module_name])
+
+ for module_name in self.model_names:
+ num_params = 0
+ num_trainable_params = 0
+ for name, param in self.raw_models[module_name].named_parameters():
+ num_params += param.numel()
+ if param.requires_grad:
+ num_trainable_params += param.numel()
+ logger.info(f"Total number of parameters in {module_name} module (on each GPU): {num_params}")
+ logger.info(f"Number of trainable parameters in {module_name} module (on each GPU): {num_trainable_params}")
\ No newline at end of file
diff --git a/utilities/Config.py b/utilities/Config.py
new file mode 100644
index 0000000000000000000000000000000000000000..86f5dc6086bd0427c6ef0aa9c2d2ab1caf57d4d3
--- /dev/null
+++ b/utilities/Config.py
@@ -0,0 +1,27 @@
+from fvcore.common.config import CfgNode as _CfgNode
+
+
+class CfgNode(_CfgNode):
+ """
+ The same as `fvcore.common.config.CfgNode`, but different in:
+
+ 1. Use unsafe yaml loading by default.
+ Note that this may lead to arbitrary code execution: you must not
+ load a config file from untrusted sources before manually inspecting
+ the content of the file.
+ 2. Support config versioning.
+ When attempting to merge an old config, it will convert the old config automatically.
+
+ .. automethod:: clone
+ .. automethod:: freeze
+ .. automethod:: defrost
+ .. automethod:: is_frozen
+ .. automethod:: load_yaml_with_base
+ .. automethod:: merge_from_list
+ .. automethod:: merge_from_other_cfg
+ """
+
+ def merge_from_dict(self, dict):
+ pass
+
+node = CfgNode()
\ No newline at end of file
diff --git a/utilities/__init__.py b/utilities/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecb3c97f07feb8a9ba130de89f4fdd4354d8f906
--- /dev/null
+++ b/utilities/__init__.py
@@ -0,0 +1,2 @@
+from .prompt_engineering import *
+from .dataset import *
\ No newline at end of file
diff --git a/utilities/arguments.py b/utilities/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f2693268a2318cb843ee81f52738a7e9b287acc
--- /dev/null
+++ b/utilities/arguments.py
@@ -0,0 +1,90 @@
+import yaml
+import json
+import argparse
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def load_config_dict_to_opt(opt, config_dict):
+ """
+ Load the key, value pairs from config_dict to opt, overriding existing values in opt
+ if there is any.
+ """
+ if not isinstance(config_dict, dict):
+ raise TypeError("Config must be a Python dictionary")
+ for k, v in config_dict.items():
+ k_parts = k.split('.')
+ pointer = opt
+ for k_part in k_parts[:-1]:
+ if k_part not in pointer:
+ pointer[k_part] = {}
+ pointer = pointer[k_part]
+ assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
+ ori_value = pointer.get(k_parts[-1])
+ pointer[k_parts[-1]] = v
+ if ori_value:
+ logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
+
+
+def load_opt_from_config_files(conf_files):
+ """
+ Load opt from the config files, settings in later files can override those in previous files.
+
+ Args:
+ conf_files (list): a list of config file paths
+
+ Returns:
+ dict: a dictionary of opt settings
+ """
+ opt = {}
+ for conf_file in conf_files:
+ with open(conf_file, encoding='utf-8') as f:
+ config_dict = yaml.safe_load(f)
+
+ load_config_dict_to_opt(opt, config_dict)
+
+ return opt
+
+
+def load_opt_command(args):
+ parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.')
+ parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate')
+ parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the config file(s).')
+ parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.')
+ parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"": , "..": }. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
+ parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER)
+
+ cmdline_args = parser.parse_args() if not args else parser.parse_args(args)
+
+ opt = load_opt_from_config_files(cmdline_args.conf_files)
+
+ if cmdline_args.config_overrides:
+ config_overrides_string = ' '.join(cmdline_args.config_overrides)
+ logger.warning(f"Command line config overrides: {config_overrides_string}")
+ config_dict = json.loads(config_overrides_string)
+ load_config_dict_to_opt(opt, config_dict)
+
+ if cmdline_args.overrides:
+ assert len(cmdline_args.overrides) % 2 == 0, "overrides arguments is not paired, required: key value"
+ keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)]
+ vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)]
+ vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals]
+
+ types = []
+ for key in keys:
+ key = key.split('.')
+ ele = opt.copy()
+ while len(key) > 0:
+ ele = ele[key.pop(0)]
+ types.append(type(ele))
+
+ config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)}
+ load_config_dict_to_opt(opt, config_dict)
+
+ # combine cmdline_args into opt dictionary
+ for key, val in cmdline_args.__dict__.items():
+ if val is not None:
+ opt[key] = val
+
+ return opt, cmdline_args
\ No newline at end of file
diff --git a/utilities/constants.py b/utilities/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..2856a91b65bca17c84948bbbd585eb7a219a4cc5
--- /dev/null
+++ b/utilities/constants.py
@@ -0,0 +1,135 @@
+BIOMED_CLASSES = ['liver', 'lung', 'kidney', 'pancreas', 'heart anatomies', 'brain anatomies',
+ 'eye anatomies', 'vessel', 'other organ', 'tumor', 'infection', 'other lesion',
+ 'fluid disturbance', 'other abnormality', 'histology structure', 'other']
+
+
+BIOMED_HIERARCHY = {'CT': {'abdomen': ['liver', 'left kidney', 'right kidney', 'panreas', 'spleen', 'stomach', ],
+ 'liver': ['liver', 'tumor', 'vessel'],
+ 'pancreas': ['pancreas', 'tumor'],
+ 'kidney': ['kidney', 'tumor', 'kidney cyst'],
+ 'lung': ['tumor', 'nodule', 'COVID-19 infection'],
+ 'colon': ['tumor']},
+ 'MRI': {'abdomen': ['liver', 'left kidney', 'right kidney', 'panreas', 'spleen', 'stomach', ],
+ 'prostate': ['prostate transitional zone', 'prostate peripheral zone'],
+ 'cardiac': ['left heart ventricle', 'right heart ventricle', 'myocardium']},
+ 'X-Ray': {'chest': ['lung', 'left lung', 'right lung', 'COVID-19 infection']},
+ 'ultrasound': {'cardiac': ['left heart ventricle', 'left heart atrium'],
+ 'transperineal': ['public symphysis', 'fetal head']},
+ 'fundus': {'retinal': ['optic disc', 'optic cup', 'retinal vessel']}}
+
+
+COCO_PANOPTIC_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff', 'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield', 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow', 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'window-blind', 'window-other', 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', 'cabinet-merged', 'table-merged', 'floor-other-merged', 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged', 'paper-merged', 'food-other-merged', 'building-other-merged', 'rock-merged', 'wall-other-merged', 'rug-merged']
+
+ADE_PANOPTIC_CLASSES = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'window', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'tub', 'rail', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'street lamp', 'booth', 'tv', 'airplane', 'dirt track', 'clothes', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'pool', 'stool', 'barrel', 'basket', 'falls', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag']
+
+ADE20K_847 = ['wall', 'building', 'sky', 'tree', 'road', 'floor', 'ceiling', 'bed', 'sidewalk', 'earth', 'cabinet', 'person', 'grass', 'windowpane', 'car', 'mountain', 'plant', 'table', 'chair', 'curtain', 'door', 'sofa', 'sea', 'painting', 'water', 'mirror', 'house', 'rug', 'shelf', 'armchair', 'fence', 'field', 'lamp', 'rock', 'seat', 'river', 'desk', 'bathtub', 'railing', 'signboard', 'cushion', 'path', 'work surface', 'stairs', 'column', 'sink', 'wardrobe', 'snow', 'refrigerator', 'base', 'bridge', 'blind', 'runway', 'cliff', 'sand', 'fireplace', 'pillow', 'screen door', 'toilet', 'skyscraper', 'grandstand', 'box', 'pool table', 'palm', 'double door', 'coffee table', 'counter', 'countertop', 'chest of drawers', 'kitchen island', 'boat', 'waterfall', 'stove', 'flower', 'bookcase', 'controls', 'book', 'stairway', 'streetlight', 'computer', 'bus', 'swivel chair', 'light', 'bench', 'case', 'towel', 'fountain', 'embankment', 'television receiver', 'van', 'hill', 'awning', 'poster', 'truck', 'airplane', 'pole', 'tower', 'court', 'ball', 'aircraft carrier', 'buffet', 'hovel', 'apparel', 'minibike', 'animal', 'chandelier', 'step', 'booth', 'bicycle', 'doorframe', 'sconce', 'pond', 'trade name', 'bannister', 'bag', 'traffic light', 'gazebo', 'escalator', 'land', 'board', 'arcade machine', 'eiderdown', 'bar', 'stall', 'playground', 'ship', 'ottoman', 'ashcan', 'bottle', 'cradle', 'pot', 'conveyer belt', 'train', 'stool', 'lake', 'tank', 'ice', 'basket', 'manhole', 'tent', 'canopy', 'microwave', 'barrel', 'dirt track', 'beam', 'dishwasher', 'plate', 'screen', 'ruins', 'washer', 'blanket', 'plaything', 'food', 'screen', 'oven', 'stage', 'beacon', 'umbrella', 'sculpture', 'aqueduct', 'container', 'scaffolding', 'hood', 'curb', 'roller coaster', 'horse', 'catwalk', 'glass', 'vase', 'central reservation', 'carousel', 'radiator', 'closet', 'machine', 'pier', 'fan', 'inflatable bounce game', 'pitch', 'paper', 'arcade', 'hot tub', 'helicopter', 'tray', 'partition', 'vineyard', 'bowl', 'bullring', 'flag', 'pot', 'footbridge', 'shower', 'bag', 'bulletin board', 'confessional booth', 'trunk', 'forest', 'elevator door', 'laptop', 'instrument panel', 'bucket', 'tapestry', 'platform', 'jacket', 'gate', 'monitor', 'telephone booth', 'spotlight', 'ring', 'control panel', 'blackboard', 'air conditioner', 'chest', 'clock', 'sand dune', 'pipe', 'vault', 'table football', 'cannon', 'swimming pool', 'fluorescent', 'statue', 'loudspeaker', 'exhibitor', 'ladder', 'carport', 'dam', 'pulpit', 'skylight', 'water tower', 'grill', 'display board', 'pane', 'rubbish', 'ice rink', 'fruit', 'patio', 'vending machine', 'telephone', 'net', 'backpack', 'jar', 'track', 'magazine', 'shutter', 'roof', 'banner', 'landfill', 'post', 'altarpiece', 'hat', 'arch', 'table game', 'bag', 'document', 'dome', 'pier', 'shanties', 'forecourt', 'crane', 'dog', 'piano', 'drawing', 'cabin', 'ad', 'amphitheater', 'monument', 'henhouse', 'cockpit', 'heater', 'windmill', 'pool', 'elevator', 'decoration', 'labyrinth', 'text', 'printer', 'mezzanine', 'mattress', 'straw', 'stalls', 'patio', 'billboard', 'bus stop', 'trouser', 'console table', 'rack', 'notebook', 'shrine', 'pantry', 'cart', 'steam shovel', 'porch', 'postbox', 'figurine', 'recycling bin', 'folding screen', 'telescope', 'deck chair', 'kennel', 'coffee maker', 'altar', 'fish', 'easel', 'artificial golf green', 'iceberg', 'candlestick', 'shower stall', 'television stand', 'wall socket', 'skeleton', 'grand piano', 'candy', 'grille door', 'pedestal', 'jersey', 'shoe', 'gravestone', 'shanty', 'structure', 'rocking chair', 'bird', 'place mat', 'tomb', 'big top', 'gas pump', 'lockers', 'cage', 'finger', 'bleachers', 'ferris wheel', 'hairdresser chair', 'mat', 'stands', 'aquarium', 'streetcar', 'napkin', 'dummy', 'booklet', 'sand trap', 'shop', 'table cloth', 'service station', 'coffin', 'drawer', 'cages', 'slot machine', 'balcony', 'volleyball court', 'table tennis', 'control table', 'shirt', 'merchandise', 'railway', 'parterre', 'chimney', 'can', 'tanks', 'fabric', 'alga', 'system', 'map', 'greenhouse', 'mug', 'barbecue', 'trailer', 'toilet tissue', 'organ', 'dishrag', 'island', 'keyboard', 'trench', 'basket', 'steering wheel', 'pitcher', 'goal', 'bread', 'beds', 'wood', 'file cabinet', 'newspaper', 'motorboat', 'rope', 'guitar', 'rubble', 'scarf', 'barrels', 'cap', 'leaves', 'control tower', 'dashboard', 'bandstand', 'lectern', 'switch', 'baseboard', 'shower room', 'smoke', 'faucet', 'bulldozer', 'saucepan', 'shops', 'meter', 'crevasse', 'gear', 'candelabrum', 'sofa bed', 'tunnel', 'pallet', 'wire', 'kettle', 'bidet', 'baby buggy', 'music stand', 'pipe', 'cup', 'parking meter', 'ice hockey rink', 'shelter', 'weeds', 'temple', 'patty', 'ski slope', 'panel', 'wallet', 'wheel', 'towel rack', 'roundabout', 'canister', 'rod', 'soap dispenser', 'bell', 'canvas', 'box office', 'teacup', 'trellis', 'workbench', 'valley', 'toaster', 'knife', 'podium', 'ramp', 'tumble dryer', 'fireplug', 'gym shoe', 'lab bench', 'equipment', 'rocky formation', 'plastic', 'calendar', 'caravan', 'check-in-desk', 'ticket counter', 'brush', 'mill', 'covered bridge', 'bowling alley', 'hanger', 'excavator', 'trestle', 'revolving door', 'blast furnace', 'scale', 'projector', 'soap', 'locker', 'tractor', 'stretcher', 'frame', 'grating', 'alembic', 'candle', 'barrier', 'cardboard', 'cave', 'puddle', 'tarp', 'price tag', 'watchtower', 'meters', 'light bulb', 'tracks', 'hair dryer', 'skirt', 'viaduct', 'paper towel', 'coat', 'sheet', 'fire extinguisher', 'water wheel', 'pottery', 'magazine rack', 'teapot', 'microphone', 'support', 'forklift', 'canyon', 'cash register', 'leaf', 'remote control', 'soap dish', 'windshield', 'cat', 'cue', 'vent', 'videos', 'shovel', 'eaves', 'antenna', 'shipyard', 'hen', 'traffic cone', 'washing machines', 'truck crane', 'cds', 'niche', 'scoreboard', 'briefcase', 'boot', 'sweater', 'hay', 'pack', 'bottle rack', 'glacier', 'pergola', 'building materials', 'television camera', 'first floor', 'rifle', 'tennis table', 'stadium', 'safety belt', 'cover', 'dish rack', 'synthesizer', 'pumpkin', 'gutter', 'fruit stand', 'ice floe', 'handle', 'wheelchair', 'mousepad', 'diploma', 'fairground ride', 'radio', 'hotplate', 'junk', 'wheelbarrow', 'stream', 'toll plaza', 'punching bag', 'trough', 'throne', 'chair desk', 'weighbridge', 'extractor fan', 'hanging clothes', 'dish', 'alarm clock', 'ski lift', 'chain', 'garage', 'mechanical shovel', 'wine rack', 'tramway', 'treadmill', 'menu', 'block', 'well', 'witness stand', 'branch', 'duck', 'casserole', 'frying pan', 'desk organizer', 'mast', 'spectacles', 'service elevator', 'dollhouse', 'hammock', 'clothes hanging', 'photocopier', 'notepad', 'golf cart', 'footpath', 'cross', 'baptismal font', 'boiler', 'skip', 'rotisserie', 'tables', 'water mill', 'helmet', 'cover curtain', 'brick', 'table runner', 'ashtray', 'street box', 'stick', 'hangers', 'cells', 'urinal', 'centerpiece', 'portable fridge', 'dvds', 'golf club', 'skirting board', 'water cooler', 'clipboard', 'camera', 'pigeonhole', 'chips', 'food processor', 'post box', 'lid', 'drum', 'blender', 'cave entrance', 'dental chair', 'obelisk', 'canoe', 'mobile', 'monitors', 'pool ball', 'cue rack', 'baggage carts', 'shore', 'fork', 'paper filer', 'bicycle rack', 'coat rack', 'garland', 'sports bag', 'fish tank', 'towel dispenser', 'carriage', 'brochure', 'plaque', 'stringer', 'iron', 'spoon', 'flag pole', 'toilet brush', 'book stand', 'water faucet', 'ticket office', 'broom', 'dvd', 'ice bucket', 'carapace', 'tureen', 'folders', 'chess', 'root', 'sewing machine', 'model', 'pen', 'violin', 'sweatshirt', 'recycling materials', 'mitten', 'chopping board', 'mask', 'log', 'mouse', 'grill', 'hole', 'target', 'trash bag', 'chalk', 'sticks', 'balloon', 'score', 'hair spray', 'roll', 'runner', 'engine', 'inflatable glove', 'games', 'pallets', 'baskets', 'coop', 'dvd player', 'rocking horse', 'buckets', 'bread rolls', 'shawl', 'watering can', 'spotlights', 'post-it', 'bowls', 'security camera', 'runner cloth', 'lock', 'alarm', 'side', 'roulette', 'bone', 'cutlery', 'pool balls', 'wheels', 'spice rack', 'plant pots', 'towel ring', 'bread box', 'video', 'funfair', 'breads', 'tripod', 'ironing board', 'skimmer', 'hollow', 'scratching post', 'tricycle', 'file box', 'mountain pass', 'tombstones', 'cooker', 'card game', 'golf bag', 'towel paper', 'chaise lounge', 'sun', 'toilet paper holder', 'rake', 'key', 'umbrella stand', 'dartboard', 'transformer', 'fireplace utensils', 'sweatshirts', 'cellular telephone', 'tallboy', 'stapler', 'sauna', 'test tube', 'palette', 'shopping carts', 'tools', 'push button', 'star', 'roof rack', 'barbed wire', 'spray', 'ear', 'sponge', 'racket', 'tins', 'eyeglasses', 'file', 'scarfs', 'sugar bowl', 'flip flop', 'headstones', 'laptop bag', 'leash', 'climbing frame', 'suit hanger', 'floor spotlight', 'plate rack', 'sewer', 'hard drive', 'sprinkler', 'tools box', 'necklace', 'bulbs', 'steel industry', 'club', 'jack', 'door bars', 'control panel', 'hairbrush', 'napkin holder', 'office', 'smoke detector', 'utensils', 'apron', 'scissors', 'terminal', 'grinder', 'entry phone', 'newspaper stand', 'pepper shaker', 'onions', 'central processing unit', 'tape', 'bat', 'coaster', 'calculator', 'potatoes', 'luggage rack', 'salt', 'street number', 'viewpoint', 'sword', 'cd', 'rowing machine', 'plug', 'andiron', 'pepper', 'tongs', 'bonfire', 'dog dish', 'belt', 'dumbbells', 'videocassette recorder', 'hook', 'envelopes', 'shower faucet', 'watch', 'padlock', 'swimming pool ladder', 'spanners', 'gravy boat', 'notice board', 'trash bags', 'fire alarm', 'ladle', 'stethoscope', 'rocket', 'funnel', 'bowling pins', 'valve', 'thermometer', 'cups', 'spice jar', 'night light', 'soaps', 'games table', 'slotted spoon', 'reel', 'scourer', 'sleeping robe', 'desk mat', 'dumbbell', 'hammer', 'tie', 'typewriter', 'shaker', 'cheese dish', 'sea star', 'racquet', 'butane gas cylinder', 'paper weight', 'shaving brush', 'sunglasses', 'gear shift', 'towel rail', 'adding machine']
+
+SUN_RGBD_37 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag']
+
+SCAN_37 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag']
+SCAN_40 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 'curtain', 'dresser', 'pillow', 'mirror', 'floor mat', 'clothes', 'ceiling', 'books', 'refridgerator', 'television', 'paper', 'towel', 'shower curtain', 'box', 'whiteboard', 'person', 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'bag', 'otherstructure', 'otherfurniture', 'otherprop']
+SCAN_20 = ["wall", "floor", "cabinet", "bed", "chair", "sofa", "table", "door", "window", "bookshelf", "picture", "counter", "desk", "curtain", "refrigerator", "shower curtain", "toilet", "sink", "bathtub", "otherfurniture"]
+
+CITYSCAPES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
+CITYSCAPES_THING = ["person", "rider", "car", "truck", "bus", "train", "motorcycle", "bicycle"]
+
+BDD_SEM = ["road", "sidewalk", "building", "wall", "fence", "pole", "traffic light", "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car", "truck", "bus", "train", "motorcycle", "bicycle"]
+BDD_PANO = ['dynamic', 'ego vehicle', 'ground', 'static', 'parking', 'rail track', 'road', 'sidewalk', 'bridge', 'building', 'fence', 'garage', 'guard rail', 'tunnel', 'wall', 'banner', 'billboard', 'lane divider', 'parking sign', 'pole', 'polegroup', 'street light', 'traffic cone', 'traffic device', 'traffic light', 'traffic sign', 'traffic sign frame', 'terrain', 'vegetation', 'sky', 'person', 'rider', 'bicycle', 'bus', 'car', 'caravan', 'motorcycle', 'trailer', 'train', 'truck']
+
+IMAGENET_CLASSES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "projectile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "dark glasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
+
+IMAGENET_FOLDER_NAMES = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141']
+
+IMAGENET_DEFAULT_TEMPLATES = [
+ '{}.',
+ 'a bad photo of a {}.',
+ 'a photo of many {}.',
+ 'a sculpture of a {}.',
+ 'a photo of the hard to see {}.',
+ 'a low resolution photo of the {}.',
+ 'a rendering of a {}.',
+ 'graffiti of a {}.',
+ 'a bad photo of the {}.',
+ 'a cropped photo of the {}.',
+ 'a tattoo of a {}.',
+ 'the embroidered {}.',
+ 'a photo of a hard to see {}.',
+ 'a bright photo of a {}.',
+ 'a photo of a clean {}.',
+ 'a photo of a dirty {}.',
+ 'a dark photo of the {}.',
+ 'a drawing of a {}.',
+ 'a photo of my {}.',
+ 'the plastic {}.',
+ 'a photo of the cool {}.',
+ 'a close-up photo of a {}.',
+ 'a black and white photo of the {}.',
+ 'a painting of the {}.',
+ 'a painting of a {}.',
+ 'a pixelated photo of the {}.',
+ 'a sculpture of the {}.',
+ 'a bright photo of the {}.',
+ 'a cropped photo of a {}.',
+ 'a plastic {}.',
+ 'a photo of the dirty {}.',
+ 'a jpeg corrupted photo of a {}.',
+ 'a blurry photo of the {}.',
+ 'a photo of the {}.',
+ 'a good photo of the {}.',
+ 'a rendering of the {}.',
+ 'a {} in a video game.',
+ 'a photo of one {}.',
+ 'a doodle of a {}.',
+ 'a close-up photo of the {}.',
+ 'a photo of a {}.',
+ 'the origami {}.',
+ 'the {} in a video game.',
+ 'a sketch of a {}.',
+ 'a doodle of the {}.',
+ 'a origami {}.',
+ 'a low resolution photo of a {}.',
+ 'the toy {}.',
+ 'a rendition of the {}.',
+ 'a photo of the clean {}.',
+ 'a photo of a large {}.',
+ 'a rendition of a {}.',
+ 'a photo of a nice {}.',
+ 'a photo of a weird {}.',
+ 'a blurry photo of a {}.',
+ 'a cartoon {}.',
+ 'art of a {}.',
+ 'a sketch of the {}.',
+ 'a embroidered {}.',
+ 'a pixelated photo of a {}.',
+ 'itap of the {}.',
+ 'a jpeg corrupted photo of the {}.',
+ 'a good photo of a {}.',
+ 'a plushie {}.',
+ 'a photo of the nice {}.',
+ 'a photo of the small {}.',
+ 'a photo of the weird {}.',
+ 'the cartoon {}.',
+ 'art of the {}.',
+ 'a drawing of the {}.',
+ 'a photo of the large {}.',
+ 'a black and white photo of a {}.',
+ 'the plushie {}.',
+ 'a dark photo of a {}.',
+ 'itap of a {}.',
+ 'graffiti of the {}.',
+ 'a toy {}.',
+ 'itap of my {}.',
+ 'a photo of a cool {}.',
+ 'a photo of a small {}.',
+ 'a tattoo of the {}.',
+]
+
+IMAGENET_SIMPLE_TEMPLATES = [
+ 'a photo of {}.',
+]
+
+PASCAL_CLASSES = [
+ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
+ "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
+ "pottedplant", "sheep", "sofa", "train", "tvmonitor"
+]
\ No newline at end of file
diff --git a/utilities/dataset.py b/utilities/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dbb757b416ba09077c4ec8552ee3027df05006a
--- /dev/null
+++ b/utilities/dataset.py
@@ -0,0 +1,41 @@
+
+class Entity(object):
+ def __init__(self, _id, _text, _mask, _interactive, _type, _start_idx, _end_idx, _image=None):
+ self.id = _id
+ self.text = _text
+ self.mask = _mask
+ self.interactive = _interactive
+ self.type = _type
+ self.start_idx = _start_idx
+ self.end_idx = _end_idx
+
+ self.image = _image
+
+def split_by_ordered_substrings(sentence, substrings):
+ results = []
+ substring_indices = []
+
+ start_index = 0
+ for i, substring in enumerate(substrings):
+ # Find the start of the substring in the remaining part of the sentence
+ index = sentence[start_index:].find(substring)
+
+ if index == -1:
+ continue
+
+ # Append any text before the substring to the results, including spaces
+ if index > 0:
+ results.append(sentence[start_index:start_index+index])
+ substring_indices.append(None) # No match in the `substrings` list for this segment
+
+ # Append the substring to the results
+ results.append(substring)
+ substring_indices.append(i) # Append the index from the `substrings` list
+ start_index += index + len(substring)
+
+ # If there's any remaining part of the sentence after all substrings, append it to the results
+ if start_index < len(sentence):
+ results.append(sentence[start_index:])
+ substring_indices.append(None) # No match in the `substrings` list for this segment
+
+ return results, substring_indices
diff --git a/utilities/distributed.py b/utilities/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7d346dc4cb465aa501cda356c99e180ed97f7fe
--- /dev/null
+++ b/utilities/distributed.py
@@ -0,0 +1,112 @@
+import os
+import time
+import torch
+import pickle
+import subprocess
+
+# from mpi4py import MPI
+import torch.distributed as dist
+
+
+def apply_distributed(opt):
+ if opt['rank'] == 0:
+ hostname_cmd = ["hostname -I"]
+ result = subprocess.check_output(hostname_cmd, shell=True)
+ master_address = result.decode('utf-8').split()[0]
+ master_port = opt['PORT']
+ else:
+ master_address = None
+ master_port = None
+
+ master_address = MPI.COMM_WORLD.bcast(master_address, root=0)
+ master_port = MPI.COMM_WORLD.bcast(master_port, root=0)
+
+ if torch.distributed.is_available() and opt['world_size'] > 1:
+ init_method_url = 'tcp://{}:{}'.format(master_address, master_port)
+ backend = 'nccl'
+ world_size = opt['world_size']
+ rank = opt['rank']
+ torch.distributed.init_process_group(backend=backend,
+ init_method=init_method_url,
+ world_size=world_size,
+ rank=rank)
+
+def init_distributed(opt):
+ opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available()
+ if 'OMPI_COMM_WORLD_SIZE' not in os.environ:
+ # application was started without MPI
+ # default to single node with single process
+ opt['env_info'] = 'no MPI'
+ opt['world_size'] = 1
+ opt['local_size'] = 1
+ opt['rank'] = 0
+ opt['local_rank'] = 0
+ opt['master_address'] = '127.0.0.1'
+ opt['master_port'] = '8673'
+ else:
+ # application was started with MPI
+ # get MPI parameters
+ opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
+ opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+
+ # set up device
+ if not opt['CUDA']:
+ assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend'
+ opt['device'] = torch.device("cpu")
+ else:
+ torch.cuda.set_device(opt['local_rank'])
+ opt['device'] = torch.device("cuda", opt['local_rank'])
+
+ apply_distributed(opt)
+ return opt
+
+def is_main_process():
+ rank = 0
+ if 'OMPI_COMM_WORLD_SIZE' in os.environ:
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+
+ return rank == 0
+
+def get_world_size():
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+def get_rank():
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ if world_size == 1:
+ return
+
+ def _send_and_wait(r):
+ if rank == r:
+ tensor = torch.tensor(0, device="cuda")
+ else:
+ tensor = torch.tensor(1, device="cuda")
+ dist.broadcast(tensor, r)
+ while tensor.item() == 1:
+ time.sleep(1)
+
+ _send_and_wait(0)
+ # now sync on the main process
+ _send_and_wait(1)
\ No newline at end of file
diff --git a/utilities/misc.py b/utilities/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9e55caf20915b6bd80e416b449ef23c255043c3
--- /dev/null
+++ b/utilities/misc.py
@@ -0,0 +1,31 @@
+# --------------------------------------------------------
+# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Xueyan Zou (xueyan@cs.wisc.edu)
+# --------------------------------------------------------
+import math
+
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value."""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1, decay=0):
+ self.val = val
+ if decay:
+ alpha = math.exp(-n / decay) # exponential decay over 100 updates
+ self.sum = alpha * self.sum + (1 - alpha) * val * n
+ self.count = alpha * self.count + (1 - alpha) * n
+ else:
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
diff --git a/utilities/model.py b/utilities/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f87e38b390cba21c55662e8435ed2cec25744b15
--- /dev/null
+++ b/utilities/model.py
@@ -0,0 +1,60 @@
+import logging
+import os
+import time
+import pickle
+import torch
+import torch.nn as nn
+
+from utilities.distributed import is_main_process
+
+logger = logging.getLogger(__name__)
+
+
+NORM_MODULES = [
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.SyncBatchNorm,
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
+ torch.nn.GroupNorm,
+ torch.nn.InstanceNorm1d,
+ torch.nn.InstanceNorm2d,
+ torch.nn.InstanceNorm3d,
+ torch.nn.LayerNorm,
+ torch.nn.LocalResponseNorm,
+]
+
+def register_norm_module(cls):
+ NORM_MODULES.append(cls)
+ return cls
+
+def align_and_update_state_dicts(model_state_dict, ckpt_state_dict):
+ model_keys = sorted(model_state_dict.keys())
+ ckpt_keys = sorted(ckpt_state_dict.keys())
+ result_dicts = {}
+ matched_log = []
+ unmatched_log = []
+ unloaded_log = []
+ for model_key in model_keys:
+ model_weight = model_state_dict[model_key]
+ if model_key in ckpt_keys:
+ ckpt_weight = ckpt_state_dict[model_key]
+ if model_weight.shape == ckpt_weight.shape:
+ result_dicts[model_key] = ckpt_weight
+ ckpt_keys.pop(ckpt_keys.index(model_key))
+ matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
+ else:
+ unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
+ else:
+ unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape))
+
+ if is_main_process():
+ for info in matched_log:
+ logger.info(info)
+ for info in unloaded_log:
+ logger.warning(info)
+ for key in ckpt_keys:
+ logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape))
+ for info in unmatched_log:
+ logger.warning(info)
+ return result_dicts
\ No newline at end of file
diff --git a/utilities/prompt_engineering.py b/utilities/prompt_engineering.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a84ca628eedd32e2c560e99005ef4c260957e73
--- /dev/null
+++ b/utilities/prompt_engineering.py
@@ -0,0 +1,99 @@
+import numpy as np
+
+
+def get_prompt_templates():
+ prompt_templates = [
+ '{}.',
+ 'a photo of a {}.',
+ 'a bad photo of a {}.',
+ 'a photo of many {}.',
+ 'a sculpture of a {}.',
+ 'a photo of the hard to see {}.',
+ 'a low resolution photo of the {}.',
+ 'a rendering of a {}.',
+ 'graffiti of a {}.',
+ 'a bad photo of the {}.',
+ 'a cropped photo of the {}.',
+ 'a tattoo of a {}.',
+ 'the embroidered {}.',
+ 'a photo of a hard to see {}.',
+ 'a bright photo of a {}.',
+ 'a photo of a clean {}.',
+ 'a photo of a dirty {}.',
+ 'a dark photo of the {}.',
+ 'a drawing of a {}.',
+ 'a photo of my {}.',
+ 'the plastic {}.',
+ 'a photo of the cool {}.',
+ 'a close-up photo of a {}.',
+ 'a black and white photo of the {}.',
+ 'a painting of the {}.',
+ 'a painting of a {}.',
+ 'a pixelated photo of the {}.',
+ 'a sculpture of the {}.',
+ 'a bright photo of the {}.',
+ 'a cropped photo of a {}.',
+ 'a plastic {}.',
+ 'a photo of the dirty {}.',
+ 'a jpeg corrupted photo of a {}.',
+ 'a blurry photo of the {}.',
+ 'a photo of the {}.',
+ 'a good photo of the {}.',
+ 'a rendering of the {}.',
+ 'a {} in a video game.',
+ 'a photo of one {}.',
+ 'a doodle of a {}.',
+ 'a close-up photo of the {}.',
+ 'the origami {}.',
+ 'the {} in a video game.',
+ 'a sketch of a {}.',
+ 'a doodle of the {}.',
+ 'a origami {}.',
+ 'a low resolution photo of a {}.',
+ 'the toy {}.',
+ 'a rendition of the {}.',
+ 'a photo of the clean {}.',
+ 'a photo of a large {}.',
+ 'a rendition of a {}.',
+ 'a photo of a nice {}.',
+ 'a photo of a weird {}.',
+ 'a blurry photo of a {}.',
+ 'a cartoon {}.',
+ 'art of a {}.',
+ 'a sketch of the {}.',
+ 'a embroidered {}.',
+ 'a pixelated photo of a {}.',
+ 'itap of the {}.',
+ 'a jpeg corrupted photo of the {}.',
+ 'a good photo of a {}.',
+ 'a plushie {}.',
+ 'a photo of the nice {}.',
+ 'a photo of the small {}.',
+ 'a photo of the weird {}.',
+ 'the cartoon {}.',
+ 'art of the {}.',
+ 'a drawing of the {}.',
+ 'a photo of the large {}.',
+ 'a black and white photo of a {}.',
+ 'the plushie {}.',
+ 'a dark photo of a {}.',
+ 'itap of a {}.',
+ 'graffiti of the {}.',
+ 'a toy {}.',
+ 'itap of my {}.',
+ 'a photo of a cool {}.',
+ 'a photo of a small {}.',
+ 'a tattoo of the {}.',
+ ]
+ return prompt_templates
+
+def prompt_engineering(classnames, topk=1, suffix='.'):
+ prompt_templates = get_prompt_templates()
+ temp_idx = np.random.randint(min(len(prompt_templates), topk))
+
+ if isinstance(classnames, list):
+ classname = random.choice(classnames)
+ else:
+ classname = classnames
+
+ return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' '))
\ No newline at end of file
diff --git a/utilities/visualizer.py b/utilities/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c32a2f9bdccd0d3679a5f23ee4852e73a8ec681
--- /dev/null
+++ b/utilities/visualizer.py
@@ -0,0 +1,1279 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import colorsys
+import logging
+import math
+import numpy as np
+from enum import Enum, unique
+import cv2
+import matplotlib as mpl
+import matplotlib.colors as mplc
+import matplotlib.figure as mplfigure
+import pycocotools.mask as mask_util
+import torch
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+from PIL import Image
+
+from detectron2.data import MetadataCatalog
+from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
+from detectron2.utils.file_io import PathManager
+
+from detectron2.utils.colormap import random_color
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["ColorMode", "VisImage", "Visualizer"]
+
+
+_SMALL_OBJECT_AREA_THRESH = 1000
+_LARGE_MASK_AREA_THRESH = 120000
+_OFF_WHITE = (1.0, 1.0, 240.0 / 255)
+_BLACK = (0, 0, 0)
+_RED = (1.0, 0, 0)
+
+_KEYPOINT_THRESHOLD = 0.05
+
+
+@unique
+class ColorMode(Enum):
+ """
+ Enum of different color modes to use for instance visualizations.
+ """
+
+ IMAGE = 0
+ """
+ Picks a random color for every instance and overlay segmentations with low opacity.
+ """
+ SEGMENTATION = 1
+ """
+ Let instances of the same category have similar colors
+ (from metadata.thing_colors), and overlay them with
+ high opacity. This provides more attention on the quality of segmentation.
+ """
+ IMAGE_BW = 2
+ """
+ Same as IMAGE, but convert all areas without masks to gray-scale.
+ Only available for drawing per-instance mask predictions.
+ """
+
+
+class GenericMask:
+ """
+ Attribute:
+ polygons (list[ndarray]): list[ndarray]: polygons for this mask.
+ Each ndarray has format [x, y, x, y, ...]
+ mask (ndarray): a binary mask
+ """
+
+ def __init__(self, mask_or_polygons, height, width):
+ self._mask = self._polygons = self._has_holes = None
+ self.height = height
+ self.width = width
+
+ m = mask_or_polygons
+ if isinstance(m, dict):
+ # RLEs
+ assert "counts" in m and "size" in m
+ if isinstance(m["counts"], list): # uncompressed RLEs
+ h, w = m["size"]
+ assert h == height and w == width
+ m = mask_util.frPyObjects(m, h, w)
+ self._mask = mask_util.decode(m)[:, :]
+ return
+
+ if isinstance(m, list): # list[ndarray]
+ self._polygons = [np.asarray(x).reshape(-1) for x in m]
+ return
+
+ if isinstance(m, np.ndarray): # assumed to be a binary mask
+ assert m.shape[1] != 2, m.shape
+ assert m.shape == (
+ height,
+ width,
+ ), f"mask shape: {m.shape}, target dims: {height}, {width}"
+ self._mask = m.astype("uint8")
+ return
+
+ raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
+
+ @property
+ def mask(self):
+ if self._mask is None:
+ self._mask = self.polygons_to_mask(self._polygons)
+ return self._mask
+
+ @property
+ def polygons(self):
+ if self._polygons is None:
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
+ return self._polygons
+
+ @property
+ def has_holes(self):
+ if self._has_holes is None:
+ if self._mask is not None:
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
+ else:
+ self._has_holes = False # if original format is polygon, does not have holes
+ return self._has_holes
+
+ def mask_to_polygons(self, mask):
+ # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
+ # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
+ # Internal contours (holes) are placed in hierarchy-2.
+ # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
+ mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
+ res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
+ hierarchy = res[-1]
+ if hierarchy is None: # empty mask
+ return [], False
+ has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
+ res = res[-2]
+ res = [x.flatten() for x in res]
+ # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
+ # We add 0.5 to turn them into real-value coordinate space. A better solution
+ # would be to first +0.5 and then dilate the returned polygon by 0.5.
+ res = [x + 0.5 for x in res if len(x) >= 6]
+ return res, has_holes
+
+ def polygons_to_mask(self, polygons):
+ rle = mask_util.frPyObjects(polygons, self.height, self.width)
+ rle = mask_util.merge(rle)
+ return mask_util.decode(rle)[:, :]
+
+ def area(self):
+ return self.mask.sum()
+
+ def bbox(self):
+ p = mask_util.frPyObjects(self.polygons, self.height, self.width)
+ p = mask_util.merge(p)
+ bbox = mask_util.toBbox(p)
+ bbox[2] += bbox[0]
+ bbox[3] += bbox[1]
+ return bbox
+
+
+class _PanopticPrediction:
+ """
+ Unify different panoptic annotation/prediction formats
+ """
+
+ def __init__(self, panoptic_seg, segments_info, metadata=None):
+ if segments_info is None:
+ assert metadata is not None
+ # If "segments_info" is None, we assume "panoptic_img" is a
+ # H*W int32 image storing the panoptic_id in the format of
+ # category_id * label_divisor + instance_id. We reserve -1 for
+ # VOID label.
+ label_divisor = metadata.label_divisor
+ segments_info = []
+ for panoptic_label in np.unique(panoptic_seg.numpy()):
+ if panoptic_label == -1:
+ # VOID region.
+ continue
+ pred_class = panoptic_label // label_divisor
+ isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
+ segments_info.append(
+ {
+ "id": int(panoptic_label),
+ "category_id": int(pred_class),
+ "isthing": bool(isthing),
+ }
+ )
+ del metadata
+
+ self._seg = panoptic_seg
+
+ self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
+ segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
+ areas = areas.numpy()
+ sorted_idxs = np.argsort(-areas)
+ self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
+ self._seg_ids = self._seg_ids.tolist()
+ for sid, area in zip(self._seg_ids, self._seg_areas):
+ if sid in self._sinfo:
+ self._sinfo[sid]["area"] = float(area)
+
+ def non_empty_mask(self):
+ """
+ Returns:
+ (H, W) array, a mask for all pixels that have a prediction
+ """
+ empty_ids = []
+ for id in self._seg_ids:
+ if id not in self._sinfo:
+ empty_ids.append(id)
+ if len(empty_ids) == 0:
+ return np.zeros(self._seg.shape, dtype=np.uint8)
+ assert (
+ len(empty_ids) == 1
+ ), ">1 ids corresponds to no labels. This is currently not supported"
+ return (self._seg != empty_ids[0]).numpy().astype(np.bool)
+
+ def semantic_masks(self):
+ for sid in self._seg_ids:
+ sinfo = self._sinfo.get(sid)
+ if sinfo is None or sinfo["isthing"]:
+ # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
+ continue
+ yield (self._seg == sid).numpy().astype(np.bool), sinfo
+
+ def instance_masks(self):
+ for sid in self._seg_ids:
+ sinfo = self._sinfo.get(sid)
+ if sinfo is None or not sinfo["isthing"]:
+ continue
+ mask = (self._seg == sid).numpy().astype(np.bool)
+ if mask.sum() > 0:
+ yield mask, sinfo
+
+
+def _create_text_labels(classes, scores, class_names, is_crowd=None):
+ """
+ Args:
+ classes (list[int] or None):
+ scores (list[float] or None):
+ class_names (list[str] or None):
+ is_crowd (list[bool] or None):
+
+ Returns:
+ list[str] or None
+ """
+ labels = None
+ if classes is not None:
+ if class_names is not None and len(class_names) > 0:
+ labels = [class_names[i] for i in classes]
+ else:
+ labels = [str(i) for i in classes]
+ if scores is not None:
+ if labels is None:
+ labels = ["{:.0f}%".format(s * 100) for s in scores]
+ else:
+ labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
+ if labels is not None and is_crowd is not None:
+ labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
+ return labels
+
+
+class VisImage:
+ def __init__(self, img, scale=1.0):
+ """
+ Args:
+ img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
+ scale (float): scale the input image
+ """
+ self.img = img
+ self.scale = scale
+ self.width, self.height = img.shape[1], img.shape[0]
+ self._setup_figure(img)
+
+ def _setup_figure(self, img):
+ """
+ Args:
+ Same as in :meth:`__init__()`.
+
+ Returns:
+ fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
+ ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
+ """
+ fig = mplfigure.Figure(frameon=False)
+ self.dpi = fig.get_dpi()
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
+ fig.set_size_inches(
+ (self.width * self.scale + 1e-2) / self.dpi,
+ (self.height * self.scale + 1e-2) / self.dpi,
+ )
+ self.canvas = FigureCanvasAgg(fig)
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
+ ax.axis("off")
+ self.fig = fig
+ self.ax = ax
+ self.reset_image(img)
+
+ def reset_image(self, img):
+ """
+ Args:
+ img: same as in __init__
+ """
+ img = img.astype("uint8")
+ self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
+
+ def save(self, filepath):
+ """
+ Args:
+ filepath (str): a string that contains the absolute path, including the file name, where
+ the visualized image will be saved.
+ """
+ self.fig.savefig(filepath)
+
+ def get_image(self):
+ """
+ Returns:
+ ndarray:
+ the visualized image of shape (H, W, 3) (RGB) in uint8 type.
+ The shape is scaled w.r.t the input image using the given `scale` argument.
+ """
+ canvas = self.canvas
+ s, (width, height) = canvas.print_to_buffer()
+ # buf = io.BytesIO() # works for cairo backend
+ # canvas.print_rgba(buf)
+ # width, height = self.width, self.height
+ # s = buf.getvalue()
+
+ buffer = np.frombuffer(s, dtype="uint8")
+
+ img_rgba = buffer.reshape(height, width, 4)
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
+ return rgb.astype("uint8")
+
+
+class Visualizer:
+ """
+ Visualizer that draws data about detection/segmentation on images.
+
+ It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
+ that draw primitive objects to images, as well as high-level wrappers like
+ `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
+ that draw composite data in some pre-defined style.
+
+ Note that the exact visualization style for the high-level wrappers are subject to change.
+ Style such as color, opacity, label contents, visibility of labels, or even the visibility
+ of objects themselves (e.g. when the object is too small) may change according
+ to different heuristics, as long as the results still look visually reasonable.
+
+ To obtain a consistent style, you can implement custom drawing functions with the
+ abovementioned primitive methods instead. If you need more customized visualization
+ styles, you can process the data yourself following their format documented in
+ tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
+ intend to satisfy everyone's preference on drawing styles.
+
+ This visualizer focuses on high rendering quality rather than performance. It is not
+ designed to be used for real-time applications.
+ """
+
+ # TODO implement a fast, rasterized version using OpenCV
+
+ def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):
+ """
+ Args:
+ img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
+ the height and width of the image respectively. C is the number of
+ color channels. The image is required to be in RGB format since that
+ is a requirement of the Matplotlib library. The image is also expected
+ to be in the range [0, 255].
+ metadata (Metadata): dataset metadata (e.g. class names and colors)
+ instance_mode (ColorMode): defines one of the pre-defined style for drawing
+ instances on an image.
+ """
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
+ if metadata is None:
+ metadata = MetadataCatalog.get("__nonexist__")
+ self.metadata = metadata
+ self.output = VisImage(self.img, scale=scale)
+ self.cpu_device = torch.device("cpu")
+
+ # too small texts are useless, therefore clamp to 9
+ self._default_font_size = max(
+ np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
+ )
+ self._default_font_size = 18
+ self._instance_mode = instance_mode
+ self.keypoint_threshold = _KEYPOINT_THRESHOLD
+
+ def draw_instance_predictions(self, predictions):
+ """
+ Draw instance-level prediction results on an image.
+
+ Args:
+ predictions (Instances): the output of an instance detection/segmentation
+ model. Following fields will be used to draw:
+ "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
+
+ Returns:
+ output (VisImage): image object with visualizations.
+ """
+ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
+ scores = predictions.scores if predictions.has("scores") else None
+ classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
+ labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
+ keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
+
+ keep = (scores > 0.5).cpu()
+ boxes = boxes[keep]
+ scores = scores[keep]
+ classes = np.array(classes)
+ classes = classes[np.array(keep)]
+ labels = np.array(labels)
+ labels = labels[np.array(keep)]
+
+ if predictions.has("pred_masks"):
+ masks = np.asarray(predictions.pred_masks)
+ masks = masks[np.array(keep)]
+ masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
+ else:
+ masks = None
+
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
+ # if self.metadata.get("thing_colors"):
+ colors = [
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
+ ]
+ alpha = 0.4
+ else:
+ colors = None
+ alpha = 0.4
+
+ if self._instance_mode == ColorMode.IMAGE_BW:
+ self.output.reset_image(
+ self._create_grayscale_image(
+ (predictions.pred_masks.any(dim=0) > 0).numpy()
+ if predictions.has("pred_masks")
+ else None
+ )
+ )
+ alpha = 0.3
+
+ self.overlay_instances(
+ masks=masks,
+ boxes=boxes,
+ labels=labels,
+ keypoints=keypoints,
+ assigned_colors=colors,
+ alpha=alpha,
+ )
+ return self.output
+
+ def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.7):
+ """
+ Draw semantic segmentation predictions/labels.
+
+ Args:
+ sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
+ Each value is the integer label of the pixel.
+ area_threshold (int): segments with less than `area_threshold` are not drawn.
+ alpha (float): the larger it is, the more opaque the segmentations are.
+
+ Returns:
+ output (VisImage): image object with visualizations.
+ """
+ if isinstance(sem_seg, torch.Tensor):
+ sem_seg = sem_seg.numpy()
+ labels, areas = np.unique(sem_seg, return_counts=True)
+ sorted_idxs = np.argsort(-areas).tolist()
+ labels = labels[sorted_idxs]
+ for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
+ try:
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
+ except (AttributeError, IndexError):
+ mask_color = None
+
+ binary_mask = (sem_seg == label).astype(np.uint8)
+ text = self.metadata.stuff_classes[label]
+ self.draw_binary_mask(
+ binary_mask,
+ color=mask_color,
+ edge_color=_OFF_WHITE,
+ text=text,
+ alpha=alpha,
+ area_threshold=area_threshold,
+ )
+ return self.output
+
+ def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7):
+ """
+ Draw panoptic prediction annotations or results.
+
+ Args:
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
+ segment.
+ segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
+ If it is a ``list[dict]``, each dict contains keys "id", "category_id".
+ If None, category id of each pixel is computed by
+ ``pixel // metadata.label_divisor``.
+ area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
+
+ Returns:
+ output (VisImage): image object with visualizations.
+ """
+ pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
+
+ if self._instance_mode == ColorMode.IMAGE_BW:
+ self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
+
+ # draw mask for all semantic segments first i.e. "stuff"
+ for mask, sinfo in pred.semantic_masks():
+ category_idx = sinfo["category_id"]
+ try:
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
+ except AttributeError:
+ mask_color = None
+
+ text = self.metadata.stuff_classes[category_idx]
+ self.draw_binary_mask(
+ mask,
+ color=mask_color,
+ edge_color=_OFF_WHITE,
+ text=text,
+ alpha=alpha,
+ area_threshold=area_threshold,
+ )
+
+ # draw mask for all instances second
+ all_instances = list(pred.instance_masks())
+ if len(all_instances) == 0:
+ return self.output
+ masks, sinfo = list(zip(*all_instances))
+ category_ids = [x["category_id"] for x in sinfo]
+
+ try:
+ scores = [x["score"] for x in sinfo]
+ except KeyError:
+ scores = None
+ labels = _create_text_labels(
+ category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo]
+ )
+
+ try:
+ colors = [
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids
+ ]
+ except AttributeError:
+ colors = None
+ self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)
+
+ return self.output
+
+ draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
+
+ def draw_dataset_dict(self, dic):
+ """
+ Draw annotations/segmentaions in Detectron2 Dataset format.
+
+ Args:
+ dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
+
+ Returns:
+ output (VisImage): image object with visualizations.
+ """
+ annos = dic.get("annotations", None)
+ if annos:
+ if "segmentation" in annos[0]:
+ masks = [x["segmentation"] for x in annos]
+ else:
+ masks = None
+ if "keypoints" in annos[0]:
+ keypts = [x["keypoints"] for x in annos]
+ keypts = np.array(keypts).reshape(len(annos), -1, 3)
+ else:
+ keypts = None
+
+ boxes = [
+ BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
+ if len(x["bbox"]) == 4
+ else x["bbox"]
+ for x in annos
+ ]
+
+ colors = None
+ category_ids = [x["category_id"] for x in annos]
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
+ colors = [
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
+ for c in category_ids
+ ]
+ names = self.metadata.get("thing_classes", None)
+ labels = _create_text_labels(
+ category_ids,
+ scores=None,
+ class_names=names,
+ is_crowd=[x.get("iscrowd", 0) for x in annos],
+ )
+ self.overlay_instances(
+ labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
+ )
+
+ sem_seg = dic.get("sem_seg", None)
+ if sem_seg is None and "sem_seg_file_name" in dic:
+ with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
+ sem_seg = Image.open(f)
+ sem_seg = np.asarray(sem_seg, dtype="uint8")
+ if sem_seg is not None:
+ self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.4)
+
+ pan_seg = dic.get("pan_seg", None)
+ if pan_seg is None and "pan_seg_file_name" in dic:
+ with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
+ pan_seg = Image.open(f)
+ pan_seg = np.asarray(pan_seg)
+ from panopticapi.utils import rgb2id
+
+ pan_seg = rgb2id(pan_seg)
+ if pan_seg is not None:
+ segments_info = dic["segments_info"]
+ pan_seg = torch.tensor(pan_seg)
+ self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.7)
+ return self.output
+
+ def overlay_instances(
+ self,
+ *,
+ boxes=None,
+ labels=None,
+ masks=None,
+ keypoints=None,
+ assigned_colors=None,
+ alpha=0.5,
+ ):
+ """
+ Args:
+ boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
+ or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
+ or a :class:`RotatedBoxes`,
+ or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
+ for the N objects in a single image,
+ labels (list[str]): the text to be displayed for each instance.
+ masks (masks-like object): Supported types are:
+
+ * :class:`detectron2.structures.PolygonMasks`,
+ :class:`detectron2.structures.BitMasks`.
+ * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
+ The first level of the list corresponds to individual instances. The second
+ level to all the polygon that compose the instance, and the third level
+ to the polygon coordinates. The third level should have the format of
+ [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
+ * list[ndarray]: each ndarray is a binary mask of shape (H, W).
+ * list[dict]: each dict is a COCO-style RLE.
+ keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
+ where the N is the number of instances and K is the number of keypoints.
+ The last dimension corresponds to (x, y, visibility or score).
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
+ for full list of formats that the colors are accepted in.
+ Returns:
+ output (VisImage): image object with visualizations.
+ """
+ num_instances = 0
+ if boxes is not None:
+ boxes = self._convert_boxes(boxes)
+ num_instances = len(boxes)
+ if masks is not None:
+ masks = self._convert_masks(masks)
+ if num_instances:
+ assert len(masks) == num_instances
+ else:
+ num_instances = len(masks)
+ if keypoints is not None:
+ if num_instances:
+ assert len(keypoints) == num_instances
+ else:
+ num_instances = len(keypoints)
+ keypoints = self._convert_keypoints(keypoints)
+ if labels is not None:
+ assert len(labels) == num_instances
+ if assigned_colors is None:
+ assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
+ if num_instances == 0:
+ return self.output
+ if boxes is not None and boxes.shape[1] == 5:
+ return self.overlay_rotated_instances(
+ boxes=boxes, labels=labels, assigned_colors=assigned_colors
+ )
+
+ # Display in largest to smallest order to reduce occlusion.
+ areas = None
+ if boxes is not None:
+ areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
+ elif masks is not None:
+ areas = np.asarray([x.area() for x in masks])
+
+ if areas is not None:
+ sorted_idxs = np.argsort(-areas).tolist()
+ # Re-order overlapped instances in descending order.
+ boxes = boxes[sorted_idxs] if boxes is not None else None
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
+ masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
+ assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
+ keypoints = keypoints[sorted_idxs] if keypoints is not None else None
+
+ for i in range(num_instances):
+ color = assigned_colors[i]
+ if boxes is not None:
+ self.draw_box(boxes[i], edge_color=color)
+
+ if masks is not None:
+ for segment in masks[i].polygons:
+ self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
+
+ if labels is not None:
+ # first get a box
+ if boxes is not None:
+ x0, y0, x1, y1 = boxes[i]
+ text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
+ horiz_align = "left"
+ elif masks is not None:
+ # skip small mask without polygon
+ if len(masks[i].polygons) == 0:
+ continue
+
+ x0, y0, x1, y1 = masks[i].bbox()
+
+ # draw text in the center (defined by median) when box is not drawn
+ # median is less sensitive to outliers.
+ text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
+ horiz_align = "center"
+ else:
+ continue # drawing the box confidence for keypoints isn't very useful.
+ # for small objects, draw text at the side to avoid occlusion
+ instance_area = (y1 - y0) * (x1 - x0)
+ if (
+ instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
+ or y1 - y0 < 40 * self.output.scale
+ ):
+ if y1 >= self.output.height - 5:
+ text_pos = (x1, y0)
+ else:
+ text_pos = (x0, y1)
+
+ height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
+ font_size = (
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
+ * 0.5
+ * self._default_font_size
+ )
+ self.draw_text(
+ labels[i],
+ text_pos,
+ color=lighter_color,
+ horizontal_alignment=horiz_align,
+ font_size=font_size,
+ )
+
+ # draw keypoints
+ if keypoints is not None:
+ for keypoints_per_instance in keypoints:
+ self.draw_and_connect_keypoints(keypoints_per_instance)
+
+ return self.output
+
+ def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
+ """
+ Args:
+ boxes (ndarray): an Nx5 numpy array of
+ (x_center, y_center, width, height, angle_degrees) format
+ for the N objects in a single image.
+ labels (list[str]): the text to be displayed for each instance.
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
+ for full list of formats that the colors are accepted in.
+
+ Returns:
+ output (VisImage): image object with visualizations.
+ """
+ num_instances = len(boxes)
+
+ if assigned_colors is None:
+ assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
+ if num_instances == 0:
+ return self.output
+
+ # Display in largest to smallest order to reduce occlusion.
+ if boxes is not None:
+ areas = boxes[:, 2] * boxes[:, 3]
+
+ sorted_idxs = np.argsort(-areas).tolist()
+ # Re-order overlapped instances in descending order.
+ boxes = boxes[sorted_idxs]
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
+ colors = [assigned_colors[idx] for idx in sorted_idxs]
+
+ for i in range(num_instances):
+ self.draw_rotated_box_with_label(
+ boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
+ )
+
+ return self.output
+
+ def draw_and_connect_keypoints(self, keypoints):
+ """
+ Draws keypoints of an instance and follows the rules for keypoint connections
+ to draw lines between appropriate keypoints. This follows color heuristics for
+ line color.
+
+ Args:
+ keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
+ and the last dimension corresponds to (x, y, probability).
+
+ Returns:
+ output (VisImage): image object with visualizations.
+ """
+ visible = {}
+ keypoint_names = self.metadata.get("keypoint_names")
+ for idx, keypoint in enumerate(keypoints):
+
+ # draw keypoint
+ x, y, prob = keypoint
+ if prob > self.keypoint_threshold:
+ self.draw_circle((x, y), color=_RED)
+ if keypoint_names:
+ keypoint_name = keypoint_names[idx]
+ visible[keypoint_name] = (x, y)
+
+ if self.metadata.get("keypoint_connection_rules"):
+ for kp0, kp1, color in self.metadata.keypoint_connection_rules:
+ if kp0 in visible and kp1 in visible:
+ x0, y0 = visible[kp0]
+ x1, y1 = visible[kp1]
+ color = tuple(x / 255.0 for x in color)
+ self.draw_line([x0, x1], [y0, y1], color=color)
+
+ # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
+ # Note that this strategy is specific to person keypoints.
+ # For other keypoints, it should just do nothing
+ try:
+ ls_x, ls_y = visible["left_shoulder"]
+ rs_x, rs_y = visible["right_shoulder"]
+ mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
+ except KeyError:
+ pass
+ else:
+ # draw line from nose to mid-shoulder
+ nose_x, nose_y = visible.get("nose", (None, None))
+ if nose_x is not None:
+ self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
+
+ try:
+ # draw line from mid-shoulder to mid-hip
+ lh_x, lh_y = visible["left_hip"]
+ rh_x, rh_y = visible["right_hip"]
+ except KeyError:
+ pass
+ else:
+ mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
+ self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
+ return self.output
+
+ """
+ Primitive drawing functions:
+ """
+
+ def draw_text(
+ self,
+ text,
+ position,
+ *,
+ font_size=None,
+ color="g",
+ horizontal_alignment="center",
+ rotation=0,
+ ):
+ """
+ Args:
+ text (str): class label
+ position (tuple): a tuple of the x and y coordinates to place text on image.
+ font_size (int, optional): font of the text. If not provided, a font size
+ proportional to the image width is calculated and used.
+ color: color of the text. Refer to `matplotlib.colors` for full list
+ of formats that are accepted.
+ horizontal_alignment (str): see `matplotlib.text.Text`
+ rotation: rotation angle in degrees CCW
+
+ Returns:
+ output (VisImage): image object with text drawn.
+ """
+ if not font_size:
+ font_size = self._default_font_size
+
+ # since the text background is dark, we don't want the text to be dark
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
+ color[np.argmax(color)] = max(0.8, np.max(color))
+
+ x, y = position
+ self.output.ax.text(
+ x,
+ y,
+ text,
+ size=font_size * self.output.scale,
+ family="sans-serif",
+ bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
+ verticalalignment="top",
+ horizontalalignment=horizontal_alignment,
+ color=color,
+ zorder=10,
+ rotation=rotation,
+ )
+ return self.output
+
+ def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
+ """
+ Args:
+ box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
+ are the coordinates of the image's top left corner. x1 and y1 are the
+ coordinates of the image's bottom right corner.
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
+ for full list of formats that are accepted.
+ line_style (string): the string to use to create the outline of the boxes.
+
+ Returns:
+ output (VisImage): image object with box drawn.
+ """
+ x0, y0, x1, y1 = box_coord
+ width = x1 - x0
+ height = y1 - y0
+
+ linewidth = max(self._default_font_size / 4, 1)
+
+ self.output.ax.add_patch(
+ mpl.patches.Rectangle(
+ (x0, y0),
+ width,
+ height,
+ fill=False,
+ edgecolor=edge_color,
+ linewidth=linewidth * self.output.scale,
+ alpha=alpha,
+ linestyle=line_style,
+ )
+ )
+ return self.output
+
+ def draw_rotated_box_with_label(
+ self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
+ ):
+ """
+ Draw a rotated box with label on its top-left corner.
+
+ Args:
+ rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
+ where cnt_x and cnt_y are the center coordinates of the box.
+ w and h are the width and height of the box. angle represents how
+ many degrees the box is rotated CCW with regard to the 0-degree box.
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
+ for full list of formats that are accepted.
+ line_style (string): the string to use to create the outline of the boxes.
+ label (string): label for rotated box. It will not be rendered when set to None.
+
+ Returns:
+ output (VisImage): image object with box drawn.
+ """
+ cnt_x, cnt_y, w, h, angle = rotated_box
+ area = w * h
+ # use thinner lines when the box is small
+ linewidth = self._default_font_size / (
+ 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
+ )
+
+ theta = angle * math.pi / 180.0
+ c = math.cos(theta)
+ s = math.sin(theta)
+ rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
+ # x: left->right ; y: top->down
+ rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
+ for k in range(4):
+ j = (k + 1) % 4
+ self.draw_line(
+ [rotated_rect[k][0], rotated_rect[j][0]],
+ [rotated_rect[k][1], rotated_rect[j][1]],
+ color=edge_color,
+ linestyle="--" if k == 1 else line_style,
+ linewidth=linewidth,
+ )
+
+ if label is not None:
+ text_pos = rotated_rect[1] # topleft corner
+
+ height_ratio = h / np.sqrt(self.output.height * self.output.width)
+ label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
+ font_size = (
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
+ )
+ self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
+
+ return self.output
+
+ def draw_circle(self, circle_coord, color, radius=3):
+ """
+ Args:
+ circle_coord (list(int) or tuple(int)): contains the x and y coordinates
+ of the center of the circle.
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
+ formats that are accepted.
+ radius (int): radius of the circle.
+
+ Returns:
+ output (VisImage): image object with box drawn.
+ """
+ x, y = circle_coord
+ self.output.ax.add_patch(
+ mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
+ )
+ return self.output
+
+ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
+ """
+ Args:
+ x_data (list[int]): a list containing x values of all the points being drawn.
+ Length of list should match the length of y_data.
+ y_data (list[int]): a list containing y values of all the points being drawn.
+ Length of list should match the length of x_data.
+ color: color of the line. Refer to `matplotlib.colors` for a full list of
+ formats that are accepted.
+ linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
+ for a full list of formats that are accepted.
+ linewidth (float or None): width of the line. When it's None,
+ a default value will be computed and used.
+
+ Returns:
+ output (VisImage): image object with line drawn.
+ """
+ if linewidth is None:
+ linewidth = self._default_font_size / 3
+ linewidth = max(linewidth, 1)
+ self.output.ax.add_line(
+ mpl.lines.Line2D(
+ x_data,
+ y_data,
+ linewidth=linewidth * self.output.scale,
+ color=color,
+ linestyle=linestyle,
+ )
+ )
+ return self.output
+
+ def draw_binary_mask(
+ self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.7, area_threshold=10
+ ):
+ """
+ Args:
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
+ type.
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
+ formats that are accepted. If None, will pick a random color.
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
+ full list of formats that are accepted.
+ text (str): if None, will be drawn on the object
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
+ area_threshold (float): a connected component smaller than this area will not be shown.
+
+ Returns:
+ output (VisImage): image object with mask drawn.
+ """
+ if color is None:
+ color = random_color(rgb=True, maximum=1)
+ color = mplc.to_rgb(color)
+
+ has_valid_segment = False
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
+
+ if not mask.has_holes:
+ # draw polygons for regular masks
+ for segment in mask.polygons:
+ area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
+ if area < (area_threshold or 0):
+ continue
+ has_valid_segment = True
+ segment = segment.reshape(-1, 2)
+ self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
+ else:
+ # TODO: Use Path/PathPatch to draw vector graphics:
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
+ rgba[:, :, :3] = color
+ rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
+ has_valid_segment = True
+ self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
+
+ if text is not None and has_valid_segment:
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
+ return self.output
+
+ def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
+ """
+ Args:
+ soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
+ formats that are accepted. If None, will pick a random color.
+ text (str): if None, will be drawn on the object
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
+
+ Returns:
+ output (VisImage): image object with mask drawn.
+ """
+ if color is None:
+ color = random_color(rgb=True, maximum=1)
+ color = mplc.to_rgb(color)
+
+ shape2d = (soft_mask.shape[0], soft_mask.shape[1])
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
+ rgba[:, :, :3] = color
+ rgba[:, :, 3] = soft_mask * alpha
+ self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
+
+ if text is not None:
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
+ binary_mask = (soft_mask > 0.5).astype("uint8")
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
+ return self.output
+
+ def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
+ """
+ Args:
+ segment: numpy array of shape Nx2, containing all the points in the polygon.
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
+ formats that are accepted.
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
+ full list of formats that are accepted. If not provided, a darker shade
+ of the polygon color will be used instead.
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
+
+ Returns:
+ output (VisImage): image object with polygon drawn.
+ """
+ if edge_color is None:
+ # make edge color darker than the polygon color
+ if alpha > 0.8:
+ edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
+ else:
+ edge_color = color
+ edge_color = mplc.to_rgb(edge_color) + (1,)
+
+ polygon = mpl.patches.Polygon(
+ segment,
+ fill=True,
+ facecolor=mplc.to_rgb(color) + (alpha,),
+ edgecolor=edge_color,
+ linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
+ )
+ self.output.ax.add_patch(polygon)
+ return self.output
+
+ """
+ Internal methods:
+ """
+
+ def _jitter(self, color):
+ """
+ Randomly modifies given color to produce a slightly different color than the color given.
+
+ Args:
+ color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
+ picked. The values in the list are in the [0.0, 1.0] range.
+
+ Returns:
+ jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
+ color after being jittered. The values in the list are in the [0.0, 1.0] range.
+ """
+ color = mplc.to_rgb(color)
+ # np.random.seed(0)
+ vec = np.random.rand(3)
+ # better to do it in another color space
+ vec = vec / np.linalg.norm(vec) * 0.5
+ res = np.clip(vec + color, 0, 1)
+ return tuple(res)
+
+ def _create_grayscale_image(self, mask=None):
+ """
+ Create a grayscale version of the original image.
+ The colors in masked area, if given, will be kept.
+ """
+ img_bw = self.img.astype("f4").mean(axis=2)
+ img_bw = np.stack([img_bw] * 3, axis=2)
+ if mask is not None:
+ img_bw[mask] = self.img[mask]
+ return img_bw
+
+ def _change_color_brightness(self, color, brightness_factor):
+ """
+ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
+ less or more saturation than the original color.
+
+ Args:
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
+ formats that are accepted.
+ brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
+ 0 will correspond to no change, a factor in [-1.0, 0) range will result in
+ a darker color and a factor in (0, 1.0] range will result in a lighter color.
+
+ Returns:
+ modified_color (tuple[double]): a tuple containing the RGB values of the
+ modified color. Each value in the tuple is in the [0.0, 1.0] range.
+ """
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
+ color = mplc.to_rgb(color)
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
+ modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
+ return modified_color
+
+ def _convert_boxes(self, boxes):
+ """
+ Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
+ """
+ if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
+ return boxes.tensor.detach().numpy()
+ else:
+ return np.asarray(boxes)
+
+ def _convert_masks(self, masks_or_polygons):
+ """
+ Convert different format of masks or polygons to a tuple of masks and polygons.
+
+ Returns:
+ list[GenericMask]:
+ """
+
+ m = masks_or_polygons
+ if isinstance(m, PolygonMasks):
+ m = m.polygons
+ if isinstance(m, BitMasks):
+ m = m.tensor.numpy()
+ if isinstance(m, torch.Tensor):
+ m = m.numpy()
+ ret = []
+ for x in m:
+ if isinstance(x, GenericMask):
+ ret.append(x)
+ else:
+ ret.append(GenericMask(x, self.output.height, self.output.width))
+ return ret
+
+ def _draw_text_in_mask(self, binary_mask, text, color):
+ """
+ Find proper places to draw text given a binary mask.
+ """
+ # TODO sometimes drawn on wrong objects. the heuristics here can improve.
+ _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
+ if stats[1:, -1].size == 0:
+ return
+ largest_component_id = np.argmax(stats[1:, -1]) + 1
+
+ # draw text on the largest component, as well as other very large components.
+ for cid in range(1, _num_cc):
+ if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
+ # median is more stable than centroid
+ # center = centroids[largest_component_id]
+ center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
+ self.draw_text(text, center, color=color)
+
+ def _convert_keypoints(self, keypoints):
+ if isinstance(keypoints, Keypoints):
+ keypoints = keypoints.tensor
+ keypoints = np.asarray(keypoints)
+ return keypoints
+
+ def get_output(self):
+ """
+ Returns:
+ output (VisImage): the image output containing the visualizations added
+ to the image.
+ """
+ return self.output
\ No newline at end of file