awinml's picture
Upload 16 files
dfc16db
raw
history blame
27.5 kB
import re
import openai
import streamlit_scrollable_textbox as stx
import pinecone
import streamlit as st
st.set_page_config(layout="wide") # isort: split
from utils.entity_extraction import (
clean_entities,
extract_quarter_year,
extract_ticker_spacy,
format_entities_flan_alpaca,
generate_alpaca_ner_prompt,
)
from utils.models import (
generate_entities_flan_alpaca_checkpoint,
generate_entities_flan_alpaca_inference_api,
generate_text_flan_t5,
get_data,
get_flan_alpaca_xl_model,
get_flan_t5_model,
get_mpnet_embedding_model,
get_sgpt_embedding_model,
get_instructor_embedding_model,
get_spacy_model,
get_splade_sparse_embedding_model,
get_t5_model,
gpt_turbo_model,
save_key,
)
from utils.prompts import (
generate_flant5_prompt_instruct_chunk_context,
generate_flant5_prompt_instruct_chunk_context_single,
generate_flant5_prompt_instruct_complete_context,
generate_flant5_prompt_summ_chunk_context,
generate_flant5_prompt_summ_chunk_context_single,
generate_gpt_j_two_shot_prompt_1,
generate_gpt_j_two_shot_prompt_2,
generate_gpt_prompt_alpaca,
generate_gpt_prompt_alpaca_multi_doc,
generate_gpt_prompt_alpaca_multi_doc_multi_company,
generate_gpt_prompt_original,
generate_multi_doc_context,
get_context_list_prompt,
)
from utils.retriever import (
format_query,
query_pinecone,
query_pinecone_sparse,
sentence_id_combine,
text_lookup,
year_quarter_range,
)
from utils.transcript_retrieval import retrieve_transcript
from utils.vector_index import (
create_dense_embeddings,
create_sparse_embeddings,
hybrid_score_norm,
)
st.title("Question Answering on Earnings Call Transcripts")
st.write(
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020."
)
col1, col2 = st.columns([3, 3], gap="medium")
with st.sidebar:
ner_choice = st.selectbox("Select NER Model", ["Spacy", "Alpaca"])
document_type = st.selectbox(
"Select Query Type", ["Single-Document", "Multi-Document"]
)
if document_type == "Multi-Document":
multi_company_choice = st.selectbox(
"Select Company Query Type",
["Single-Company", "Compare Companies"],
)
if ner_choice == "Spacy":
ner_model = get_spacy_model()
with col1:
st.subheader("Question")
if document_type == "Single-Document":
query_text = st.text_area(
"Input Query",
value="What was discussed regarding Wearables revenue performance?",
)
else:
if multi_company_choice == "Single-Company":
query_text = st.text_area(
"Input Query",
value="What was the reported revenue for Wearables over the last 2 years?",
)
else:
query_text = st.text_area(
"Input Query",
value="How was AAPL's capex spend compared to GOOGL?",
)
years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
quarters_choice = ["Q1", "Q2", "Q3", "Q4", "All"]
ticker_choice = [
"AAPL",
"CSCO",
"MSFT",
"ASML",
"NVDA",
"GOOGL",
"MU",
"INTC",
"AMZN",
"AMD",
]
if document_type == "Single-Document":
if ner_choice == "Alpaca":
ner_prompt = generate_alpaca_ner_prompt(query_text)
entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(
entity_text
)
else:
company_ent = extract_ticker_spacy(query_text, ner_model)
quarter_ent, year_ent = extract_quarter_year(query_text)
ticker_index, quarter_index, year_index = clean_entities(
company_ent, quarter_ent, year_ent
)
with col1:
# Hardcoding the defaults for a question without metadata
if (
query_text
== "What was discussed regarding Wearables revenue performance?"
):
year = st.selectbox("Year", years_choice)
quarter = st.selectbox("Quarter", quarters_choice)
ticker = st.selectbox("Company", ticker_choice)
else:
year = st.selectbox("Year", years_choice, index=year_index)
quarter = st.selectbox(
"Quarter", quarters_choice, index=quarter_index
)
ticker = st.selectbox("Company", ticker_choice, ticker_index)
participant_type = st.selectbox(
"Speaker", ["Company Speaker", "Analyst"]
)
else:
# Multi-Document Case
with col1:
# Single Company Summary
if multi_company_choice == "Single-Company":
# Hardcoding the defaults for a question without metadata
if (
query_text
== "What was the reported revenue for Wearables over the last 2 years?"
):
start_year = st.selectbox("Start Year", years_choice, index=2)
start_quarter = st.selectbox(
"Start Quarter", quarters_choice, index=0
)
end_year = st.selectbox("End Year", years_choice, index=0)
end_quarter = st.selectbox(
"End Quarter", quarters_choice, index=0
)
ticker = st.selectbox("Company", ticker_choice, index=0)
else:
start_year = st.selectbox("Start Year", years_choice, index=2)
start_quarter = st.selectbox(
"Start Quarter", quarters_choice, index=0
)
end_year = st.selectbox("End Year", years_choice, index=0)
end_quarter = st.selectbox(
"End Quarter", quarters_choice, index=0
)
ticker = st.selectbox("Company", ticker_choice, index=0)
# Single Company Summary
if multi_company_choice == "Compare Companies":
# Hardcoding the defaults for a question without metadata
if query_text == "How was AAPL's capex spend compared to GOOGL?":
start_year = st.selectbox("Start Year", years_choice, index=1)
start_quarter = st.selectbox(
"Start Quarter", quarters_choice, index=0
)
end_year = st.selectbox("End Year", years_choice, index=0)
end_quarter = st.selectbox(
"End Quarter", quarters_choice, index=0
)
ticker_first = st.selectbox(
"First Company", ticker_choice, index=0
)
ticker_second = st.selectbox(
"Second Company", ticker_choice, index=5
)
else:
start_year = st.selectbox("Start Year", years_choice, index=2)
start_quarter = st.selectbox(
"Start Quarter", quarters_choice, index=0
)
end_year = st.selectbox("End Year", years_choice, index=0)
end_quarter = st.selectbox(
"End Quarter", quarters_choice, index=0
)
ticker_first = st.selectbox(
"First Company", ticker_choice, index=0
)
ticker_second = st.selectbox(
"Second Company", ticker_choice, index=1
)
participant_type = st.selectbox(
"Speaker", ["Company Speaker", "Analyst"]
)
with st.sidebar:
st.subheader("Select Options:")
if document_type == "Single-Document":
num_results = int(
st.number_input("Number of Results to query", 1, 15, value=5)
)
else:
num_results = int(
st.number_input("Number of Results to query", 1, 15, value=4)
)
# Choose encoder model
encoder_models_choice = ["MPNET", "Instructor", "SGPT", "Hybrid MPNET - SPLADE"]
with st.sidebar:
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
# Choose decoder model
# Restricting multi-document to only GPT-3
if document_type == "Single-Document":
decoder_models_choice = ["GPT-3.5 Turbo", "T5", "FLAN-T5", "GPT-J"]
else:
decoder_models_choice = ["GPT-3.5 Turbo"]
with st.sidebar:
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
if encoder_model == "MPNET":
# Connect to pinecone environment
pinecone.init(
api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp"
)
pinecone_index_name = "week2-all-mpnet-base"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_mpnet_embedding_model()
elif encoder_model == "SGPT":
# Connect to pinecone environment
pinecone.init(
api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp"
)
pinecone_index_name = "week2-sgpt-125m"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_sgpt_embedding_model()
elif encoder_model == "Instructor":
# Connect to pinecone environment
pinecone.init(
api_key=st.secrets["pinecone_instructor"], environment="us-west4-gcp-free"
)
pinecone_index_name = "week13-instructor-xl"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_instructor_embedding_model()
elif encoder_model == "Hybrid MPNET - SPLADE":
pinecone.init(
api_key=st.secrets["pinecone_hybrid_splade_mpnet"],
environment="us-central1-gcp",
)
pinecone_index_name = "splade-mpnet"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_mpnet_embedding_model()
(
sparse_retriever_model,
sparse_retriever_tokenizer,
) = get_splade_sparse_embedding_model()
with st.sidebar:
if document_type == "Single-Document":
window = int(st.number_input("Sentence Window Size", 0, 10, value=1))
threshold = float(
st.number_input(
label="Similarity Score Threshold",
step=0.05,
format="%.2f",
value=0.25,
)
)
else:
window = int(st.number_input("Sentence Window Size", 0, 10, value=1))
threshold = float(
st.number_input(
label="Similarity Score Threshold",
step=0.05,
format="%.2f",
value=0.6,
)
)
data = get_data()
if document_type == "Single-Document":
if encoder_model == "Hybrid SGPT - SPLADE":
dense_query_embedding = create_dense_embeddings(
query_text, retriever_model
)
sparse_query_embedding = create_sparse_embeddings(
query_text, sparse_retriever_model, sparse_retriever_tokenizer
)
dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
dense_query_embedding, sparse_query_embedding, 0
)
query_results = query_pinecone_sparse(
dense_query_embedding,
sparse_query_embedding,
num_results,
pinecone_index,
year,
quarter,
ticker,
participant_type,
threshold,
)
else:
dense_query_embedding = create_dense_embeddings(
query_text, retriever_model
)
query_results = query_pinecone(
dense_query_embedding,
num_results,
pinecone_index,
year,
quarter,
ticker,
participant_type,
threshold,
)
if threshold <= 0.90:
context_list = sentence_id_combine(data, query_results, lag=window)
else:
context_list = format_query(query_results)
else:
# Multi-Document Retreival
# Single Company
if multi_company_choice == "Single-Company":
if encoder_model == "Hybrid SGPT - SPLADE":
dense_query_embedding = create_dense_embeddings(
query_text, retriever_model
)
sparse_query_embedding = create_sparse_embeddings(
query_text, sparse_retriever_model, sparse_retriever_tokenizer
)
dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
dense_query_embedding, sparse_query_embedding, 0
)
year_quarter_list = year_quarter_range(
start_quarter, start_year, end_quarter, end_year
)
context_group = []
for year, quarter in year_quarter_list:
query_results = query_pinecone_sparse(
dense_query_embedding,
sparse_query_embedding,
num_results,
pinecone_index,
year,
quarter,
ticker,
participant_type,
threshold,
)
results_list = sentence_id_combine(
data, query_results, lag=window
)
context_group.append((results_list, year, quarter, ticker))
else:
dense_query_embedding = create_dense_embeddings(
query_text, retriever_model
)
year_quarter_list = year_quarter_range(
start_quarter, start_year, end_quarter, end_year
)
context_group = []
for year, quarter in year_quarter_list:
query_results = query_pinecone(
dense_query_embedding,
num_results,
pinecone_index,
year,
quarter,
ticker,
participant_type,
threshold,
)
results_list = sentence_id_combine(
data, query_results, lag=window
)
context_group.append((results_list, year, quarter, ticker))
multi_doc_context = generate_multi_doc_context(context_group)
# Companies Comparison
else:
if encoder_model == "Hybrid SGPT - SPLADE":
dense_query_embedding = create_dense_embeddings(
query_text, retriever_model
)
sparse_query_embedding = create_sparse_embeddings(
query_text, sparse_retriever_model, sparse_retriever_tokenizer
)
dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
dense_query_embedding, sparse_query_embedding, 0
)
year_quarter_list = year_quarter_range(
start_quarter, start_year, end_quarter, end_year
)
# First Company Context
context_group_first = []
for year, quarter in year_quarter_list:
query_results = query_pinecone_sparse(
dense_query_embedding,
sparse_query_embedding,
num_results,
pinecone_index,
year,
quarter,
ticker_first,
participant_type,
threshold,
)
results_list = sentence_id_combine(
data, query_results, lag=window
)
context_group_first.append(
(results_list, year, quarter, ticker_first)
)
# Second Company Context
context_group_second = []
for year, quarter in year_quarter_list:
query_results = query_pinecone_sparse(
dense_query_embedding,
sparse_query_embedding,
num_results,
pinecone_index,
year,
quarter,
ticker_second,
participant_type,
threshold,
)
results_list = sentence_id_combine(
data, query_results, lag=window
)
context_group_second.append(
(results_list, year, quarter, ticker_second)
)
else:
dense_query_embedding = create_dense_embeddings(
query_text, retriever_model
)
year_quarter_list = year_quarter_range(
start_quarter, start_year, end_quarter, end_year
)
# First Company Context
context_group_first = []
for year, quarter in year_quarter_list:
query_results = query_pinecone(
dense_query_embedding,
num_results,
pinecone_index,
year,
quarter,
ticker_first,
participant_type,
threshold,
)
results_list = sentence_id_combine(
data, query_results, lag=window
)
context_group_first.append(
(results_list, year, quarter, ticker_first)
)
# Second Company Context
context_group_second = []
for year, quarter in year_quarter_list:
query_results = query_pinecone(
dense_query_embedding,
num_results,
pinecone_index,
year,
quarter,
ticker_second,
participant_type,
threshold,
)
results_list = sentence_id_combine(
data, query_results, lag=window
)
context_group_second.append(
(results_list, year, quarter, ticker_second)
)
multi_doc_context_first = generate_multi_doc_context(
context_group_first
)
multi_doc_context_second = generate_multi_doc_context(
context_group_second
)
if decoder_model == "GPT-3.5 Turbo":
if document_type == "Single-Document":
prompt = generate_gpt_prompt_alpaca(query_text, context_list)
else:
if multi_company_choice == "Single-Company":
prompt = generate_gpt_prompt_alpaca_multi_doc(
query_text, context_group
)
else:
prompt = generate_gpt_prompt_alpaca_multi_doc_multi_company(
query_text, context_group_first, context_group_second
)
with col2:
with st.form("my_form"):
edited_prompt = st.text_area(
label="Model Prompt", value=prompt, height=400
)
openai_key = st.text_input(
"Enter OpenAI key",
value="",
type="password",
)
submitted = st.form_submit_button("Submit")
if submitted:
api_key = save_key(openai_key)
openai.api_key = api_key
generated_text = gpt_turbo_model(edited_prompt)
st.subheader("Answer:")
regex_pattern_sentences = (
"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s"
)
generated_text_list = re.split(
regex_pattern_sentences, generated_text
)
for answer_text in generated_text_list:
answer_text = f"""{answer_text}"""
st.write(
f"<ul><li><p>{answer_text}</p></li></ul>",
unsafe_allow_html=True,
)
elif decoder_model == "T5":
prompt = generate_flant5_prompt_instruct_complete_context(
query_text, context_list
)
t5_pipeline = get_t5_model()
output_text = []
with col2:
with st.form("my_form"):
edited_prompt = st.text_area(
label="Model Prompt", value=prompt, height=400
)
context_list = get_context_list_prompt(edited_prompt)
submitted = st.form_submit_button("Submit")
if submitted:
for context_text in context_list:
output_text.append(
t5_pipeline(context_text)[0]["summary_text"]
)
st.subheader("Answer:")
for text in output_text:
st.markdown(f"- {text}")
elif decoder_model == "FLAN-T5":
flan_t5_model, flan_t5_tokenizer = get_flan_t5_model()
output_text = []
with col2:
prompt_type = st.selectbox(
"Select prompt type",
["Complete Text QA", "Chunkwise QA", "Chunkwise Summarize"],
)
if prompt_type == "Complete Text QA":
prompt = generate_flant5_prompt_instruct_complete_context(
query_text, context_list
)
elif prompt_type == "Chunkwise QA":
st.write("The following prompt is not editable.")
prompt = generate_flant5_prompt_instruct_chunk_context(
query_text, context_list
)
elif prompt_type == "Chunkwise Summarize":
st.write("The following prompt is not editable.")
prompt = generate_flant5_prompt_summ_chunk_context(
query_text, context_list
)
else:
prompt = ""
with st.form("my_form"):
edited_prompt = st.text_area(
label="Model Prompt", value=prompt, height=400
)
submitted = st.form_submit_button("Submit")
if submitted:
if prompt_type == "Complete Text QA":
output_text_string = generate_text_flan_t5(
flan_t5_model, flan_t5_tokenizer, prompt
)
st.subheader("Answer:")
st.write(output_text_string)
elif prompt_type == "Chunkwise QA":
for context_text in context_list:
model_input = generate_flant5_prompt_instruct_chunk_context_single(
query_text, context_text
)
output_text.append(
generate_text_flan_t5(
flan_t5_model, flan_t5_tokenizer, model_input
)
)
st.subheader("Answer:")
for text in output_text:
if "(iii)" not in text:
st.markdown(f"- {text}")
elif prompt_type == "Chunkwise Summarize":
for context_text in context_list:
model_input = (
generate_flant5_prompt_summ_chunk_context_single(
query_text, context_text
)
)
output_text.append(
generate_text_flan_t5(
flan_t5_model, flan_t5_tokenizer, model_input
)
)
st.subheader("Answer:")
for text in output_text:
if "(iii)" not in text:
st.markdown(f"- {text}")
if decoder_model == "GPT-J":
if ticker in ["AAPL", "AMD"]:
prompt = generate_gpt_j_two_shot_prompt_1(query_text, context_list)
elif ticker in ["NVDA", "INTC", "AMZN"]:
prompt = generate_gpt_j_two_shot_prompt_2(query_text, context_list)
else:
prompt = generate_gpt_j_two_shot_prompt_1(query_text, context_list)
with col2:
with st.form("my_form"):
edited_prompt = st.text_area(
label="Model Prompt", value=prompt, height=400
)
st.write(
"The app currently just shows the prompt. The app does not load the model due to memory limitations."
)
submitted = st.form_submit_button("Submit")
tab1, tab2 = st.tabs(["Retrived Text", "Retrieved Documents"])
with tab1:
if document_type == "Single-Document":
with st.expander("See Retrieved Text"):
st.subheader("Retrieved Text:")
for context_text in context_list:
context_text = f"""{context_text}"""
st.write(
f"<ul><li><p>{context_text}</p></li></ul>",
unsafe_allow_html=True,
)
else:
with st.expander("See Retrieved Text"):
st.subheader("Retrieved Text:")
if multi_company_choice == "Compare Companies":
multi_doc_context = (
multi_doc_context_first + multi_doc_context_second
)
sections = [
s.strip()
for s in multi_doc_context.split("Document: ")
if s.strip()
]
# Add "Document: " back to the beginning of each section
context_list = [
"Document: " + s[0:7] + "\n" + s[7:] for s in sections
]
for context_text in context_list:
context_text = f"""{context_text}"""
st.write(
f"<ul><li><p>{context_text}</p></li></ul>",
unsafe_allow_html=True,
)
with tab2:
if document_type == "Single-Document":
file_text = retrieve_transcript(data, year, quarter, ticker)
with st.expander("See Transcript"):
st.subheader("Earnings Call Transcript:")
stx.scrollableTextbox(
file_text, height=700, border=False, fontFamily="Helvetica"
)
else:
if multi_company_choice == "Single-Company":
for year, quarter in year_quarter_list:
file_text = retrieve_transcript(data, year, quarter, ticker)
with st.expander(f"See Transcript - {quarter} {year}"):
st.subheader("Earnings Call Transcript - {quarter} {year}:")
stx.scrollableTextbox(
file_text, height=700, border=False, fontFamily="Helvetica"
)
else:
for year, quarter in year_quarter_list:
file_text = retrieve_transcript(data, year, quarter, ticker_first)
with st.expander(f"See Transcript - {quarter} {year}"):
st.subheader("Earnings Call Transcript - {quarter} {year}:")
stx.scrollableTextbox(
file_text, height=700, border=False, fontFamily="Helvetica"
)
for year, quarter in year_quarter_list:
file_text = retrieve_transcript(data, year, quarter, ticker_second)
with st.expander(f"See Transcript - {quarter} {year}"):
st.subheader("Earnings Call Transcript - {quarter} {year}:")
stx.scrollableTextbox(
file_text, height=700, border=False, fontFamily="Helvetica"
)