import os import gradio as gr import pandas as pd import plotly import plotly.graph_objects as go from assets.color import color_dict from assets.path import SEASON def read_testset(season): return pd.read_json(os.path.join("results", SEASON[season], "test_dataset.json")) def build_keypoint_plot(dataset): keypoint_set = {} for i, categories in enumerate(dataset['categories']): for category in categories: parent = "" for keypoint in category: if not keypoint: keypoint = "未分类" if keypoint not in keypoint_set: keypoint_set[keypoint] = {"value": 0} keypoint_set[keypoint]['value'] += 1 keypoint_set[keypoint]['parent'] = parent keypoint_set[keypoint]['color'] = category[0] if category[0] else "未分类" parent = keypoint labels, parents, values, colors = [], [], [], [] for k, v in keypoint_set.items(): labels.append(k) parents.append(v['parent']) values.append(v['value']) colors.append(color_dict[v['color']]) fig = go.Figure(go.Sunburst( labels=labels, parents=parents, values=values, branchvalues="total", insidetextorientation='radial', marker={"colors": colors} )) return fig def build_difficulty_plot(dataset): xs = sorted(dataset['difficulty'].unique()) ys = [len(dataset[dataset['difficulty'] == x]) for x in xs] fig = go.Figure([go.Bar(x=xs, y=ys, marker={"color": ys, "colorscale": "Viridis", "colorbar": {"title": "Total"}})]) fig.update_layout(yaxis=dict(type='log')) return fig def build_plot(season): dataset = pd.read_json(os.path.join("results", SEASON[season], "test_dataset.json")) return build_keypoint_plot(dataset), build_difficulty_plot(dataset) def create_data(top_components): with gr.Tab("All data"): with gr.Row(): all_keypoint_plot = gr.Plot( plotly.io.from_json(open("assets/keypoint_distribution.json", encoding="utf-8").read()), label="Keypoint Distribution") all_difficulty_plot = gr.Plot( plotly.io.from_json(open("assets/difficulty_distribution.json", encoding="utf-8").read()), label="Difficulty Distribution") with gr.Tab("Test Data"): with gr.Row(): k_fig, d_fig = build_plot("latest") test_keypoint_plot = gr.Plot(k_fig, label="Keypoint Distribution") test_difficulty_plot = gr.Plot(d_fig, label="Difficulty Distribution") return {"all_keypoint": all_keypoint_plot, "all_difficulty": all_difficulty_plot, "test_keypoint": test_keypoint_plot, "test_difficulty": test_difficulty_plot}