Spaces:
Runtime error
Runtime error
File size: 3,454 Bytes
01d9aad a0c7a3f 01d9aad 96aa704 01d9aad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import streamlit as st
import numpy as np
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import json
# Reference: https://huggingface.co./spaces/team-zero-shot-nli/zero-shot-nli/blob/main/utils.py
def plot_result(top_topics, scores):
top_topics = np.array(top_topics)
scores = np.array(scores)
scores *= 100
fig = px.bar(x=np.around(scores,2), y=top_topics, orientation='h',
labels={'x': 'Confidence Score', 'y': 'Label'},
text=scores,
range_x=(0,115),
title='Predictions',
color=np.linspace(0,1,len(scores)),
color_continuous_scale='GnBu')
fig.update(layout_coloraxis_showscale=False)
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
st.plotly_chart(fig)
def plot_dual_bar_chart(topics_summary, scores_summary, topics_text, scores_text):
data1 = pd.DataFrame({'label': topics_summary, 'scores on summary': scores_summary})
data2 = pd.DataFrame({'label': topics_text, 'scores on full text': scores_text})
data = pd.merge(data1, data2, on = ['label'])
data.sort_values('scores on summary', ascending = True, inplace = True)
fig = make_subplots(rows=1, cols=2,
subplot_titles=("Predictions on Summary", "Predictions on Full Text"),
)
fig1 = px.bar(x=round(data['scores on summary']*100, 2), y=data['label'], orientation='h',
text=round(data['scores on summary']*100, 2),
)
fig2 = px.bar(x=round(data['scores on full text']*100,2), y=data['label'], orientation='h',
text=round(data['scores on full text']*100,2),
)
fig.add_trace(fig1['data'][0], row=1, col=1)
fig.add_trace(fig2['data'][0], row=1, col=2)
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
fig.update_layout(height=600, width=700) #, title_text="Predictions for")
fig.update_xaxes(range=[0,115])
fig.update_xaxes(matches='x')
fig.update_yaxes(showticklabels=False) # hide all the xticks
fig.update_yaxes(showticklabels=True, row=1, col=1)
st.plotly_chart(fig)
# def plot_dual_bar_chart(topics_summary, scores_summary, topics_text, scores_text):
# data1 = pd.DataFrame({'label': topics_summary, 'scores': scores_summary})
# data1['classification_on'] = 'summary'
# data2 = pd.DataFrame({'label': topics_text, 'scores': scores_text})
# data2['classification_on'] = 'full text'
# data = pd.concat([data1, data2])
# data['scores'] = round(data['scores']*100,2)
# fig = px.bar(
# data, x="scores", y="label", #orientation = 'h',
# labels={'x': 'Confidence Score', 'y': 'Label'},
# text=data['scores'],
# range_x=(0,115),
# color="label", barmode="group",
# facet_col="classification_on",
# category_orders={"classification_on": ["summary", "full text"]}
# )
# fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
# st.plotly_chart(fig)
def examples_load():
with open("examples.json") as f:
data=json.load(f)
return data['text'], data['long_text_license'], data['labels'], data['ground_labels']
def example_long_text_load():
with open("example_long_text.txt", "r") as f:
text_data = f.read()
return text_data
|