lambdaofgod's picture
better looking gallery
89d0cf9
raw
history blame
6.01 kB
import gradio as gr
import pandas as pd
import logging
import re
from task_visualizations import TaskVisualizations
import plotly.graph_objects as go
from functools import partial
from text_visualization import WordCloudExtractor
logging.basicConfig(level=logging.INFO)
class AppConfig:
repo_representations_path = "data/repo_representations.jsonl"
task_counts_path = "data/repos_task_counts.csv"
selected_task_counts_path = "data/selected_repos_task_counts.csv"
tasks_path = "data/paperswithcode_tasks.csv"
def load_repo_df(repo_representations_path):
data = pd.read_json(repo_representations_path, lines=True, orient="records")
return data.assign(
text=data["text"]
.str.replace(r"<img.*\/>", "", regex=True)
.str.replace("│", "\n")
.str.replace("⋮", "\n")
)
def display_representations(repo, representation1, representation2):
repo_data = repos_df[repos_df["repo_name"] == repo]
logging.info(f"repo_data: {repo_data}")
text1 = (
repo_data[repo_data["representation"] == representation1]["text"].iloc[0]
if not repo_data[repo_data["representation"] == representation1].empty
else "No data available"
)
text2 = (
repo_data[repo_data["representation"] == representation2]["text"].iloc[0]
if not repo_data[repo_data["representation"] == representation2].empty
else "No data available"
)
return text1, text2
def get_representation_wordclouds(representations, repos_df):
wordclouds = dict()
for representation in representations:
texts = list(repos_df[repos_df["representation"] == representation]["text"])
wordclouds[representation] = WordCloudExtractor().extract_wordcloud_image(texts)
return wordclouds
def setup_repository_representations_tab(repos, representation_types):
wordcloud_dict = get_representation_wordclouds(representation_types, repos_df)
gr.Markdown("## Wordclouds")
gr.Gallery([(wordcloud, representation_type) for representation_type, wordcloud in wordcloud_dict.items()], columns=[3], rows=[4], height=300)
gr.Markdown("Select a repository and two representation types to compare them.")
with gr.Row():
repo = gr.Dropdown(choices=repos, label="Repository", value=repos[0])
representation1 = gr.Dropdown(
choices=representation_types, label="Representation 1", value="readme"
)
representation2 = gr.Dropdown(
choices=representation_types,
label="Representation 2",
value="generated_readme",
)
with gr.Row():
with gr.Column(
elem_id="column1",
variant="panel",
scale=1,
min_width=300,
):
text1 = gr.Markdown()
with gr.Column(
elem_id="column2",
variant="panel",
scale=1,
min_width=300,
):
text2 = gr.Markdown()
def update_representations(repo, representation1, representation2):
text1_content, text2_content = display_representations(
repo, representation1, representation2
)
return (
f"### Representation 1: {representation1}\n\n{text1_content}",
f"### Representation 2: {representation2}\n\n{text2_content}",
)
# Initial call to populate textboxes with default values
text1.value, text2.value = update_representations(
repos[0], "readme", "generated_readme"
)
for component in [repo, representation1, representation2]:
component.change(
fn=update_representations,
inputs=[repo, representation1, representation2],
outputs=[text1, text2],
)
## main
repos_df = load_repo_df(AppConfig.repo_representations_path)
repos = list(repos_df["repo_name"].unique())
representation_types = list(repos_df["representation"].unique())
logging.info(f"found {len(repos)} repositories")
logging.info(f"representation types: {representation_types}")
task_visualizations = TaskVisualizations(
AppConfig.task_counts_path,
AppConfig.selected_task_counts_path,
AppConfig.tasks_path,
)
with gr.Blocks() as demo:
with gr.Tab("Explore Repository Representations"):
setup_repository_representations_tab(repos, representation_types)
with gr.Tab("Explore PapersWithCode Tasks"):
task_counts_description = """
## PapersWithCode Tasks Visualization
PapersWithCode tasks are grouped by area.
In addition to showing task distribution across the original dataset we display task counts in the repositories we selected.
""".strip()
gr.Markdown(task_counts_description)
with gr.Row():
min_task_counts_slider_all = gr.Slider(
minimum=50,
maximum=1000,
value=150,
step=50,
label="Minimum Task Count (All Repositories)",
)
update_button = gr.Button("Update Plots")
min_task_counts_slider_selected = gr.Slider(
minimum=10,
maximum=100,
value=50,
step=10,
label="Minimum Task Count (Selected Repositories)",
)
update_selected_button = gr.Button("Update Plots")
with gr.Row("Task Counts"):
all_repos_tasks_plot = gr.Plot(label="All Repositories")
selected_repos_tasks_plot = gr.Plot(label="Selected Repositories")
update_button.click(
fn=partial(task_visualizations.get_tasks_sunburst, which_df="all"),
inputs=[min_task_counts_slider_all],
outputs=[all_repos_tasks_plot],
)
update_selected_button.click(
fn=partial(task_visualizations.get_tasks_sunburst, which_df="selected"),
inputs=[min_task_counts_slider_selected],
outputs=[selected_repos_tasks_plot],
)
demo.launch()