Spaces:
Sleeping
Sleeping
import re | |
import numpy as np | |
import openai | |
import streamlit_scrollable_textbox as stx | |
import pinecone | |
import streamlit as st | |
st.set_page_config(layout="wide") # isort: split | |
from utils import nltkmodules | |
from utils.entity_extraction import ( | |
extract_entities_docs, | |
year_quarter_range, | |
clean_companies, | |
ticker_year_quarter_tuples_creator, | |
extract_entities_keywords, | |
clean_keywords_all_combs, | |
) | |
from utils.models import ( | |
get_alpaca_model, | |
get_vicuna_ner_1_model, | |
get_vicuna_ner_2_model, | |
get_vicuna_text_gen_model, | |
get_data, | |
get_instructor_embedding_model_api, | |
gpt_turbo_model, | |
vicuna_text_generate, | |
save_key, | |
) | |
from utils.prompts import ( | |
generate_prompt_alpaca_style, | |
generate_multi_doc_context, | |
) | |
from utils.retriever import ( | |
query_pinecone, | |
sentence_id_combine, | |
get_indices_bm25, | |
) | |
from utils.transcript_retrieval import retrieve_transcript | |
from utils.vector_index import create_dense_embeddings | |
st.title("Question Answering on Earnings Call Transcripts") | |
st.write( | |
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020." | |
) | |
# Caching Resources and Model APIs | |
data = get_data() | |
alpaca_model = get_alpaca_model() | |
vicuna_ner_1_model = get_vicuna_ner_1_model() | |
vicuna_ner_2_model = get_vicuna_ner_2_model() | |
vicuna_text_gen_model = get_vicuna_text_gen_model() | |
# Sidebar Options | |
decoder_models_choice = ["GPT-3.5 Turbo", "Vicuna-7B"] | |
with st.sidebar: | |
st.subheader("Select Options:") | |
num_results = int( | |
st.number_input("Number of Results to query", 1, 15, value=4) | |
) | |
window = int(st.number_input("Sentence Window Size", 0, 10, value=1)) | |
threshold = float( | |
st.number_input( | |
label="Similarity Score Threshold", | |
step=0.05, | |
format="%.2f", | |
value=0.6, | |
) | |
) | |
use_bm25 = st.checkbox("Use 2-Stage Retrieval (BM25)", value=True) | |
num_candidates = int( | |
st.number_input( | |
"Number of Candidates to Generate:", | |
25, | |
200, | |
step=25, | |
value=50, | |
) | |
) | |
decoder_model = st.selectbox( | |
"Select Text Generation Model", decoder_models_choice | |
) | |
col1, col2 = st.columns([3, 3], gap="medium") | |
with col1: | |
query_text = st.text_area( | |
"Input Query", | |
value="How has the growth been for AMD in the PC market in 2020?", | |
) | |
# Extracting Document Entities from Question | |
( | |
companies, | |
start_quarter, | |
start_year, | |
end_quarter, | |
end_year, | |
) = extract_entities_docs(query_text, vicuna_ner_1_model) | |
year_quarter_range_list = year_quarter_range( | |
start_quarter, start_year, end_quarter, end_year | |
) | |
ticker_list = clean_companies(companies) | |
ticker_year_quarter_tuples_list = ticker_year_quarter_tuples_creator( | |
ticker_list, year_quarter_range_list | |
) | |
# Extract keywords from query | |
all_keywords = extract_entities_keywords(query_text, vicuna_ner_2_model) | |
if all_keywords != []: | |
keywords = clean_keywords_all_combs(all_keywords) | |
else: | |
keywords = None | |
# Connect to PineCone Vector Database - Instructor Model | |
pinecone.init( | |
api_key=st.secrets["pinecone_instructor"], | |
environment="us-west4-gcp-free", | |
) | |
pinecone_index_name = "week13-instructor-xl" | |
pinecone_index = pinecone.Index(pinecone_index_name) | |
retriever_model = get_instructor_embedding_model_api() | |
instruction = ( | |
"Represent the financial question for retrieving supporting documents:" | |
) | |
dense_query_embedding = create_dense_embeddings( | |
query_text, retriever_model, instruction | |
) | |
context_group = [] | |
if ticker_year_quarter_tuples_list != []: | |
for ticker, quarter, year in ticker_year_quarter_tuples_list: | |
if use_bm25 == True: | |
indices = get_indices_bm25( | |
data, ticker, quarter, year, num_candidates | |
) | |
else: | |
indices = None | |
query_results = query_pinecone( | |
dense_query_embedding, | |
num_results, | |
pinecone_index, | |
year, | |
quarter, | |
ticker, | |
keywords, | |
indices, | |
threshold, | |
) | |
context = sentence_id_combine(data, query_results, lag=window) | |
context_group.append((context, year, quarter, ticker)) | |
multi_doc_context = generate_multi_doc_context(context_group) | |
else: | |
indices = None | |
query_results = query_pinecone( | |
dense_query_embedding, | |
num_results, | |
pinecone_index, | |
None, | |
None, | |
None, | |
keywords, | |
indices, | |
threshold, | |
) | |
multi_doc_context = sentence_id_combine(data, query_results, lag=window) | |
prompt = generate_prompt_alpaca_style(query_text, multi_doc_context) | |
with col1: | |
edited_prompt = st.text_area( | |
label="Model Prompt", value=prompt, height=400 | |
) | |
if decoder_model == "GPT-3.5 Turbo": | |
with col2: | |
with st.form("gpt_form"): | |
openai_key = st.text_input( | |
"Enter OpenAI key", | |
value="", | |
type="password", | |
) | |
gpt_submitted = st.form_submit_button("Submit") | |
if gpt_submitted: | |
api_key = save_key(openai_key) | |
openai.api_key = api_key | |
generated_text = gpt_turbo_model(edited_prompt) | |
st.subheader("Answer:") | |
regex_pattern_sentences = ( | |
"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s" | |
) | |
generated_text_list = re.split( | |
regex_pattern_sentences, generated_text | |
) | |
for answer_text in generated_text_list: | |
answer_text = f"""{answer_text}""" | |
st.write( | |
f"<ul><li><p>{answer_text}</p></li></ul>", | |
unsafe_allow_html=True, | |
) | |
if decoder_model == "Vicuna-7B": | |
with col2: | |
st.write("The Vicuna Model is running: ...") | |
st.write("The model takes 10-15 mins to generate the text.") | |
generated_text = vicuna_text_generate(prompt, vicuna_text_gen_model) | |
st.subheader("Answer:") | |
regex_pattern_sentences = "(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s" | |
generated_text_list = re.split(regex_pattern_sentences, generated_text) | |
for answer_text in generated_text_list: | |
answer_text = f"""{answer_text}""" | |
st.write( | |
f"<ul><li><p>{answer_text}</p></li></ul>", | |
unsafe_allow_html=True, | |
) | |
tab1, tab2 = st.tabs(["Retrieved Text", "Retrieved Documents"]) | |
with tab1: | |
with st.expander("See Retrieved Text"): | |
st.subheader("Retrieved Text:") | |
st.write( | |
f"<p>{multi_doc_context}</p>", | |
unsafe_allow_html=True, | |
) | |
with tab2: | |
if ticker_year_quarter_tuples_list != []: | |
for ticker, quarter, year in ticker_year_quarter_tuples_list: | |
file_text = retrieve_transcript(data, year, quarter, ticker) | |
with st.expander(f"See Transcript - {quarter} {year}"): | |
st.subheader(f"Earnings Call Transcript - {quarter} {year}:") | |
stx.scrollableTextbox( | |
file_text, | |
height=700, | |
border=False, | |
fontFamily="Helvetica", | |
) | |
else: | |
st.write( | |
"No specific document/documents found. Please mention Ticker and Duration in the Question." | |
) | |