explore-label-concepts / src /sample_interface.py
Xmaster6y's picture
vote interface
405b0bd unverified
raw
history blame
5.47 kB
"""Interface for labeling concepts in images.
"""
from typing import Optional
import gradio as gr
from src import global_variables
from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME
def get_image(
step: int,
split: str,
index: str,
filtered_indices: dict,
profile: gr.OAuthProfile
):
username = profile.username
try:
int_index = int(index)
except:
gr.Warning("Error parsing index using 0")
int_index = 0
sample_idx = int_index + step
if sample_idx < 0:
gr.Warning("No previous image.")
sample_idx = 0
if sample_idx >= len(global_variables.all_metadata[split]):
gr.Warning("No next image.")
sample_idx = len(global_variables.all_metadata[split]) - 1
sample = global_variables.all_metadata[split][sample_idx]
image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}"
try:
username_votes = sample["votes"][username]
voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)]
except KeyError:
voted_concepts = []
return (
image_path,
voted_concepts,
f"{split}:{sample_idx}",
sample["class"],
sample["concepts"],
str(sample_idx),
filtered_indices,
)
def make_get_image(step):
def f(
split: str,
index: str,
filtered_indices: dict,
profile: gr.OAuthProfile
):
return get_image(step, split, index, filtered_indices, profile)
return f
get_next_image = make_get_image(1)
get_prev_image = make_get_image(-1)
get_current_image = make_get_image(0)
def submit_label(
voted_concepts: list,
current_image: Optional[str],
split,
index,
filtered_indices,
profile: gr.OAuthProfile
):
username = profile.username
if current_image is None:
gr.Warning("No image selected.")
return None, None, None, None, None, index, filtered_indices
current_split, idx = current_image.split(":")
idx = int(idx)
global_variables.get_metadata(current_split)
if "votes" not in global_variables.all_metadata[current_split][idx]:
global_variables.all_metadata[current_split][idx]["votes"] = {}
global_variables.all_metadata[current_split][idx]["votes"][username] = {c: c in voted_concepts for c in CONCEPTS}
vote_sum = {c: 0 for c in CONCEPTS}
concepts = {}
for c in CONCEPTS:
for vote in global_variables.all_metadata[current_split][idx]["votes"].values():
if c not in vote:
continue
vote_sum[c] += 2 * vote[c] - 1
concepts[c] = vote_sum[c] > 0 if vote_sum[c] != 0 else None
global_variables.all_metadata[current_split][idx]["concepts"] = concepts
global_variables.save_metadata(current_split)
gr.Info("Submit success")
return get_next_image(
split,
index,
filtered_indices,
profile
)
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"## # Image Selection",
)
split = gr.Radio(
label="Split",
choices=["train", "validation", "test"],
value="train",
)
index = gr.Textbox(
value="0",
label="Index",
max_lines=1,
)
with gr.Group():
voted_concepts = gr.CheckboxGroup(
label="Voted Concepts",
choices=CONCEPTS,
)
with gr.Row():
prev_button = gr.Button(
value="Prev",
)
next_button = gr.Button(
value="Next",
)
gr.LoginButton()
submit_button = gr.Button(
value="Submit",
)
with gr.Group():
gr.Markdown(
"## # Image Info",
)
im_class = gr.Textbox(
label="Class",
)
im_concepts = gr.JSON(
label="Concepts",
)
with gr.Column():
image = gr.Image(
label="Image",
)
current_image = gr.State(None)
filtered_indices = gr.State({
split: list(range(len(global_variables.all_metadata[split])))
for split in global_variables.all_metadata
})
common_output = [
image,
voted_concepts,
current_image,
im_class,
im_concepts,
index,
filtered_indices,
]
common_input = [split, index, filtered_indices]
prev_button.click(
get_prev_image,
inputs=common_input,
outputs=common_output
)
next_button.click(
get_next_image,
inputs=common_input,
outputs=common_output
)
submit_button.click(
submit_label,
inputs=[voted_concepts, current_image, split, index, filtered_indices],
outputs=common_output
)
index.submit(
get_current_image,
inputs=common_input,
outputs=common_output,
)
interface.load(
get_current_image,
inputs=common_input,
outputs=common_output,
)