File size: 3,454 Bytes
01d9aad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0c7a3f
01d9aad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96aa704
01d9aad
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
import numpy as np
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import json

# Reference: https://huggingface.co./spaces/team-zero-shot-nli/zero-shot-nli/blob/main/utils.py
def plot_result(top_topics, scores):
    top_topics = np.array(top_topics)
    scores = np.array(scores)
    scores *= 100
    fig = px.bar(x=np.around(scores,2), y=top_topics, orientation='h', 
                 labels={'x': 'Confidence Score', 'y': 'Label'},
                 text=scores,
                 range_x=(0,115),
                 title='Predictions',
                 color=np.linspace(0,1,len(scores)),
                 color_continuous_scale='GnBu')
    fig.update(layout_coloraxis_showscale=False)
    fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
    st.plotly_chart(fig)


def plot_dual_bar_chart(topics_summary, scores_summary, topics_text, scores_text):
    data1 = pd.DataFrame({'label': topics_summary, 'scores on summary': scores_summary})
    data2 = pd.DataFrame({'label': topics_text, 'scores on full text': scores_text})
    data = pd.merge(data1, data2, on = ['label'])
    data.sort_values('scores on summary', ascending = True, inplace = True)

    fig = make_subplots(rows=1, cols=2, 
        subplot_titles=("Predictions on Summary", "Predictions on Full Text"),
        )

    fig1 = px.bar(x=round(data['scores on summary']*100, 2), y=data['label'], orientation='h', 
                 text=round(data['scores on summary']*100, 2),
                 )

    fig2 = px.bar(x=round(data['scores on full text']*100,2), y=data['label'], orientation='h', 
                 text=round(data['scores on full text']*100,2),
                 )

    fig.add_trace(fig1['data'][0], row=1, col=1)
    fig.add_trace(fig2['data'][0], row=1, col=2)

    fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
    fig.update_layout(height=600, width=700) #, title_text="Predictions for")
    fig.update_xaxes(range=[0,115])
    fig.update_xaxes(matches='x')
    fig.update_yaxes(showticklabels=False) # hide all the xticks
    fig.update_yaxes(showticklabels=True, row=1, col=1)

    st.plotly_chart(fig)

# def plot_dual_bar_chart(topics_summary, scores_summary, topics_text, scores_text):
#     data1 = pd.DataFrame({'label': topics_summary, 'scores': scores_summary})
#     data1['classification_on'] = 'summary'
#     data2 = pd.DataFrame({'label': topics_text, 'scores': scores_text})
#     data2['classification_on'] = 'full text'
#     data = pd.concat([data1, data2])
#     data['scores'] = round(data['scores']*100,2)

#     fig = px.bar(
#         data, x="scores", y="label", #orientation = 'h',
#                  labels={'x': 'Confidence Score', 'y': 'Label'},
#                  text=data['scores'],
#                  range_x=(0,115),
#                  color="label", barmode="group", 
#                  facet_col="classification_on",
#                  category_orders={"classification_on": ["summary", "full text"]}
#        )
#     fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')

#     st.plotly_chart(fig)


def examples_load():
    with open("examples.json") as f:
        data=json.load(f)
    return data['text'], data['long_text_license'], data['labels'], data['ground_labels']

def example_long_text_load():
    with open("example_long_text.txt", "r") as f:
        text_data = f.read()
    return text_data