|
from datasets import load_dataset |
|
import gradio as gr |
|
import os |
|
import random |
|
|
|
wmtis = load_dataset("nlphuji/wmtis-identify")['test'] |
|
print(f"Loaded WMTIS identify, first example:") |
|
print(wmtis[0]) |
|
dataset_size = len(wmtis) - 1 |
|
|
|
NATURAL_IMAGE = 'natural_image' |
|
NORMAL_IMAGE = 'normal_image' |
|
STRANGE_IMAGE = 'strange_image' |
|
|
|
def func(index): |
|
example = wmtis[index] |
|
outputs = [] |
|
target_size = example['normal_image'].size |
|
add_outputs_for_key(example, outputs, target_size, 'natural') |
|
add_outputs_for_key(example, outputs, target_size, 'normal') |
|
add_outputs_for_key(example, outputs, target_size, 'strange') |
|
|
|
return outputs |
|
|
|
|
|
def add_outputs_for_key(example, outputs, target_size, item): |
|
for item_key in [f'{item}_image', f'{item}_image_caption', f'rating_{item}', f'comments_{item}', f'{item}_hash']: |
|
if item_key == f'comments_{item}': |
|
outputs.append(get_empty_comment_if_needed(example[item_key])) |
|
elif item_key == f'{item}_image': |
|
outputs.append(example[item_key].resize(target_size)) |
|
else: |
|
outputs.append(example[item_key]) |
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
def get_empty_comment_if_needed(item): |
|
if item == 'nan': |
|
return '-' |
|
return item |
|
|
|
|
|
def add_column_by_key(item, target_size): |
|
with gr.Column(): |
|
img = wmtis[index][f"{item}_image"] |
|
img_resized = img.resize(target_size) |
|
i1 = gr.Image(value=img_resized, label=f'{item.capitalize()} Image') |
|
p1 = gr.Textbox(value=wmtis[index][f"{item}_image_caption"], label='BLIP2 Predicted Caption') |
|
r1 = gr.Textbox(value=wmtis[index][f"rating_{item}"], label='Rating') |
|
c1 = gr.Textbox(value=get_empty_comment_if_needed(wmtis[index][f"comments_{item}"]), label='Comments') |
|
t1 = gr.Textbox(value=wmtis[index][f"{item}_hash"], label='Image ID') |
|
item_outputs = [i1, p1, r1, c1, t1] |
|
return item_outputs |
|
|
|
|
|
with demo: |
|
gr.Markdown("# Main Challenge: Weirdness, not Synthesis") |
|
|
|
with gr.Column(): |
|
slider = gr.Slider(minimum=0, maximum=dataset_size) |
|
with gr.Row(): |
|
index = slider.value |
|
if index >= dataset_size: |
|
index = 0 |
|
target_size = wmtis[index]['normal_image'].size |
|
natural_outputs = add_column_by_key('natural', target_size) |
|
normal_outputs = add_column_by_key('normal', target_size) |
|
strange_outputs = add_column_by_key('strange', target_size) |
|
|
|
slider.change(func, inputs=[slider], outputs=natural_outputs + normal_outputs + strange_outputs) |
|
|
|
demo.launch() |
|
|