awinml's picture
Upload 7 files (#6)
93b6d4c
raw
history blame
5.21 kB
import numpy as np
import streamlit_scrollable_textbox as stx
import pinecone
import streamlit as st
st.set_page_config(layout="wide") # isort: split
from utils import nltkmodules
from utils.models import (
get_bm25_model,
tokenizer,
get_data,
get_instructor_embedding_model,
preprocess_text,
)
from utils.retriever import (
query_pinecone,
format_context,
format_query,
get_bm25_search_hits,
retrieve_transcript,
)
st.title("Instructor XL Embeddings")
st.write(
"The app compares the performance of the Instructor-XL Embedding Model on the text from AMD's Q1 2020 Earnings Call Transcript."
)
data = get_data()
col1, col2 = st.columns([3, 3], gap="medium")
instructor_model = get_instructor_embedding_model()
question_choice = [
"What was discussed regarding Ryzen revenue performance?",
"What is the impact of the enterprise and cloud on AMD's growth",
"What was the impact of situation in China on the sales and revenue?",
]
question_instruction_choice = [
"Represent the financial question for retrieving supporting documents:",
"Represent the financial question for retrieving supporting sentences:",
"Represent the finance query for retrieving supporting documents:",
"Represent the finance query for retrieving related documents:",
"Represent a finance query for retrieving relevant documents:",
]
with col1:
st.subheader("Question")
st.write(
"Choose a preset question example from the dropdown or enter a question in the text box."
)
default_query = st.selectbox("Question Examples", question_choice)
query_text = st.text_area(
"Question",
value=default_query,
)
st.subheader("Question Embedding-Instruction")
st.write(
"Choose a preset instruction example from the dropdown or enter a instruction in the text box."
)
default_query_embedding_instruction = st.selectbox(
"Question Embedding-Instruction Examples", question_instruction_choice
)
query_embedding_instruction = st.text_area(
"Question Embedding-Instruction",
value=default_query_embedding_instruction,
)
num_results = int(
st.number_input("Number of Results to query", 1, 15, value=5)
)
corpus, bm25 = get_bm25_model(data)
tokenized_query = preprocess_text(query_text).split()
sparse_scores = np.argsort(bm25.get_scores(tokenized_query), axis=0)[::-1]
indices = get_bm25_search_hits(corpus, sparse_scores, 50)
dense_embedding = instructor_model.predict(
query_embedding_instruction,
query_text,
api_name="/predict",
)
text_embedding_instructions_choice = [
"Represent the financial statement for retrieval:",
"Represent the financial document for retrieval:",
"Represent the finance passage for retrieval:",
"Represent the earnings call transcript for retrieval:",
"Represent the earnings call transcript sentence for retrieval:",
"Represent the earnings call transcript answer for retrieval:",
]
index_mapping = {
"Represent the financial statement for retrieval:": "week14-instructor-xl-amd-fsr-1",
"Represent the financial document for retrieval:": "week14-instructor-xl-amd-fdr-2",
"Represent the finance passage for retrieval:": "week14-instructor-xl-amd-fpr-3",
"Represent the earnings call transcript for retrieval:": "week14-instructor-xl-amd-ectr-4",
"Represent the earnings call transcript sentence for retrieval:": "week14-instructor-xl-amd-ects-5",
"Represent the earnings call transcript answer for retrieval:": "week14-instructor-xl-amd-ecta-6",
}
with col2:
with st.form("my_form"):
text_embedding_instruction = st.selectbox(
"Select instruction for Text Embedding",
text_embedding_instructions_choice,
)
pinecone_index_name = index_mapping[text_embedding_instruction]
pinecone.init(
api_key=st.secrets[f"pinecone_{pinecone_index_name}"],
environment="asia-southeast1-gcp-free",
)
pinecone_index = pinecone.Index(pinecone_index_name)
submitted = st.form_submit_button("Submit")
if submitted:
matches = query_pinecone(
dense_embedding, num_results, pinecone_index, indices
)
context = format_query(matches)
output_text = format_context(context)
st.subheader("Retrieved Text:")
for output in output_text:
output = f"""{output}"""
st.write(
f"<ul><li><p>{output}</p></li></ul>",
unsafe_allow_html=True,
)
tab1 = st.tabs(["View transcript"])
with col2:
st.subheader("Retrieved Text:")
for output in output_text:
output = f"""{output}"""
st.write(
f"<ul><li><p>{output}</p></li></ul>",
unsafe_allow_html=True,
)
with tab1:
file_text = retrieve_transcript()
with st.expander("See Transcript"):
st.subheader("AMD Q1 2020 Earnings Call Transcript:")
stx.scrollableTextbox(
file_text, height=700, border=False, fontFamily="Helvetica"
)