patrickvonplaten's picture
up
88fd33d
raw
history blame
7.42 kB
from datasets import load_dataset
from collections import Counter
from random import sample, shuffle
import datasets
from pandas import DataFrame
from huggingface_hub import list_datasets
import os
import gradio as gr
import secrets
parti_prompt_results = []
ORG = "diffusers-parti-prompts"
SUBMISSIONS = {
"sd-v1-5": load_dataset(os.path.join(ORG, "sd-v1-5"))["train"],
"sd-v2-1": load_dataset(os.path.join(ORG, "sd-v2.1"))["train"],
"if-v1-0": load_dataset(os.path.join(ORG, "karlo-v1"))["train"],
"karlo": load_dataset(os.path.join(ORG, "if-v-1.0"))["train"],
# "Kadinsky":
}
NUM_QUESTIONS = 10
MODEL_KEYS = "-".join(SUBMISSIONS.keys())
SUBMISSION_ORG = f"results-{MODEL_KEYS}"
submission_names = list(SUBMISSIONS.keys())
num_images = len(SUBMISSIONS[submission_names[0]])
def generate_random_hash(length=8):
"""
Generates a random hash of specified length.
Args:
length (int): The length of the hash to generate.
Returns:
str: A random hash of specified length.
"""
if length % 2 != 0:
raise ValueError("Length should be an even number.")
num_bytes = length // 2
random_bytes = secrets.token_bytes(num_bytes)
random_hash = secrets.token_hex(num_bytes)
return random_hash
def start():
ids = {id: 0 for id in range(num_images)}
# submitted_ids = Counter(submissions["ids"])
all_datasets = list_datasets(author=SUBMISSION_ORG)
relevant_ids = [d.id for d in all_datasets]
submitted_ids = []
for _id in relevant_ids:
ds = load_dataset(_id)["train"]
submitted_ids += ds["id"]
submitted_ids = Counter(submitted_ids)
ids = {**ids, **submitted_ids}
# sort by count
ids = sorted(ids.items(), key=lambda x: x[1])
ids = [i[0] for i in ids]
# get lowest count ids
id_candidates = ids[: (10 * NUM_QUESTIONS)]
# get random `NUM_QUESTIONS` ids to check
image_ids = sample(id_candidates, k=NUM_QUESTIONS)
images = {}
for i in range(NUM_QUESTIONS):
order = list(range(len(SUBMISSIONS)))
shuffle(order)
id = image_ids[i]
row = SUBMISSIONS[submission_names[0]][id]
images[i] = {
"prompt": row["Prompt"],
"result": "",
"id": id,
"Challenge": row["Challenge"],
"Category": row["Category"],
"Note": row["Note"],
}
for n, m in enumerate(order):
images[i][f"choice_{n}"] = m
images_frame = DataFrame.from_dict(images, orient="index")
return images_frame
def process(dataframe, row_number=0):
if row_number == NUM_QUESTIONS:
return None, ""
image_id = dataframe.iloc[row_number]["id"]
choices = [
submission_names[dataframe.iloc[row_number][f"choice_{i}"]]
for i in range(len(SUBMISSIONS))
]
images = [SUBMISSIONS[c][int(image_id)]["images"] for c in choices]
prompt = SUBMISSIONS[choices[0]][int(image_id)]["Prompt"]
prompt = f"Prompt {row_number + 1}/{NUM_QUESTIONS}: '{prompt}'"
return images, prompt
def write_result(user_choice, row_number, dataframe, prompt):
if row_number == NUM_QUESTIONS:
return row_number, dataframe
user_choice = int(user_choice)
chosen_model = submission_names[dataframe.iloc[row_number][f"choice_{user_choice}"]]
dataframe.loc[row_number, "result"] = chosen_model
return row_number + 1, dataframe
def get_index(evt: gr.SelectData) -> int:
return evt.index
def change_view(row_number, dataframe):
if row_number == NUM_QUESTIONS:
favorite_model = dataframe["result"].value_counts().idxmax()
dataset = datasets.Dataset.from_pandas(dataframe)
dataset = dataset.remove_columns(set(dataset.column_names) - set(["id", "result"]))
hash = generate_random_hash()
repo_id = os.path.join(SUBMISSION_ORG, hash)
dataset.push_to_hub(repo_id, token=os.getenv("HF_TOKEN"))
return {
result: f"You are of type: {favorite_model}!",
result_view: gr.update(visible=True),
gallery_view: gr.update(visible=False),
}
else:
return {
result: "",
result_view: gr.update(visible=False),
gallery_view: gr.update(visible=True),
}
if True:
TITLE = "Open-Source Parti Prompts"
DESCRIPTION = "An interactive 'Which Generative AI' game to evaluate open-source generative AI models"
GALLERY_COLUMN_NUM = len(SUBMISSIONS)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(TITLE)
gr.Markdown(DESCRIPTION)
start_button = gr.Button("Start").style(full_width=False)
headers = ["prompt", "result", "id", "Challenge", "Category", "Note"] + [
f"choice_{i}" for i in range(len(SUBMISSIONS))
]
datatype = ["str", "str", "number", "str", "str", "str"] + len(SUBMISSIONS) * [
"number"
]
with gr.Column(visible=False):
row_number = gr.Number(
label="Current row selection index",
value=0,
precision=0,
interactive=False,
)
# Create Data Frame
with gr.Column(visible=False) as result_view:
result = gr.Markdown("")
dataframe = gr.Dataframe(
headers=headers,
datatype=datatype,
row_count=NUM_QUESTIONS,
col_count=(6 + len(SUBMISSIONS), "fixed"),
interactive=False,
)
gr.Markdown("Click on start to play again!")
with gr.Column(visible=True) as gallery_view:
gr.Markdown("Pick your the photo that best corresponds to the prompt.")
prompt = gr.Markdown(f"Prompt 1/{NUM_QUESTIONS}: ")
gallery = gr.Gallery(
label="All images", show_label=False, elem_id="gallery"
).style(columns=GALLERY_COLUMN_NUM, object_fit="contain")
next_button = gr.Button("Select").style(full_width=False)
with gr.Column(visible=False):
selected_image = gr.Number(label="Selected index", value=-1, precision=0)
start_button.click(
fn=start,
inputs=[],
outputs=dataframe
).then(
fn=lambda x: 0 if x == NUM_QUESTIONS else x,
inputs=[row_number],
outputs=[row_number],
).then(
fn=change_view,
inputs=[row_number, dataframe],
outputs=[result_view, gallery_view, result]
).then(
fn=process, inputs=[dataframe], outputs=[gallery, prompt]
)
gallery.select(
fn=get_index,
outputs=selected_image,
queue=False,
)
next_button.click(
fn=write_result,
inputs=[selected_image, row_number, dataframe, prompt],
outputs=[row_number, dataframe],
).then(
fn=process,
inputs=[dataframe, row_number],
outputs=[gallery, prompt]
).then(
fn=change_view,
inputs=[row_number, dataframe],
outputs=[result_view, gallery_view, result]
).then(
fn=lambda x: 0 if x == NUM_QUESTIONS else x,
inputs=[row_number],
outputs=[row_number],
)
demo.launch()