Spaces:
Runtime error
Runtime error
Merge branch 'main' of https://github.com/pleonova/multi-label-summary-text
Browse files- app.py +57 -10
- examples.json +2 -1
- models.py +5 -1
- requirements.txt +1 -0
- utils.py +1 -1
app.py
CHANGED
@@ -5,12 +5,14 @@ import pandas as pd
|
|
5 |
import base64
|
6 |
from typing import Sequence
|
7 |
import streamlit as st
|
|
|
|
|
8 |
|
9 |
from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
|
10 |
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
|
11 |
import json
|
12 |
|
13 |
-
ex_text, ex_license, ex_labels = examples_load()
|
14 |
ex_long_text = example_long_text_load()
|
15 |
|
16 |
|
@@ -18,18 +20,27 @@ ex_long_text = example_long_text_load()
|
|
18 |
st.header("Summzarization & Multi-label Classification for Long Text")
|
19 |
st.write("This app summarizes and then classifies your long text with multiple labels.")
|
20 |
st.write("__Inputs__: User enters their own custom text and labels.")
|
21 |
-
st.write("__Outputs__: A summary of the text,
|
|
|
22 |
|
23 |
with st.form(key='my_form'):
|
24 |
example_text = ex_long_text #ex_text
|
25 |
display_text = "[Excerpt from Project Gutenberg: Frankenstein]\n" + example_text + "\n\n" + ex_license
|
26 |
-
text_input = st.text_area("Input any text you want to
|
27 |
|
28 |
if text_input == display_text:
|
29 |
text_input = example_text
|
30 |
|
31 |
-
labels = st.text_input('
|
32 |
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
submit_button = st.form_submit_button(label='Submit')
|
34 |
|
35 |
|
@@ -93,15 +104,51 @@ if submit_button:
|
|
93 |
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
|
94 |
|
95 |
data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
|
|
|
96 |
data2 = pd.merge(data, data_ex_text, on = ['label'])
|
97 |
-
st.markdown("### Data Table")
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
with st.spinner('Generating a table of results and a download link...'):
|
100 |
-
coded_data = base64.b64encode(data2.to_csv(index = False). encode ()).decode()
|
101 |
-
st.markdown(
|
102 |
-
f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Click here to download the data</a>',
|
103 |
-
unsafe_allow_html = True
|
104 |
-
)
|
105 |
st.dataframe(data2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
st.success('All done!')
|
107 |
st.balloons()
|
|
|
5 |
import base64
|
6 |
from typing import Sequence
|
7 |
import streamlit as st
|
8 |
+
from sklearn.metrics import classification_report
|
9 |
+
|
10 |
|
11 |
from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
|
12 |
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
|
13 |
import json
|
14 |
|
15 |
+
ex_text, ex_license, ex_labels, ex_glabels = examples_load()
|
16 |
ex_long_text = example_long_text_load()
|
17 |
|
18 |
|
|
|
20 |
st.header("Summzarization & Multi-label Classification for Long Text")
|
21 |
st.write("This app summarizes and then classifies your long text with multiple labels.")
|
22 |
st.write("__Inputs__: User enters their own custom text and labels.")
|
23 |
+
st.write("__Outputs__: A summary of the text, likelihood percentages for each label and a downloadable csv of the results. \
|
24 |
+
Option to evaluate results against a list of ground truth labels, if available.")
|
25 |
|
26 |
with st.form(key='my_form'):
|
27 |
example_text = ex_long_text #ex_text
|
28 |
display_text = "[Excerpt from Project Gutenberg: Frankenstein]\n" + example_text + "\n\n" + ex_license
|
29 |
+
text_input = st.text_area("Input any text you want to summarize & classify here (keep in mind very long text will take a while to process):", display_text)
|
30 |
|
31 |
if text_input == display_text:
|
32 |
text_input = example_text
|
33 |
|
34 |
+
labels = st.text_input('Enter possible labels (comma-separated):',ex_labels, max_chars=1000)
|
35 |
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
36 |
+
|
37 |
+
glabels = st.text_input('If available, enter ground truth labels to evaluate results, otherwise leave blank (comma-separated):',ex_glabels, max_chars=1000)
|
38 |
+
glabels = list(set([x.strip() for x in glabels.strip().split(',') if len(x.strip()) > 0]))
|
39 |
+
|
40 |
+
threshold_value = st.slider(
|
41 |
+
'Select a threshold cutoff for matching percentage (used for ground truth label evaluation)',
|
42 |
+
0.0, 1.0, (0.5))
|
43 |
+
|
44 |
submit_button = st.form_submit_button(label='Submit')
|
45 |
|
46 |
|
|
|
104 |
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
|
105 |
|
106 |
data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
|
107 |
+
|
108 |
data2 = pd.merge(data, data_ex_text, on = ['label'])
|
|
|
109 |
|
110 |
+
if len(glabels) > 0:
|
111 |
+
gdata = pd.DataFrame({'label': glabels})
|
112 |
+
gdata['is_true_label'] = int(1)
|
113 |
+
|
114 |
+
data2 = pd.merge(data2, gdata, how = 'left', on = ['label'])
|
115 |
+
data2['is_true_label'].fillna(0, inplace = True)
|
116 |
+
|
117 |
+
st.markdown("### Data Table")
|
118 |
with st.spinner('Generating a table of results and a download link...'):
|
|
|
|
|
|
|
|
|
|
|
119 |
st.dataframe(data2)
|
120 |
+
|
121 |
+
@st.cache
|
122 |
+
def convert_df(df):
|
123 |
+
# IMPORTANT: Cache the conversion to prevent computation on every rerun
|
124 |
+
return df.to_csv().encode('utf-8')
|
125 |
+
csv = convert_df(data2)
|
126 |
+
st.download_button(
|
127 |
+
label="Download data as CSV",
|
128 |
+
data=csv,
|
129 |
+
file_name='text_labels.csv',
|
130 |
+
mime='text/csv',
|
131 |
+
)
|
132 |
+
# coded_data = base64.b64encode(data2.to_csv(index = False). encode ()).decode()
|
133 |
+
# st.markdown(
|
134 |
+
# f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Click here to download the data</a>',
|
135 |
+
# unsafe_allow_html = True
|
136 |
+
# )
|
137 |
+
|
138 |
+
if len(glabels) > 0:
|
139 |
+
st.markdown("### Evaluation Metrics")
|
140 |
+
with st.spinner('Evaluating output against ground truth...'):
|
141 |
+
|
142 |
+
section_header_description = ['Summary Label Performance', 'Original Full Text Label Performance']
|
143 |
+
data_headers = ['scores_from_summary', 'scores_from_full_text']
|
144 |
+
for i in range(0,2):
|
145 |
+
st.markdown(f"##### {section_header_description[i]}")
|
146 |
+
report = classification_report(y_true = data2[['is_true_label']],
|
147 |
+
y_pred = (data2[[data_headers[i]]] >= threshold_value) * 1.0,
|
148 |
+
output_dict=True)
|
149 |
+
df_report = pd.DataFrame(report).transpose()
|
150 |
+
st.markdown(f"Threshold set for: {threshold_value}")
|
151 |
+
st.dataframe(df_report)
|
152 |
+
|
153 |
st.success('All done!')
|
154 |
st.balloons()
|
examples.json
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
{
|
2 |
"text": "Such were the professor’s words—rather let me say such the words of the fate—enounced to destroy me. As he went on I felt as if my soul were grappling with a palpable enemy; one by one the various keys were touched which formed the mechanism of my being; chord after chord was sounded, and soon my mind was filled with one thought, one conception, one purpose. So much has been done, exclaimed the soul of Frankenstein—more, far more, will I achieve; treading in the steps already marked, I will pioneer a new way, explore unknown powers, and unfold to the world the deepest mysteries of creation.",
|
3 |
"long_text_license": "[This eBook is for the use of anyone anywhere in the United States and most other parts of the world at no cost and with almost no restrictions whatsoever. You may copy it, give it away or re-use it under the terms of the Project Gutenberg License included with this eBook or online at www.gutenberg.org. If you are not located in the United States, you will have to check the laws of the country where you are located before using this eBook.]",
|
4 |
-
"labels":"Batman,Science,Sound,Light,Creation,Optics,Eyes,Engineering,Color,Communication,Death"
|
|
|
5 |
}
|
|
|
1 |
{
|
2 |
"text": "Such were the professor’s words—rather let me say such the words of the fate—enounced to destroy me. As he went on I felt as if my soul were grappling with a palpable enemy; one by one the various keys were touched which formed the mechanism of my being; chord after chord was sounded, and soon my mind was filled with one thought, one conception, one purpose. So much has been done, exclaimed the soul of Frankenstein—more, far more, will I achieve; treading in the steps already marked, I will pioneer a new way, explore unknown powers, and unfold to the world the deepest mysteries of creation.",
|
3 |
"long_text_license": "[This eBook is for the use of anyone anywhere in the United States and most other parts of the world at no cost and with almost no restrictions whatsoever. You may copy it, give it away or re-use it under the terms of the Project Gutenberg License included with this eBook or online at www.gutenberg.org. If you are not located in the United States, you will have to check the laws of the country where you are located before using this eBook.]",
|
4 |
+
"labels":"Batman,Science,Sound,Light,Creation,Optics,Eyes,Engineering,Color,Communication,Death",
|
5 |
+
"ground_labels":"Science,Sound,Light,Creation,Engineering,Communication,Death"
|
6 |
}
|
models.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
|
|
|
|
3 |
|
4 |
import spacy
|
5 |
nlp = spacy.load('en_core_web_sm')
|
@@ -28,6 +30,7 @@ def create_nest_sentences(document:str, token_max_length = 1024):
|
|
28 |
return nested
|
29 |
|
30 |
# Reference: https://huggingface.co/facebook/bart-large-mnli
|
|
|
31 |
def load_summary_model():
|
32 |
model_name = "facebook/bart-large-mnli"
|
33 |
summarizer = pipeline(task='summarization', model=model_name)
|
@@ -41,7 +44,7 @@ def load_summary_model():
|
|
41 |
# return summarizer
|
42 |
|
43 |
def summarizer_gen(summarizer, sequence:str, maximum_tokens:int, minimum_tokens:int):
|
44 |
-
output = summarizer(sequence, num_beams=4, max_length=maximum_tokens, min_length=minimum_tokens, do_sample=False)
|
45 |
return output[0].get('summary_text')
|
46 |
|
47 |
|
@@ -57,6 +60,7 @@ def summarizer_gen(summarizer, sequence:str, maximum_tokens:int, minimum_tokens:
|
|
57 |
|
58 |
|
59 |
# Reference: https://huggingface.co/spaces/team-zero-shot-nli/zero-shot-nli/blob/main/utils.py
|
|
|
60 |
def load_model():
|
61 |
model_name = "facebook/bart-large-mnli"
|
62 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
|
6 |
import spacy
|
7 |
nlp = spacy.load('en_core_web_sm')
|
|
|
30 |
return nested
|
31 |
|
32 |
# Reference: https://huggingface.co/facebook/bart-large-mnli
|
33 |
+
@st.cache(allow_output_mutation=True)
|
34 |
def load_summary_model():
|
35 |
model_name = "facebook/bart-large-mnli"
|
36 |
summarizer = pipeline(task='summarization', model=model_name)
|
|
|
44 |
# return summarizer
|
45 |
|
46 |
def summarizer_gen(summarizer, sequence:str, maximum_tokens:int, minimum_tokens:int):
|
47 |
+
output = summarizer(sequence, num_beams=4, max_length=maximum_tokens, min_length=minimum_tokens, do_sample=False, early_stopping = True)
|
48 |
return output[0].get('summary_text')
|
49 |
|
50 |
|
|
|
60 |
|
61 |
|
62 |
# Reference: https://huggingface.co/spaces/team-zero-shot-nli/zero-shot-nli/blob/main/utils.py
|
63 |
+
@st.cache(allow_output_mutation=True)
|
64 |
def load_model():
|
65 |
model_name = "facebook/bart-large-mnli"
|
66 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
requirements.txt
CHANGED
@@ -3,5 +3,6 @@ pandas
|
|
3 |
streamlit
|
4 |
plotly
|
5 |
torch
|
|
|
6 |
spacy>=2.2.0,<3.0.0
|
7 |
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.0/en_core_web_sm-2.2.0.tar.gz#egg=en_core_web_sm
|
|
|
3 |
streamlit
|
4 |
plotly
|
5 |
torch
|
6 |
+
sklearn
|
7 |
spacy>=2.2.0,<3.0.0
|
8 |
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.0/en_core_web_sm-2.2.0.tar.gz#egg=en_core_web_sm
|
utils.py
CHANGED
@@ -77,7 +77,7 @@ def plot_dual_bar_chart(topics_summary, scores_summary, topics_text, scores_text
|
|
77 |
def examples_load():
|
78 |
with open("examples.json") as f:
|
79 |
data=json.load(f)
|
80 |
-
return data['text'], data['long_text_license'], data['labels']
|
81 |
|
82 |
def example_long_text_load():
|
83 |
with open("example_long_text.txt", "r") as f:
|
|
|
77 |
def examples_load():
|
78 |
with open("examples.json") as f:
|
79 |
data=json.load(f)
|
80 |
+
return data['text'], data['long_text_license'], data['labels'], data['ground_labels']
|
81 |
|
82 |
def example_long_text_load():
|
83 |
with open("example_long_text.txt", "r") as f:
|