File size: 2,661 Bytes
f1e1ac2
 
 
 
 
 
 
 
cce0a52
f1e1ac2
 
 
 
 
 
 
 
cce0a52
 
f1e1ac2
 
 
 
 
cce0a52
 
 
 
 
 
f1e1ac2
 
 
cce0a52
f1e1ac2
 
 
 
 
 
 
 
 
 
cce0a52
 
 
 
f1e1ac2
 
 
 
 
 
 
 
 
 
 
 
 
cce0a52
f1e1ac2
 
 
cce0a52
f1e1ac2
 
cce0a52
f1e1ac2
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os

import gradio as gr
import pandas as pd
import plotly
import plotly.graph_objects as go

from assets.color import color_dict
from assets.content import KEYPOINT_DISTRIBUTION, DIFFICULTY_DISTRIBUTION
from assets.path import SEASON


def read_testset(season):
    return pd.read_json(os.path.join("results", SEASON[season], "test_dataset.json"))


def build_keypoint_plot(dataset):
    labels, parents, values, colors = {}, [], [], []
    for categories, count in dataset['categories'].value_counts().items():
        for category in categories:
            parent = ""
            for keypoint in category:
                if not keypoint:
                    keypoint = "未分类"
                if keypoint not in labels:
                    labels[keypoint] = len(labels)
                    values.append(0)
                    parents.append(parent)
                    colors.append(color_dict[category[0]])
                values[labels[keypoint]] += count
                parent = keypoint

    fig = go.Figure(go.Sunburst(
        labels=list(labels),
        parents=parents,
        values=values,
        branchvalues="total",
        insidetextorientation='radial',
        marker={"colors": colors}
    ))
    return fig


def build_difficulty_plot(dataset):
    xs, ys = [], []
    for x, y in dataset['difficulty'].value_counts().sort_index().items():
        xs.append(x)
        ys.append(y)

    fig = go.Figure([go.Bar(x=xs, y=ys, marker={"color": ys, "colorscale": "Viridis",
                                                "colorbar": {"title": "Total"}})])
    fig.update_layout(yaxis=dict(type='log'))
    return fig


def build_plot(season):
    dataset = pd.read_json(os.path.join("results", SEASON[season], "test_dataset.json"))
    return build_keypoint_plot(dataset), build_difficulty_plot(dataset)


def create_data(top_components):
    k_fig, d_fig = build_plot("latest")
    with gr.Tab("All data"):
        with gr.Row():
            all_keypoint_plot = gr.Plot(
                plotly.io.from_json(KEYPOINT_DISTRIBUTION),
                label="Keypoint Distribution")
            all_difficulty_plot = gr.Plot(
                plotly.io.from_json(DIFFICULTY_DISTRIBUTION),
                label="Difficulty Distribution")
    with gr.Tab("Test Data"):
        with gr.Row():
            test_keypoint_plot = gr.Plot(k_fig, label="Keypoint Distribution")
            test_difficulty_plot = gr.Plot(d_fig, label="Difficulty Distribution")

    return {"all_keypoint": all_keypoint_plot, "all_difficulty": all_difficulty_plot,
            "test_keypoint": test_keypoint_plot, "test_difficulty": test_difficulty_plot}