"""Interface for labeling concepts in images. """ from typing import Optional import gradio as gr from src import global_variables from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME def get_image( step: int, split: str, index: str, filtered_indices: dict, profile: gr.OAuthProfile ): username = profile.username try: int_index = int(index) except: gr.Warning("Error parsing index using 0") int_index = 0 sample_idx = int_index + step if sample_idx < 0: gr.Warning("No previous image.") sample_idx = 0 if sample_idx >= len(global_variables.all_metadata[split]): gr.Warning("No next image.") sample_idx = len(global_variables.all_metadata[split]) - 1 sample = global_variables.all_metadata[split][sample_idx] image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}" try: username_votes = sample["votes"][username] voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)] except KeyError: voted_concepts = [] return ( image_path, voted_concepts, f"{split}:{sample_idx}", sample["class"], sample["concepts"], str(sample_idx), filtered_indices, ) def make_get_image(step): def f( split: str, index: str, filtered_indices: dict, profile: gr.OAuthProfile ): return get_image(step, split, index, filtered_indices, profile) return f get_next_image = make_get_image(1) get_prev_image = make_get_image(-1) get_current_image = make_get_image(0) def submit_label( voted_concepts: list, current_image: Optional[str], split, index, filtered_indices, profile: gr.OAuthProfile ): username = profile.username if current_image is None: gr.Warning("No image selected.") return None, None, None, None, None, index, filtered_indices current_split, idx = current_image.split(":") idx = int(idx) global_variables.get_metadata(current_split) if "votes" not in global_variables.all_metadata[current_split][idx]: global_variables.all_metadata[current_split][idx]["votes"] = {} global_variables.all_metadata[current_split][idx]["votes"][username] = {c: c in voted_concepts for c in CONCEPTS} vote_sum = {c: 0 for c in CONCEPTS} concepts = {} for c in CONCEPTS: for vote in global_variables.all_metadata[current_split][idx]["votes"].values(): if c not in vote: continue vote_sum[c] += 2 * vote[c] - 1 concepts[c] = vote_sum[c] > 0 if vote_sum[c] != 0 else None global_variables.all_metadata[current_split][idx]["concepts"] = concepts global_variables.save_metadata(current_split) gr.Info("Submit success") return get_next_image( split, index, filtered_indices, profile ) with gr.Blocks() as interface: with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown( "## # Image Selection", ) split = gr.Radio( label="Split", choices=["train", "validation", "test"], value="train", ) index = gr.Textbox( value="0", label="Index", max_lines=1, ) with gr.Group(): voted_concepts = gr.CheckboxGroup( label="Voted Concepts", choices=CONCEPTS, ) with gr.Row(): prev_button = gr.Button( value="Prev", ) next_button = gr.Button( value="Next", ) gr.LoginButton() submit_button = gr.Button( value="Submit", ) with gr.Group(): gr.Markdown( "## # Image Info", ) im_class = gr.Textbox( label="Class", ) im_concepts = gr.JSON( label="Concepts", ) with gr.Column(): image = gr.Image( label="Image", ) current_image = gr.State(None) filtered_indices = gr.State({ split: list(range(len(global_variables.all_metadata[split]))) for split in global_variables.all_metadata }) common_output = [ image, voted_concepts, current_image, im_class, im_concepts, index, filtered_indices, ] common_input = [split, index, filtered_indices] prev_button.click( get_prev_image, inputs=common_input, outputs=common_output ) next_button.click( get_next_image, inputs=common_input, outputs=common_output ) submit_button.click( submit_label, inputs=[voted_concepts, current_image, split, index, filtered_indices], outputs=common_output ) index.submit( get_current_image, inputs=common_input, outputs=common_output, ) interface.load( get_current_image, inputs=common_input, outputs=common_output, )