lambdaofgod's picture
feat: Add separate sliders for all and selected repositories in the PapersWithCode tasks tab
15420a6
raw
history blame
4.94 kB
import gradio as gr
import pandas as pd
import logging
import re
from task_visualizations import TaskVisualizations
import plotly.graph_objects as go
logging.basicConfig(level=logging.INFO)
class AppConfig:
repo_representations_path = "data/repo_representations.jsonl"
task_counts_path = "data/repos_task_counts.csv"
selected_task_counts_path = "data/selected_repos_task_counts.csv"
tasks_path = "data/paperswithcode_tasks.csv"
def load_repo_df(repo_representations_path):
data = pd.read_json(repo_representations_path, lines=True, orient="records")
return data.assign(
text=data["text"]
.str.replace(r"<img.*\/>", "", regex=True)
.str.replace("│", "\n")
.str.replace("⋮", "\n")
)
def display_representations(repo, representation1, representation2):
repo_data = repos_df[repos_df["repo_name"] == repo]
logging.info(f"repo_data: {repo_data}")
text1 = (
repo_data[repo_data["representation"] == representation1]["text"].iloc[0]
if not repo_data[repo_data["representation"] == representation1].empty
else "No data available"
)
text2 = (
repo_data[repo_data["representation"] == representation2]["text"].iloc[0]
if not repo_data[repo_data["representation"] == representation2].empty
else "No data available"
)
return text1, text2
def setup_repository_representations_tab(repos, representation_types):
gr.Markdown("Select a repository and two representation types to compare them.")
with gr.Row():
repo = gr.Dropdown(choices=repos, label="Repository", value=repos[0])
representation1 = gr.Dropdown(
choices=representation_types, label="Representation 1", value="readme"
)
representation2 = gr.Dropdown(
choices=representation_types,
label="Representation 2",
value="generated_readme",
)
with gr.Row():
with gr.Column(
elem_id="column1",
variant="panel",
scale=1,
min_width=300,
):
text1 = gr.Markdown()
with gr.Column(
elem_id="column2",
variant="panel",
scale=1,
min_width=300,
):
text2 = gr.Markdown()
def update_representations(repo, representation1, representation2):
text1_content, text2_content = display_representations(
repo, representation1, representation2
)
return (
f"### Representation 1: {representation1}\n\n{text1_content}",
f"### Representation 2: {representation2}\n\n{text2_content}",
)
# Initial call to populate textboxes with default values
text1.value, text2.value = update_representations(
repos[0], "readme", "generated_readme"
)
for component in [repo, representation1, representation2]:
component.change(
fn=update_representations,
inputs=[repo, representation1, representation2],
outputs=[text1, text2],
)
## main
repos_df = load_repo_df(AppConfig.repo_representations_path)
repos = list(repos_df["repo_name"].unique())
representation_types = list(repos_df["representation"].unique())
logging.info(f"found {len(repos)} repositories")
logging.info(f"representation types: {representation_types}")
task_visualizations = TaskVisualizations(
AppConfig.task_counts_path,
AppConfig.selected_task_counts_path,
AppConfig.tasks_path,
)
with gr.Blocks() as demo:
with gr.Tab("Explore Repository Representations"):
setup_repository_representations_tab(repos, representation_types)
with gr.Tab("Explore PapersWithCode Tasks"):
task_counts_description = """
## PapersWithCode Tasks Visualization
PapersWithCode tasks are grouped by area.
""".strip()
gr.Markdown(task_counts_description)
with gr.Row():
min_task_counts_slider_all = gr.Slider(
minimum=10,
maximum=1000,
value=100,
step=10,
label="Minimum Task Count (All Repositories)",
)
min_task_counts_slider_selected = gr.Slider(
minimum=10,
maximum=1000,
value=100,
step=10,
label="Minimum Task Count (Selected Repositories)",
)
update_button = gr.Button("Update Plots")
with gr.Row("Task Counts"):
all_repos_tasks_plot = gr.Plot(label="All Repositories")
selected_repos_tasks_plot = gr.Plot(label="Selected Repositories")
update_button.click(
fn=task_visualizations.get_tasks_sunbursts,
inputs=[min_task_counts_slider_all, min_task_counts_slider_selected],
outputs=[all_repos_tasks_plot, selected_repos_tasks_plot],
)
demo.launch()