Spaces:
Running
Running
import os | |
import gradio as gr | |
import pandas as pd | |
import plotly | |
import plotly.graph_objects as go | |
from assets.color import color_dict | |
from assets.path import SEASON | |
def read_testset(season): | |
return pd.read_json(os.path.join("results", SEASON[season], "test_dataset.json")) | |
def build_keypoint_plot(dataset): | |
keypoint_set = {} | |
for i, categories in enumerate(dataset['categories']): | |
for category in categories: | |
parent = "" | |
for keypoint in category: | |
if not keypoint: | |
keypoint = "未分类" | |
if keypoint not in keypoint_set: | |
keypoint_set[keypoint] = {"value": 0} | |
keypoint_set[keypoint]['value'] += 1 | |
keypoint_set[keypoint]['parent'] = parent | |
keypoint_set[keypoint]['color'] = category[0] if category[0] else "未分类" | |
parent = keypoint | |
labels, parents, values, colors = [], [], [], [] | |
for k, v in keypoint_set.items(): | |
labels.append(k) | |
parents.append(v['parent']) | |
values.append(v['value']) | |
colors.append(color_dict[v['color']]) | |
fig = go.Figure(go.Sunburst( | |
labels=labels, | |
parents=parents, | |
values=values, | |
branchvalues="total", | |
insidetextorientation='radial', | |
marker={"colors": colors} | |
)) | |
return fig | |
def build_difficulty_plot(dataset): | |
xs = sorted(dataset['difficulty'].unique()) | |
ys = [len(dataset[dataset['difficulty'] == x]) for x in xs] | |
fig = go.Figure([go.Bar(x=xs, y=ys, marker={"color": ys, "colorscale": "Viridis", | |
"colorbar": {"title": "Total"}})]) | |
fig.update_layout(yaxis=dict(type='log')) | |
return fig | |
def build_plot(season): | |
dataset = pd.read_json(os.path.join("results", SEASON[season], "test_dataset.json")) | |
return build_keypoint_plot(dataset), build_difficulty_plot(dataset) | |
def create_data(top_components): | |
with gr.Tab("All data"): | |
with gr.Row(): | |
all_keypoint_plot = gr.Plot( | |
plotly.io.from_json(open("assets/keypoint_distribution.json", encoding="utf-8").read()), | |
label="Keypoint Distribution") | |
all_difficulty_plot = gr.Plot( | |
plotly.io.from_json(open("assets/difficulty_distribution.json", encoding="utf-8").read()), | |
label="Difficulty Distribution") | |
with gr.Tab("Test Data"): | |
with gr.Row(): | |
k_fig, d_fig = build_plot("latest") | |
test_keypoint_plot = gr.Plot(k_fig, label="Keypoint Distribution") | |
test_difficulty_plot = gr.Plot(d_fig, label="Difficulty Distribution") | |
return {"all_keypoint": all_keypoint_plot, "all_difficulty": all_difficulty_plot, | |
"test_keypoint": test_keypoint_plot, "test_difficulty": test_difficulty_plot} | |