![lambdaofgod's picture](https://cdn-avatars.huggingface.co/v1/production/uploads/1640433294356-617a5afbff3db6021d06977b.jpeg)
feat: Add separate sliders for all and selected repositories in the PapersWithCode tasks tab
15420a6
import gradio as gr | |
import pandas as pd | |
import logging | |
import re | |
from task_visualizations import TaskVisualizations | |
import plotly.graph_objects as go | |
logging.basicConfig(level=logging.INFO) | |
class AppConfig: | |
repo_representations_path = "data/repo_representations.jsonl" | |
task_counts_path = "data/repos_task_counts.csv" | |
selected_task_counts_path = "data/selected_repos_task_counts.csv" | |
tasks_path = "data/paperswithcode_tasks.csv" | |
def load_repo_df(repo_representations_path): | |
data = pd.read_json(repo_representations_path, lines=True, orient="records") | |
return data.assign( | |
text=data["text"] | |
.str.replace(r"<img.*\/>", "", regex=True) | |
.str.replace("│", "\n") | |
.str.replace("⋮", "\n") | |
) | |
def display_representations(repo, representation1, representation2): | |
repo_data = repos_df[repos_df["repo_name"] == repo] | |
logging.info(f"repo_data: {repo_data}") | |
text1 = ( | |
repo_data[repo_data["representation"] == representation1]["text"].iloc[0] | |
if not repo_data[repo_data["representation"] == representation1].empty | |
else "No data available" | |
) | |
text2 = ( | |
repo_data[repo_data["representation"] == representation2]["text"].iloc[0] | |
if not repo_data[repo_data["representation"] == representation2].empty | |
else "No data available" | |
) | |
return text1, text2 | |
def setup_repository_representations_tab(repos, representation_types): | |
gr.Markdown("Select a repository and two representation types to compare them.") | |
with gr.Row(): | |
repo = gr.Dropdown(choices=repos, label="Repository", value=repos[0]) | |
representation1 = gr.Dropdown( | |
choices=representation_types, label="Representation 1", value="readme" | |
) | |
representation2 = gr.Dropdown( | |
choices=representation_types, | |
label="Representation 2", | |
value="generated_readme", | |
) | |
with gr.Row(): | |
with gr.Column( | |
elem_id="column1", | |
variant="panel", | |
scale=1, | |
min_width=300, | |
): | |
text1 = gr.Markdown() | |
with gr.Column( | |
elem_id="column2", | |
variant="panel", | |
scale=1, | |
min_width=300, | |
): | |
text2 = gr.Markdown() | |
def update_representations(repo, representation1, representation2): | |
text1_content, text2_content = display_representations( | |
repo, representation1, representation2 | |
) | |
return ( | |
f"### Representation 1: {representation1}\n\n{text1_content}", | |
f"### Representation 2: {representation2}\n\n{text2_content}", | |
) | |
# Initial call to populate textboxes with default values | |
text1.value, text2.value = update_representations( | |
repos[0], "readme", "generated_readme" | |
) | |
for component in [repo, representation1, representation2]: | |
component.change( | |
fn=update_representations, | |
inputs=[repo, representation1, representation2], | |
outputs=[text1, text2], | |
) | |
## main | |
repos_df = load_repo_df(AppConfig.repo_representations_path) | |
repos = list(repos_df["repo_name"].unique()) | |
representation_types = list(repos_df["representation"].unique()) | |
logging.info(f"found {len(repos)} repositories") | |
logging.info(f"representation types: {representation_types}") | |
task_visualizations = TaskVisualizations( | |
AppConfig.task_counts_path, | |
AppConfig.selected_task_counts_path, | |
AppConfig.tasks_path, | |
) | |
with gr.Blocks() as demo: | |
with gr.Tab("Explore Repository Representations"): | |
setup_repository_representations_tab(repos, representation_types) | |
with gr.Tab("Explore PapersWithCode Tasks"): | |
task_counts_description = """ | |
## PapersWithCode Tasks Visualization | |
PapersWithCode tasks are grouped by area. | |
""".strip() | |
gr.Markdown(task_counts_description) | |
with gr.Row(): | |
min_task_counts_slider_all = gr.Slider( | |
minimum=10, | |
maximum=1000, | |
value=100, | |
step=10, | |
label="Minimum Task Count (All Repositories)", | |
) | |
min_task_counts_slider_selected = gr.Slider( | |
minimum=10, | |
maximum=1000, | |
value=100, | |
step=10, | |
label="Minimum Task Count (Selected Repositories)", | |
) | |
update_button = gr.Button("Update Plots") | |
with gr.Row("Task Counts"): | |
all_repos_tasks_plot = gr.Plot(label="All Repositories") | |
selected_repos_tasks_plot = gr.Plot(label="Selected Repositories") | |
update_button.click( | |
fn=task_visualizations.get_tasks_sunbursts, | |
inputs=[min_task_counts_slider_all, min_task_counts_slider_selected], | |
outputs=[all_repos_tasks_plot, selected_repos_tasks_plot], | |
) | |
demo.launch() | |