Spaces:
Build error
Build error
import openai | |
import pinecone | |
import streamlit_scrollable_textbox as stx | |
import streamlit as st | |
from utils import ( | |
clean_entities, | |
create_dense_embeddings, | |
create_sparse_embeddings, | |
extract_entities, | |
format_query, | |
generate_flant5_prompt, | |
generate_gpt_prompt, | |
get_context_list_prompt, | |
get_data, | |
get_flan_t5_model, | |
get_mpnet_embedding_model, | |
get_sgpt_embedding_model, | |
get_spacy_model, | |
get_splade_sparse_embedding_model, | |
get_t5_model, | |
gpt_model, | |
hybrid_score_norm, | |
query_pinecone, | |
query_pinecone_sparse, | |
retrieve_transcript, | |
save_key, | |
sentence_id_combine, | |
text_lookup, | |
) | |
st.set_page_config(layout="wide") # isort: skip | |
st.title("Abstractive Question Answering") | |
st.write( | |
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020." | |
) | |
col1, col2 = st.columns([3, 3], gap="medium") | |
spacy_model = get_spacy_model() | |
with col1: | |
st.subheader("Question") | |
query_text = st.text_input( | |
"Input Query", | |
value="What was discussed regarding Wearables revenue performance in Q1 2020?", | |
) | |
company_ent, quarter_ent, year_ent = extract_entities(query_text, spacy_model) | |
ticker_index, quarter_index, year_index = clean_entities( | |
company_ent, quarter_ent, year_ent | |
) | |
with col1: | |
years_choice = ["2020", "2019", "2018", "2017", "2016", "All"] | |
with col1: | |
year = st.selectbox("Year", years_choice, index=year_index) | |
with col1: | |
quarter = st.selectbox( | |
"Quarter", ["Q1", "Q2", "Q3", "Q4", "All"], index=quarter_index | |
) | |
with col1: | |
participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"]) | |
ticker_choice = [ | |
"AAPL", | |
"CSCO", | |
"MSFT", | |
"ASML", | |
"NVDA", | |
"GOOGL", | |
"MU", | |
"INTC", | |
"AMZN", | |
"AMD", | |
] | |
with col1: | |
ticker = st.selectbox("Company", ticker_choice, ticker_index) | |
with st.sidebar: | |
st.subheader("Select Options:") | |
with st.sidebar: | |
num_results = int( | |
st.number_input("Number of Results to query", 1, 15, value=6) | |
) | |
# Choose encoder model | |
encoder_models_choice = ["MPNET", "SGPT", "Hybrid MPNET - SPLADE"] | |
with st.sidebar: | |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice) | |
# Choose decoder model | |
decoder_models_choice = [ | |
"GPT3 - (text-davinci-003)", | |
"T5", | |
"FLAN-T5", | |
] | |
with st.sidebar: | |
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice) | |
if encoder_model == "MPNET": | |
# Connect to pinecone environment | |
pinecone.init( | |
api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp" | |
) | |
pinecone_index_name = "week2-all-mpnet-base" | |
pinecone_index = pinecone.Index(pinecone_index_name) | |
retriever_model = get_mpnet_embedding_model() | |
elif encoder_model == "SGPT": | |
# Connect to pinecone environment | |
pinecone.init( | |
api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp" | |
) | |
pinecone_index_name = "week2-sgpt-125m" | |
pinecone_index = pinecone.Index(pinecone_index_name) | |
retriever_model = get_sgpt_embedding_model() | |
elif encoder_model == "Hybrid MPNET - SPLADE": | |
pinecone.init( | |
api_key=st.secrets["pinecone_hybrid_splade_mpnet"], | |
environment="us-central1-gcp", | |
) | |
pinecone_index_name = "splade-mpnet" | |
pinecone_index = pinecone.Index(pinecone_index_name) | |
retriever_model = get_mpnet_embedding_model() | |
( | |
sparse_retriever_model, | |
sparse_retriever_tokenizer, | |
) = get_splade_sparse_embedding_model() | |
with st.sidebar: | |
window = int(st.number_input("Sentence Window Size", 0, 10, value=1)) | |
with st.sidebar: | |
threshold = float( | |
st.number_input( | |
label="Similarity Score Threshold", | |
step=0.05, | |
format="%.2f", | |
value=0.25, | |
) | |
) | |
data = get_data() | |
if encoder_model == "Hybrid SGPT - SPLADE": | |
dense_query_embedding = create_dense_embeddings( | |
query_text, retriever_model | |
) | |
sparse_query_embedding = create_sparse_embeddings( | |
query_text, sparse_retriever_model, sparse_retriever_tokenizer | |
) | |
dense_query_embedding, sparse_query_embedding = hybrid_score_norm( | |
dense_query_embedding, sparse_query_embedding, 0 | |
) | |
query_results = query_pinecone_sparse( | |
dense_query_embedding, | |
sparse_query_embedding, | |
num_results, | |
pinecone_index, | |
year, | |
quarter, | |
ticker, | |
participant_type, | |
threshold, | |
) | |
else: | |
dense_query_embedding = create_dense_embeddings( | |
query_text, retriever_model | |
) | |
query_results = query_pinecone( | |
dense_query_embedding, | |
num_results, | |
pinecone_index, | |
year, | |
quarter, | |
ticker, | |
participant_type, | |
threshold, | |
) | |
if threshold <= 0.90: | |
context_list = sentence_id_combine(data, query_results, lag=window) | |
else: | |
context_list = format_query(query_results) | |
if decoder_model == "GPT3 - (text-davinci-003)": | |
prompt = generate_gpt_prompt(query_text, context_list) | |
with col2: | |
with st.form("my_form"): | |
edited_prompt = st.text_area( | |
label="Model Prompt", value=prompt, height=270 | |
) | |
openai_key = st.text_input( | |
"Enter OpenAI key", | |
value="", | |
type="password", | |
) | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
api_key = save_key(openai_key) | |
openai.api_key = api_key | |
generated_text = gpt_model(edited_prompt) | |
st.subheader("Answer:") | |
st.write(generated_text) | |
elif decoder_model == "T5": | |
prompt = generate_flant5_prompt(query_text, context_list) | |
t5_pipeline = get_t5_model() | |
output_text = [] | |
with col2: | |
with st.form("my_form"): | |
edited_prompt = st.text_area( | |
label="Model Prompt", value=prompt, height=270 | |
) | |
context_list = get_context_list_prompt(edited_prompt) | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
for context_text in context_list: | |
output_text.append( | |
t5_pipeline(context_text)[0]["summary_text"] | |
) | |
st.subheader("Answer:") | |
for text in output_text: | |
st.markdown(f"- {text}") | |
elif decoder_model == "FLAN-T5": | |
prompt = generate_flant5_prompt(query_text, context_list) | |
flan_t5_pipeline = get_flan_t5_model() | |
output_text = [] | |
with col2: | |
with st.form("my_form"): | |
edited_prompt = st.text_area( | |
label="Model Prompt", value=prompt, height=270 | |
) | |
context_list = get_context_list_prompt(edited_prompt) | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
for context_text in context_list: | |
output_text.append( | |
flan_t5_pipeline("Summarize:" + context_text)[0][ | |
"summary_text" | |
] | |
) | |
st.subheader("Answer:") | |
for text in output_text: | |
if "(iii)" not in text: | |
st.markdown(f"- {text}") | |
with col1: | |
with st.expander("See Retrieved Text"): | |
for context_text in context_list: | |
st.markdown(f"- {context_text}") | |
file_text = retrieve_transcript(data, year, quarter, ticker) | |
with col1: | |
with st.expander("See Transcript"): | |
stx.scrollableTextbox( | |
file_text, height=700, border=False, fontFamily="Helvetica" | |
) | |