Spaces:
Runtime error
Runtime error
import numpy as np | |
import streamlit_scrollable_textbox as stx | |
import pinecone | |
import streamlit as st | |
st.set_page_config(layout="wide") # isort: split | |
from utils import nltkmodules | |
from utils.models import ( | |
get_bm25_model, | |
tokenizer, | |
get_data, | |
get_instructor_embedding_model, | |
preprocess_text, | |
) | |
from utils.retriever import ( | |
query_pinecone, | |
format_context, | |
format_query, | |
get_bm25_search_hits, | |
retrieve_transcript, | |
) | |
st.title("Instructor XL Embeddings") | |
st.write( | |
"The app compares the performance of the Instructor-XL Embedding Model on the text from AMD's Q1 2020 Earnings Call Transcript.'" | |
) | |
data = get_data() | |
col1, col2 = st.columns([3, 3], gap="medium") | |
instructor_model = get_instructor_embedding_model() | |
question_choice = [ | |
"What was discussed regarding Ryzen revenue performance?", | |
"What is the impact of the enterprise and cloud on AMD's growth", | |
"What was the impact of situation in China on the sales and revenue?", | |
] | |
question_instruction_choice = [ | |
"Represent the financial question for retrieving supporting documents:", | |
"Represent the financial question for retrieving supporting sentences:", | |
"Represent the finance query for retrieving supporting documents:", | |
"Represent the finance query for retrieving related documents:", | |
"Represent a finance query for retrieving relevant documents:", | |
] | |
with col1: | |
st.subheader("Question") | |
st.write( | |
"Choose a preset question example from the dropdown or enter a question in the text box." | |
) | |
default_query = st.selectbox("Question Examples", question_choice) | |
query_text = st.text_area( | |
"Question", | |
value=default_query, | |
) | |
st.subheader("Question Embedding-Instruction") | |
st.write( | |
"Choose a preset instruction example from the dropdown or enter a instruction in the text box." | |
) | |
default_query_embedding_instruction = st.selectbox( | |
"Question Embedding-Instruction Examples", question_instruction_choice | |
) | |
query_embedding_instruction = st.text_area( | |
"Question Embedding-Instruction", | |
value=default_query_embedding_instruction, | |
) | |
num_results = int( | |
st.number_input("Number of Results to query", 1, 15, value=5) | |
) | |
corpus, bm25 = get_bm25_model(data) | |
tokenized_query = preprocess_text(query_text).split() | |
sparse_scores = np.argsort(bm25.get_scores(tokenized_query), axis=0)[::-1] | |
indices = get_bm25_search_hits(corpus, sparse_scores, 50) | |
dense_embedding = instructor_model.predict( | |
query_embedding_instruction, | |
query_text, | |
api_name="/predict", | |
) | |
text_embedding_instructions_choice = [ | |
"Represent the financial statement for retrieval:", | |
"Represent the financial document for retrieval:", | |
"Represent the finance passage for retrieval:", | |
"Represent the earnings call transcript for retrieval:", | |
"Represent the earnings call transcript sentence for retrieval:", | |
"Represent the earnings call transcript answer for retrieval:", | |
] | |
index_mapping = { | |
"Represent the financial statement for retrieval:": "week14-instructor-xl-amd-fsr-1", | |
"Represent the financial document for retrieval:": "week14-instructor-xl-amd-fdr-2", | |
"Represent the finance passage for retrieval:": "week14-instructor-xl-amd-fpr-3", | |
"Represent the earnings call transcript for retrieval:": "week14-instructor-xl-amd-ectr-4", | |
"Represent the earnings call transcript sentence for retrieval:": "week14-instructor-xl-amd-ects-5", | |
"Represent the earnings call transcript answer for retrieval:": "week14-instructor-xl-amd-ecta-6", | |
} | |
with st.form("my_form"): | |
text_embedding_instruction = st.selectbox( | |
"Select instruction for Text Embedding", | |
text_embedding_instructions_choice, | |
) | |
pinecone_index_name = index_mapping[text_embedding_instruction] | |
pinecone.init( | |
api_key=st.secrets[f"pinecone_{pinecone_index_name}"], | |
environment="asia-southeast1-gcp-free", | |
) | |
pinecone_index = pinecone.Index(pinecone_index_name) | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
matches = query_pinecone( | |
dense_embedding, num_results, pinecone_index, indices | |
) | |
context = format_query(matches) | |
output_text = format_context(context) | |
tab1 = st.tabs(["View transcript"]) | |
with col2: | |
st.subheader("Retrieved Text:") | |
for output in output_text: | |
output = f"""{output}""" | |
st.write( | |
f"<ul><li><p>{output}</p></li></ul>", | |
unsafe_allow_html=True, | |
) | |
with tab1: | |
file_text = retrieve_transcript() | |
with st.expander("See Transcript"): | |
st.subheader("AMD Q1 2020 Earnings Call Transcript:") | |
stx.scrollableTextbox( | |
file_text, height=700, border=False, fontFamily="Helvetica" | |
) | |