Spaces:
Build error
Build error
Upload app.py
Browse files
app.py
CHANGED
@@ -40,6 +40,7 @@ from utils.prompts import (
|
|
40 |
generate_gpt_j_two_shot_prompt_2,
|
41 |
generate_gpt_prompt_alpaca,
|
42 |
generate_gpt_prompt_alpaca_multi_doc,
|
|
|
43 |
generate_gpt_prompt_original,
|
44 |
generate_multi_doc_context,
|
45 |
get_context_list_prompt,
|
@@ -74,6 +75,11 @@ with st.sidebar:
|
|
74 |
document_type = st.selectbox(
|
75 |
"Select Query Type", ["Single-Document", "Multi-Document"]
|
76 |
)
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
if ner_choice == "Spacy":
|
79 |
ner_model = get_spacy_model()
|
@@ -86,11 +92,16 @@ with col1:
|
|
86 |
value="What was discussed regarding Wearables revenue performance?",
|
87 |
)
|
88 |
else:
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
|
96 |
quarters_choice = ["Q1", "Q2", "Q3", "Q4", "All"]
|
@@ -145,32 +156,76 @@ if document_type == "Single-Document":
|
|
145 |
|
146 |
else:
|
147 |
# Multi-Document Case
|
148 |
-
|
149 |
with col1:
|
150 |
-
#
|
151 |
-
if
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
"Start
|
158 |
-
|
|
|
|
|
159 |
|
160 |
-
|
161 |
-
|
|
|
|
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
participant_type = st.selectbox(
|
176 |
"Speaker", ["Company Speaker", "Analyst"]
|
@@ -186,7 +241,7 @@ with st.sidebar:
|
|
186 |
)
|
187 |
else:
|
188 |
num_results = int(
|
189 |
-
st.number_input("Number of Results to query", 1, 15, value=
|
190 |
)
|
191 |
|
192 |
|
@@ -252,7 +307,7 @@ with st.sidebar:
|
|
252 |
)
|
253 |
)
|
254 |
else:
|
255 |
-
window = int(st.number_input("Sentence Window Size", 0, 10, value=
|
256 |
|
257 |
threshold = float(
|
258 |
st.number_input(
|
@@ -310,69 +365,191 @@ if document_type == "Single-Document":
|
|
310 |
|
311 |
else:
|
312 |
# Multi-Document Retreival
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
dense_query_embedding,
|
331 |
-
sparse_query_embedding,
|
332 |
-
num_results,
|
333 |
-
pinecone_index,
|
334 |
-
year,
|
335 |
-
quarter,
|
336 |
-
ticker,
|
337 |
-
participant_type,
|
338 |
-
threshold,
|
339 |
)
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
|
|
|
|
343 |
else:
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
dense_query_embedding,
|
355 |
-
num_results,
|
356 |
-
pinecone_index,
|
357 |
-
year,
|
358 |
-
quarter,
|
359 |
-
ticker,
|
360 |
-
participant_type,
|
361 |
-
threshold,
|
362 |
)
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
|
369 |
if decoder_model == "GPT-3.5 Turbo":
|
370 |
if document_type == "Single-Document":
|
371 |
prompt = generate_gpt_prompt_alpaca(query_text, context_list)
|
372 |
else:
|
373 |
-
|
374 |
-
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
376 |
|
377 |
with col2:
|
378 |
with st.form("my_form"):
|
@@ -527,6 +704,11 @@ with tab1:
|
|
527 |
else:
|
528 |
with st.expander("See Retrieved Text"):
|
529 |
st.subheader("Retrieved Text:")
|
|
|
|
|
|
|
|
|
|
|
530 |
sections = [
|
531 |
s.strip()
|
532 |
for s in multi_doc_context.split("Document: ")
|
@@ -554,10 +736,26 @@ with tab2:
|
|
554 |
file_text, height=700, border=False, fontFamily="Helvetica"
|
555 |
)
|
556 |
else:
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
st.
|
561 |
-
|
562 |
-
|
563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
generate_gpt_j_two_shot_prompt_2,
|
41 |
generate_gpt_prompt_alpaca,
|
42 |
generate_gpt_prompt_alpaca_multi_doc,
|
43 |
+
generate_gpt_prompt_alpaca_multi_doc_multi_company,
|
44 |
generate_gpt_prompt_original,
|
45 |
generate_multi_doc_context,
|
46 |
get_context_list_prompt,
|
|
|
75 |
document_type = st.selectbox(
|
76 |
"Select Query Type", ["Single-Document", "Multi-Document"]
|
77 |
)
|
78 |
+
if document_type == "Multi-Document":
|
79 |
+
multi_company_choice = st.selectbox(
|
80 |
+
"Select Company Query Type",
|
81 |
+
["Single-Company", "Compare Companies"],
|
82 |
+
)
|
83 |
|
84 |
if ner_choice == "Spacy":
|
85 |
ner_model = get_spacy_model()
|
|
|
92 |
value="What was discussed regarding Wearables revenue performance?",
|
93 |
)
|
94 |
else:
|
95 |
+
if multi_company_choice == "Single-Company":
|
96 |
+
query_text = st.text_area(
|
97 |
+
"Input Query",
|
98 |
+
value="What was the reported revenue for Wearables over the last 2 years?",
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
query_text = st.text_area(
|
102 |
+
"Input Query",
|
103 |
+
value="How was AAPL's capex spend compared to GOOGL?",
|
104 |
+
)
|
105 |
|
106 |
years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
|
107 |
quarters_choice = ["Q1", "Q2", "Q3", "Q4", "All"]
|
|
|
156 |
|
157 |
else:
|
158 |
# Multi-Document Case
|
|
|
159 |
with col1:
|
160 |
+
# Single Company Summary
|
161 |
+
if multi_company_choice == "Single-Company":
|
162 |
+
# Hardcoding the defaults for a question without metadata
|
163 |
+
if (
|
164 |
+
query_text
|
165 |
+
== "What was the reported revenue for Wearables over the last 2 years?"
|
166 |
+
):
|
167 |
+
start_year = st.selectbox("Start Year", years_choice, index=2)
|
168 |
+
start_quarter = st.selectbox(
|
169 |
+
"Start Quarter", quarters_choice, index=0
|
170 |
+
)
|
171 |
|
172 |
+
end_year = st.selectbox("End Year", years_choice, index=0)
|
173 |
+
end_quarter = st.selectbox(
|
174 |
+
"End Quarter", quarters_choice, index=0
|
175 |
+
)
|
176 |
|
177 |
+
ticker = st.selectbox("Company", ticker_choice, index=0)
|
178 |
+
else:
|
179 |
+
start_year = st.selectbox("Start Year", years_choice, index=2)
|
180 |
+
start_quarter = st.selectbox(
|
181 |
+
"Start Quarter", quarters_choice, index=0
|
182 |
+
)
|
183 |
+
|
184 |
+
end_year = st.selectbox("End Year", years_choice, index=0)
|
185 |
+
end_quarter = st.selectbox(
|
186 |
+
"End Quarter", quarters_choice, index=0
|
187 |
+
)
|
188 |
+
|
189 |
+
ticker = st.selectbox("Company", ticker_choice, index=0)
|
190 |
+
|
191 |
+
# Single Company Summary
|
192 |
+
if multi_company_choice == "Compare Companies":
|
193 |
+
# Hardcoding the defaults for a question without metadata
|
194 |
+
if query_text == "How was AAPL's capex spend compared to GOOGL?":
|
195 |
+
start_year = st.selectbox("Start Year", years_choice, index=1)
|
196 |
+
start_quarter = st.selectbox(
|
197 |
+
"Start Quarter", quarters_choice, index=0
|
198 |
+
)
|
199 |
|
200 |
+
end_year = st.selectbox("End Year", years_choice, index=0)
|
201 |
+
end_quarter = st.selectbox(
|
202 |
+
"End Quarter", quarters_choice, index=0
|
203 |
+
)
|
204 |
+
|
205 |
+
ticker_first = st.selectbox(
|
206 |
+
"First Company", ticker_choice, index=0
|
207 |
+
)
|
208 |
+
ticker_second = st.selectbox(
|
209 |
+
"Second Company", ticker_choice, index=5
|
210 |
+
)
|
211 |
+
|
212 |
+
else:
|
213 |
+
start_year = st.selectbox("Start Year", years_choice, index=2)
|
214 |
+
start_quarter = st.selectbox(
|
215 |
+
"Start Quarter", quarters_choice, index=0
|
216 |
+
)
|
217 |
|
218 |
+
end_year = st.selectbox("End Year", years_choice, index=0)
|
219 |
+
end_quarter = st.selectbox(
|
220 |
+
"End Quarter", quarters_choice, index=0
|
221 |
+
)
|
222 |
+
|
223 |
+
ticker_first = st.selectbox(
|
224 |
+
"First Company", ticker_choice, index=0
|
225 |
+
)
|
226 |
+
ticker_second = st.selectbox(
|
227 |
+
"Second Company", ticker_choice, index=1
|
228 |
+
)
|
229 |
|
230 |
participant_type = st.selectbox(
|
231 |
"Speaker", ["Company Speaker", "Analyst"]
|
|
|
241 |
)
|
242 |
else:
|
243 |
num_results = int(
|
244 |
+
st.number_input("Number of Results to query", 1, 15, value=4)
|
245 |
)
|
246 |
|
247 |
|
|
|
307 |
)
|
308 |
)
|
309 |
else:
|
310 |
+
window = int(st.number_input("Sentence Window Size", 0, 10, value=1))
|
311 |
|
312 |
threshold = float(
|
313 |
st.number_input(
|
|
|
365 |
|
366 |
else:
|
367 |
# Multi-Document Retreival
|
368 |
+
# Single Company
|
369 |
+
if multi_company_choice == "Single-Company":
|
370 |
+
if encoder_model == "Hybrid SGPT - SPLADE":
|
371 |
+
dense_query_embedding = create_dense_embeddings(
|
372 |
+
query_text, retriever_model
|
373 |
+
)
|
374 |
+
sparse_query_embedding = create_sparse_embeddings(
|
375 |
+
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
376 |
+
)
|
377 |
+
dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
|
378 |
+
dense_query_embedding, sparse_query_embedding, 0
|
379 |
+
)
|
380 |
+
year_quarter_list = year_quarter_range(
|
381 |
+
start_quarter, start_year, end_quarter, end_year
|
382 |
+
)
|
383 |
+
|
384 |
+
context_group = []
|
385 |
+
for year, quarter in year_quarter_list:
|
386 |
+
query_results = query_pinecone_sparse(
|
387 |
+
dense_query_embedding,
|
388 |
+
sparse_query_embedding,
|
389 |
+
num_results,
|
390 |
+
pinecone_index,
|
391 |
+
year,
|
392 |
+
quarter,
|
393 |
+
ticker,
|
394 |
+
participant_type,
|
395 |
+
threshold,
|
396 |
+
)
|
397 |
+
results_list = sentence_id_combine(
|
398 |
+
data, query_results, lag=window
|
399 |
+
)
|
400 |
+
context_group.append((results_list, year, quarter, ticker))
|
401 |
|
402 |
+
else:
|
403 |
+
dense_query_embedding = create_dense_embeddings(
|
404 |
+
query_text, retriever_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
)
|
406 |
+
year_quarter_list = year_quarter_range(
|
407 |
+
start_quarter, start_year, end_quarter, end_year
|
408 |
+
)
|
409 |
+
|
410 |
+
context_group = []
|
411 |
+
for year, quarter in year_quarter_list:
|
412 |
+
query_results = query_pinecone(
|
413 |
+
dense_query_embedding,
|
414 |
+
num_results,
|
415 |
+
pinecone_index,
|
416 |
+
year,
|
417 |
+
quarter,
|
418 |
+
ticker,
|
419 |
+
participant_type,
|
420 |
+
threshold,
|
421 |
+
)
|
422 |
+
results_list = sentence_id_combine(
|
423 |
+
data, query_results, lag=window
|
424 |
+
)
|
425 |
+
context_group.append((results_list, year, quarter, ticker))
|
426 |
|
427 |
+
multi_doc_context = generate_multi_doc_context(context_group)
|
428 |
+
# Companies Comparison
|
429 |
else:
|
430 |
+
if encoder_model == "Hybrid SGPT - SPLADE":
|
431 |
+
dense_query_embedding = create_dense_embeddings(
|
432 |
+
query_text, retriever_model
|
433 |
+
)
|
434 |
+
sparse_query_embedding = create_sparse_embeddings(
|
435 |
+
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
436 |
+
)
|
437 |
+
dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
|
438 |
+
dense_query_embedding, sparse_query_embedding, 0
|
439 |
+
)
|
440 |
+
year_quarter_list = year_quarter_range(
|
441 |
+
start_quarter, start_year, end_quarter, end_year
|
442 |
+
)
|
443 |
+
|
444 |
+
# First Company Context
|
445 |
+
context_group_first = []
|
446 |
+
for year, quarter in year_quarter_list:
|
447 |
+
query_results = query_pinecone_sparse(
|
448 |
+
dense_query_embedding,
|
449 |
+
sparse_query_embedding,
|
450 |
+
num_results,
|
451 |
+
pinecone_index,
|
452 |
+
year,
|
453 |
+
quarter,
|
454 |
+
ticker_first,
|
455 |
+
participant_type,
|
456 |
+
threshold,
|
457 |
+
)
|
458 |
+
results_list = sentence_id_combine(
|
459 |
+
data, query_results, lag=window
|
460 |
+
)
|
461 |
+
context_group_first.append(
|
462 |
+
(results_list, year, quarter, ticker_first)
|
463 |
+
)
|
464 |
+
|
465 |
+
# Second Company Context
|
466 |
+
context_group_second = []
|
467 |
+
for year, quarter in year_quarter_list:
|
468 |
+
query_results = query_pinecone_sparse(
|
469 |
+
dense_query_embedding,
|
470 |
+
sparse_query_embedding,
|
471 |
+
num_results,
|
472 |
+
pinecone_index,
|
473 |
+
year,
|
474 |
+
quarter,
|
475 |
+
ticker_second,
|
476 |
+
participant_type,
|
477 |
+
threshold,
|
478 |
+
)
|
479 |
+
results_list = sentence_id_combine(
|
480 |
+
data, query_results, lag=window
|
481 |
+
)
|
482 |
+
context_group_second.append(
|
483 |
+
(results_list, year, quarter, ticker_second)
|
484 |
+
)
|
485 |
|
486 |
+
else:
|
487 |
+
dense_query_embedding = create_dense_embeddings(
|
488 |
+
query_text, retriever_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
)
|
490 |
+
year_quarter_list = year_quarter_range(
|
491 |
+
start_quarter, start_year, end_quarter, end_year
|
492 |
+
)
|
493 |
+
|
494 |
+
# First Company Context
|
495 |
+
context_group_first = []
|
496 |
+
for year, quarter in year_quarter_list:
|
497 |
+
query_results = query_pinecone(
|
498 |
+
dense_query_embedding,
|
499 |
+
num_results,
|
500 |
+
pinecone_index,
|
501 |
+
year,
|
502 |
+
quarter,
|
503 |
+
ticker_first,
|
504 |
+
participant_type,
|
505 |
+
threshold,
|
506 |
+
)
|
507 |
+
results_list = sentence_id_combine(
|
508 |
+
data, query_results, lag=window
|
509 |
+
)
|
510 |
+
context_group_first.append(
|
511 |
+
(results_list, year, quarter, ticker_first)
|
512 |
+
)
|
513 |
|
514 |
+
# Second Company Context
|
515 |
+
context_group_second = []
|
516 |
+
for year, quarter in year_quarter_list:
|
517 |
+
query_results = query_pinecone(
|
518 |
+
dense_query_embedding,
|
519 |
+
num_results,
|
520 |
+
pinecone_index,
|
521 |
+
year,
|
522 |
+
quarter,
|
523 |
+
ticker_second,
|
524 |
+
participant_type,
|
525 |
+
threshold,
|
526 |
+
)
|
527 |
+
results_list = sentence_id_combine(
|
528 |
+
data, query_results, lag=window
|
529 |
+
)
|
530 |
+
context_group_second.append(
|
531 |
+
(results_list, year, quarter, ticker_second)
|
532 |
+
)
|
533 |
|
534 |
+
multi_doc_context_first = generate_multi_doc_context(
|
535 |
+
context_group_first
|
536 |
+
)
|
537 |
+
multi_doc_context_second = generate_multi_doc_context(
|
538 |
+
context_group_second
|
539 |
+
)
|
540 |
|
541 |
if decoder_model == "GPT-3.5 Turbo":
|
542 |
if document_type == "Single-Document":
|
543 |
prompt = generate_gpt_prompt_alpaca(query_text, context_list)
|
544 |
else:
|
545 |
+
if multi_company_choice == "Single-Company":
|
546 |
+
prompt = generate_gpt_prompt_alpaca_multi_doc(
|
547 |
+
query_text, context_group
|
548 |
+
)
|
549 |
+
else:
|
550 |
+
prompt = generate_gpt_prompt_alpaca_multi_doc_multi_company(
|
551 |
+
query_text, context_group_first, context_group_second
|
552 |
+
)
|
553 |
|
554 |
with col2:
|
555 |
with st.form("my_form"):
|
|
|
704 |
else:
|
705 |
with st.expander("See Retrieved Text"):
|
706 |
st.subheader("Retrieved Text:")
|
707 |
+
if multi_company_choice == "Compare Companies":
|
708 |
+
multi_doc_context = (
|
709 |
+
multi_doc_context_first + multi_doc_context_second
|
710 |
+
)
|
711 |
+
|
712 |
sections = [
|
713 |
s.strip()
|
714 |
for s in multi_doc_context.split("Document: ")
|
|
|
736 |
file_text, height=700, border=False, fontFamily="Helvetica"
|
737 |
)
|
738 |
else:
|
739 |
+
if multi_company_choice == "Single-Company":
|
740 |
+
for year, quarter in year_quarter_list:
|
741 |
+
file_text = retrieve_transcript(data, year, quarter, ticker)
|
742 |
+
with st.expander(f"See Transcript - {quarter} {year}"):
|
743 |
+
st.subheader("Earnings Call Transcript - {quarter} {year}:")
|
744 |
+
stx.scrollableTextbox(
|
745 |
+
file_text, height=700, border=False, fontFamily="Helvetica"
|
746 |
+
)
|
747 |
+
else:
|
748 |
+
for year, quarter in year_quarter_list:
|
749 |
+
file_text = retrieve_transcript(data, year, quarter, ticker_first)
|
750 |
+
with st.expander(f"See Transcript - {quarter} {year}"):
|
751 |
+
st.subheader("Earnings Call Transcript - {quarter} {year}:")
|
752 |
+
stx.scrollableTextbox(
|
753 |
+
file_text, height=700, border=False, fontFamily="Helvetica"
|
754 |
+
)
|
755 |
+
for year, quarter in year_quarter_list:
|
756 |
+
file_text = retrieve_transcript(data, year, quarter, ticker_second)
|
757 |
+
with st.expander(f"See Transcript - {quarter} {year}"):
|
758 |
+
st.subheader("Earnings Call Transcript - {quarter} {year}:")
|
759 |
+
stx.scrollableTextbox(
|
760 |
+
file_text, height=700, border=False, fontFamily="Helvetica"
|
761 |
+
)
|