import os import gradio as gr import pandas as pd import plotly import plotly.graph_objects as go from assets.color import color_dict from assets.content import KEYPOINT_DISTRIBUTION, DIFFICULTY_DISTRIBUTION from assets.path import SEASON def read_testset(season): return pd.read_json(os.path.join("results", SEASON[season], "test_dataset.json")) def build_keypoint_plot(dataset): labels, parents, values, colors = {}, [], [], [] for categories, count in dataset['categories'].value_counts().items(): for category in categories: parent = "" for keypoint in category: if not keypoint: keypoint = "未分类" if keypoint not in labels: labels[keypoint] = len(labels) values.append(0) parents.append(parent) colors.append(color_dict[category[0]]) values[labels[keypoint]] += count parent = keypoint fig = go.Figure(go.Sunburst( labels=list(labels), parents=parents, values=values, branchvalues="total", insidetextorientation='radial', marker={"colors": colors} )) return fig def build_difficulty_plot(dataset): xs, ys = [], [] for x, y in dataset['difficulty'].value_counts().sort_index().items(): xs.append(x) ys.append(y) fig = go.Figure([go.Bar(x=xs, y=ys, marker={"color": ys, "colorscale": "Viridis", "colorbar": {"title": "Total"}})]) fig.update_layout(yaxis=dict(type='log')) return fig def build_plot(season): dataset = pd.read_json(os.path.join("results", SEASON[season], "test_dataset.json")) return build_keypoint_plot(dataset), build_difficulty_plot(dataset) def create_data(top_components): k_fig, d_fig = build_plot("latest") with gr.Tab("All data"): with gr.Row(): all_keypoint_plot = gr.Plot( plotly.io.from_json(KEYPOINT_DISTRIBUTION), label="Keypoint Distribution") all_difficulty_plot = gr.Plot( plotly.io.from_json(DIFFICULTY_DISTRIBUTION), label="Difficulty Distribution") with gr.Tab("Test Data"): with gr.Row(): test_keypoint_plot = gr.Plot(k_fig, label="Keypoint Distribution") test_difficulty_plot = gr.Plot(d_fig, label="Difficulty Distribution") return {"all_keypoint": all_keypoint_plot, "all_difficulty": all_difficulty_plot, "test_keypoint": test_keypoint_plot, "test_difficulty": test_difficulty_plot}