ango
upload all files
f1e1ac2
raw
history blame
No virus
3.89 kB
import json
import os
import gradio as gr
import plotly.graph_objects as go
from assets.constant import DELIMITER
from assets.path import SEASON
DEEPEST = 4
def build_plot(category_result, columns):
k_x, k_y, k_text, k_color = [], [], [], []
d_xy = {}
for c in columns:
k_x.append(c.split(DELIMITER)[-1])
result = category_result.get(c)
k_y.append(round(result.get("acc"), 4))
sub_count = sum([1 for k in category_result if k.startswith(c)]) - 1
k_text.append(
f'hit:{result.get("hit")} sub_count:{sub_count}')
for d, v in result['difficulty'].items():
if d not in d_xy:
d_xy[d] = {"hit": 0, "all": 0}
d_xy[d]['hit'] += v['hit']
d_xy[d]['all'] += v['all']
k_color.append(result.get("all"))
d_x = sorted(d_xy, reverse=True)
d_y, d_text, d_color = [], [], []
for d in d_x:
v = d_xy[d]
d_y.append(v['hit'] / v['all'])
d_text.append(f'hit/total:{v["hit"]}/{v["all"]}')
d_color.append(v['all'])
k_fig = go.Figure([go.Bar(x=k_x, y=k_y, hovertext=k_text, marker={"color": k_color, "colorscale": "Viridis",
"colorbar": {"title": "Total"}})])
k_fig.update_layout(yaxis=dict(range=[0, 1]))
d_fig = go.Figure([go.Bar(x=d_x, y=d_y, hovertext=d_text,
marker={"color": d_color, "colorscale": "Cividis", "colorbar": {"title": "Total"}})])
d_fig.update_layout(yaxis=dict(range=[0, 1]))
return k_fig, d_fig, k_x
def create_detail(top_components):
models = os.listdir(os.path.join("results", SEASON["latest"], "details"))
model_dropdown = gr.Dropdown(choices=models, label="Select Model")
category_result = gr.State()
with gr.Row():
keypoint_dropdowns = [gr.Dropdown([], visible=False, label=f"Level{i + 1}") for i in range(DEEPEST)]
keypoint_plot = gr.Plot(label="Keypoint Acc")
difficulty_plot = gr.Plot(label="Difficulty Acc")
for i in range(DEEPEST):
keypoint_dropdown = keypoint_dropdowns[i]
def keypoint_dropdown_func(x, *args):
keypoints = DELIMITER.join(args)
columns = [k for k in x if k.startswith(keypoints) and k.count(DELIMITER) == len(args)]
sub = True
if not columns:
columns = [keypoints]
sub = False
k_fig, d_fig, choices = build_plot(x, columns)
updates = list(args) + [gr.update(choices=choices, visible=sub)] + [
gr.update(choices=[], visible=False)] * (DEEPEST - len(args) - 1)
return gr.update(value=k_fig), gr.update(value=d_fig), *updates
keypoint_dropdown.input(keypoint_dropdown_func, [category_result, *keypoint_dropdowns[:i + 1]],
[keypoint_plot, difficulty_plot, *keypoint_dropdowns])
def model_dropdown_func(x):
dir = os.path.join("results", SEASON["latest"], "details", x)
new_category_result = json.load(open(os.path.join(dir, "category_result-all.json"), encoding="utf-8"))
columns = sorted([k for k in new_category_result if k.count(DELIMITER) == 0],
key=lambda c: new_category_result[c]['all'], reverse=True)
k_fig, d_fig, choices = build_plot(new_category_result, columns)
return new_category_result, gr.update(value=k_fig), gr.update(value=d_fig), gr.update(choices=choices,
visible=True), *[
gr.update(value=None, visible=False) for _ in range(DEEPEST - 1)]
model_dropdown.change(model_dropdown_func, model_dropdown, [category_result, keypoint_plot, difficulty_plot,
*keypoint_dropdowns])