Spaces:
Runtime error
Runtime error
Upload 7 files
#5
by
awinml
- opened
- app.py +160 -0
- requirements.txt +11 -0
- utils/__init__.py +0 -0
- utils/bm25.py +0 -0
- utils/models.py +58 -0
- utils/nltkmodules.py +5 -0
- utils/retriever.py +46 -0
app.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import streamlit_scrollable_textbox as stx
|
3 |
+
|
4 |
+
import pinecone
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
st.set_page_config(layout="wide") # isort: split
|
8 |
+
|
9 |
+
|
10 |
+
from utils import nltkmodules
|
11 |
+
from utils.models import (
|
12 |
+
get_bm25_model,
|
13 |
+
tokenizer,
|
14 |
+
get_data,
|
15 |
+
get_instructor_embedding_model,
|
16 |
+
preprocess_text,
|
17 |
+
)
|
18 |
+
from utils.retriever import (
|
19 |
+
query_pinecone,
|
20 |
+
format_context,
|
21 |
+
format_query,
|
22 |
+
get_bm25_search_hits,
|
23 |
+
retrieve_transcript,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
st.title("Instructor XL Embeddings")
|
28 |
+
|
29 |
+
|
30 |
+
st.write(
|
31 |
+
"The app compares the performance of the Instructor-XL Embedding Model on the text from AMD's Q1 2020 Earnings Call Transcript.'"
|
32 |
+
)
|
33 |
+
|
34 |
+
data = get_data()
|
35 |
+
|
36 |
+
|
37 |
+
col1, col2 = st.columns([3, 3], gap="medium")
|
38 |
+
|
39 |
+
instructor_model = get_instructor_embedding_model()
|
40 |
+
|
41 |
+
question_choice = [
|
42 |
+
"What was discussed regarding Ryzen revenue performance?",
|
43 |
+
"What is the impact of the enterprise and cloud on AMD's growth",
|
44 |
+
"What was the impact of situation in China on the sales and revenue?",
|
45 |
+
]
|
46 |
+
|
47 |
+
question_instruction_choice = [
|
48 |
+
"Represent the financial question for retrieving supporting documents:",
|
49 |
+
"Represent the financial question for retrieving supporting sentences:",
|
50 |
+
"Represent the finance query for retrieving supporting documents:",
|
51 |
+
"Represent the finance query for retrieving related documents:",
|
52 |
+
"Represent a finance query for retrieving relevant documents:",
|
53 |
+
]
|
54 |
+
|
55 |
+
|
56 |
+
with col1:
|
57 |
+
st.subheader("Question")
|
58 |
+
st.write(
|
59 |
+
"Choose a preset question example from the dropdown or enter a question in the text box."
|
60 |
+
)
|
61 |
+
default_query = st.selectbox("Question Examples", question_choice)
|
62 |
+
|
63 |
+
query_text = st.text_area(
|
64 |
+
"Question",
|
65 |
+
value=default_query,
|
66 |
+
)
|
67 |
+
|
68 |
+
st.subheader("Question Embedding-Instruction")
|
69 |
+
st.write(
|
70 |
+
"Choose a preset instruction example from the dropdown or enter a instruction in the text box."
|
71 |
+
)
|
72 |
+
default_query_embedding_instruction = st.selectbox(
|
73 |
+
"Question Embedding-Instruction Examples", question_instruction_choice
|
74 |
+
)
|
75 |
+
|
76 |
+
query_embedding_instruction = st.text_area(
|
77 |
+
"Question Embedding-Instruction",
|
78 |
+
value=default_query_embedding_instruction,
|
79 |
+
)
|
80 |
+
|
81 |
+
num_results = int(
|
82 |
+
st.number_input("Number of Results to query", 1, 15, value=5)
|
83 |
+
)
|
84 |
+
|
85 |
+
corpus, bm25 = get_bm25_model(data)
|
86 |
+
|
87 |
+
tokenized_query = preprocess_text(query_text).split()
|
88 |
+
sparse_scores = np.argsort(bm25.get_scores(tokenized_query), axis=0)[::-1]
|
89 |
+
indices = get_bm25_search_hits(corpus, sparse_scores, 50)
|
90 |
+
|
91 |
+
|
92 |
+
dense_embedding = instructor_model.predict(
|
93 |
+
query_embedding_instruction,
|
94 |
+
query_text,
|
95 |
+
api_name="/predict",
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
text_embedding_instructions_choice = [
|
100 |
+
"Represent the financial statement for retrieval:",
|
101 |
+
"Represent the financial document for retrieval:",
|
102 |
+
"Represent the finance passage for retrieval:",
|
103 |
+
"Represent the earnings call transcript for retrieval:",
|
104 |
+
"Represent the earnings call transcript sentence for retrieval:",
|
105 |
+
"Represent the earnings call transcript answer for retrieval:",
|
106 |
+
]
|
107 |
+
|
108 |
+
index_mapping = {
|
109 |
+
"Represent the financial statement for retrieval:": "week14-instructor-xl-amd-fsr-1",
|
110 |
+
"Represent the financial document for retrieval:": "week14-instructor-xl-amd-fdr-2",
|
111 |
+
"Represent the finance passage for retrieval:": "week14-instructor-xl-amd-fpr-3",
|
112 |
+
"Represent the earnings call transcript for retrieval:": "week14-instructor-xl-amd-ectr-4",
|
113 |
+
"Represent the earnings call transcript sentence for retrieval:": "week14-instructor-xl-amd-ects-5",
|
114 |
+
"Represent the earnings call transcript answer for retrieval:": "week14-instructor-xl-amd-ecta-6",
|
115 |
+
}
|
116 |
+
|
117 |
+
|
118 |
+
with st.form("my_form"):
|
119 |
+
text_embedding_instruction = st.selectbox(
|
120 |
+
"Select instruction for Text Embedding",
|
121 |
+
text_embedding_instructions_choice,
|
122 |
+
)
|
123 |
+
|
124 |
+
pinecone_index_name = index_mapping[text_embedding_instruction]
|
125 |
+
pinecone.init(
|
126 |
+
api_key=st.secrets[f"pinecone_{pinecone_index_name}"],
|
127 |
+
environment="asia-southeast1-gcp-free",
|
128 |
+
)
|
129 |
+
|
130 |
+
pinecone_index = pinecone.Index(pinecone_index_name)
|
131 |
+
|
132 |
+
submitted = st.form_submit_button("Submit")
|
133 |
+
if submitted:
|
134 |
+
matches = query_pinecone(
|
135 |
+
dense_embedding, num_results, pinecone_index, indices
|
136 |
+
)
|
137 |
+
context = format_query(matches)
|
138 |
+
output_text = format_context(context)
|
139 |
+
|
140 |
+
|
141 |
+
tab1 = st.tabs(["View transcript"])
|
142 |
+
|
143 |
+
|
144 |
+
with col2:
|
145 |
+
st.subheader("Retrieved Text:")
|
146 |
+
for output in output_text:
|
147 |
+
output = f"""{output}"""
|
148 |
+
st.write(
|
149 |
+
f"<ul><li><p>{output}</p></li></ul>",
|
150 |
+
unsafe_allow_html=True,
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
with tab1:
|
155 |
+
file_text = retrieve_transcript()
|
156 |
+
with st.expander("See Transcript"):
|
157 |
+
st.subheader("AMD Q1 2020 Earnings Call Transcript:")
|
158 |
+
stx.scrollableTextbox(
|
159 |
+
file_text, height=700, border=False, fontFamily="Helvetica"
|
160 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas
|
2 |
+
nltk
|
3 |
+
tqdm
|
4 |
+
pinecone-client
|
5 |
+
torch
|
6 |
+
git+https://github.com/UKPLab/sentence-transformers.git
|
7 |
+
streamlit
|
8 |
+
streamlit-scrollable-textbox
|
9 |
+
InstructorEmbedding
|
10 |
+
gradio_client
|
11 |
+
rank_bm25
|
utils/__init__.py
ADDED
File without changes
|
utils/bm25.py
ADDED
File without changes
|
utils/models.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from gradio_client import Client
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
from rank_bm25 import BM25Okapi, BM25L, BM25Plus
|
6 |
+
import numpy as np
|
7 |
+
import nltk
|
8 |
+
from nltk.tokenize import word_tokenize
|
9 |
+
from nltk.corpus import stopwords
|
10 |
+
from nltk.stem.porter import PorterStemmer
|
11 |
+
import re
|
12 |
+
|
13 |
+
|
14 |
+
def tokenizer(
|
15 |
+
string, reg="[a-zA-Z'-]+|[0-9]{1,}%|[0-9]{1,}\.[0-9]{1,}%|\d+\.\d+%}"
|
16 |
+
):
|
17 |
+
regex = reg
|
18 |
+
string = string.replace("-", " ")
|
19 |
+
return " ".join(re.findall(regex, string))
|
20 |
+
|
21 |
+
|
22 |
+
def preprocess_text(text):
|
23 |
+
# Convert to lowercase
|
24 |
+
text = text.lower()
|
25 |
+
# Tokenize the text
|
26 |
+
tokens = word_tokenize(text)
|
27 |
+
# Remove stop words
|
28 |
+
stop_words = set(stopwords.words("english"))
|
29 |
+
tokens = [token for token in tokens if token not in stop_words]
|
30 |
+
# Stem the tokens
|
31 |
+
porter_stemmer = PorterStemmer()
|
32 |
+
tokens = [porter_stemmer.stem(token) for token in tokens]
|
33 |
+
# Join the tokens back into a single string
|
34 |
+
preprocessed_text = " ".join(tokens)
|
35 |
+
preprocessed_text = tokenizer(preprocessed_text)
|
36 |
+
|
37 |
+
return preprocessed_text
|
38 |
+
|
39 |
+
|
40 |
+
@st.experimental_singleton
|
41 |
+
def get_data():
|
42 |
+
data = pd.read_csv("AMD_Q1_2020_earnings_call_data_keywords.csv")
|
43 |
+
return data
|
44 |
+
|
45 |
+
|
46 |
+
@st.experimental_singleton
|
47 |
+
def get_instructor_embedding_model():
|
48 |
+
client = Client("https://awinml-api-instructor-xl-1.hf.space/")
|
49 |
+
return client
|
50 |
+
|
51 |
+
|
52 |
+
@st.experimental_singleton
|
53 |
+
def get_bm25_model(data):
|
54 |
+
corpus = data.Text.tolist()
|
55 |
+
corpus_clean = [preprocess_text(x) for x in corpus]
|
56 |
+
tokenized_corpus = [doc.split(" ") for doc in corpus_clean]
|
57 |
+
bm25 = BM25Plus(tokenized_corpus)
|
58 |
+
return corpus, bm25
|
utils/nltkmodules.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
|
3 |
+
nltk.download("wordnet")
|
4 |
+
nltk.download("punkt")
|
5 |
+
nltk.download("stopwords")
|
utils/retriever.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def query_pinecone(dense_vec, top_k, index, indices):
|
5 |
+
xc = index.query(
|
6 |
+
vector=dense_vec,
|
7 |
+
top_k=top_k,
|
8 |
+
filter={"QA_Flag": {"$eq": "Answer"}, "index": {"$in": indices}},
|
9 |
+
include_metadata=True,
|
10 |
+
)
|
11 |
+
return xc["matches"]
|
12 |
+
|
13 |
+
|
14 |
+
def format_query(query_results):
|
15 |
+
# extract passage_text from Pinecone search result
|
16 |
+
context = [
|
17 |
+
(result["metadata"]["Text"], result["score"])
|
18 |
+
for result in query_results
|
19 |
+
]
|
20 |
+
return context
|
21 |
+
|
22 |
+
|
23 |
+
def format_context(context):
|
24 |
+
output_text = []
|
25 |
+
for text, score in context:
|
26 |
+
output_text.append(f"Text: {text}\nCosine Similarity: {score}")
|
27 |
+
return output_text
|
28 |
+
|
29 |
+
|
30 |
+
def get_bm25_search_hits(corpus, sparse_scores, top_n=50):
|
31 |
+
bm25_search = []
|
32 |
+
indices = []
|
33 |
+
for idx in sparse_scores:
|
34 |
+
if len(bm25_search) <= top_n:
|
35 |
+
bm25_search.append(corpus[idx])
|
36 |
+
indices.append(idx)
|
37 |
+
return indices
|
38 |
+
|
39 |
+
|
40 |
+
def retrieve_transcript():
|
41 |
+
open_file = open(
|
42 |
+
f"2020-Apr-28-AMD.txt",
|
43 |
+
"r",
|
44 |
+
)
|
45 |
+
file_text = open_file.read()
|
46 |
+
return f"""{file_text}"""
|