import numpy as np import streamlit_scrollable_textbox as stx import pinecone import streamlit as st st.set_page_config(layout="wide") # isort: split from utils import nltkmodules from utils.models import ( get_bm25_model, tokenizer, get_data, get_instructor_embedding_model, preprocess_text, ) from utils.retriever import ( query_pinecone, format_context, format_query, get_bm25_search_hits, retrieve_transcript, ) st.title("Instructor XL Embeddings") st.write( "The app compares the performance of the Instructor-XL Embedding Model on the text from AMD's Q1 2020 Earnings Call Transcript.'" ) data = get_data() col1, col2 = st.columns([3, 3], gap="medium") instructor_model = get_instructor_embedding_model() question_choice = [ "What was discussed regarding Ryzen revenue performance?", "What is the impact of the enterprise and cloud on AMD's growth", "What was the impact of situation in China on the sales and revenue?", ] question_instruction_choice = [ "Represent the financial question for retrieving supporting documents:", "Represent the financial question for retrieving supporting sentences:", "Represent the finance query for retrieving supporting documents:", "Represent the finance query for retrieving related documents:", "Represent a finance query for retrieving relevant documents:", ] with col1: st.subheader("Question") st.write( "Choose a preset question example from the dropdown or enter a question in the text box." ) default_query = st.selectbox("Question Examples", question_choice) query_text = st.text_area( "Question", value=default_query, ) st.subheader("Question Embedding-Instruction") st.write( "Choose a preset instruction example from the dropdown or enter a instruction in the text box." ) default_query_embedding_instruction = st.selectbox( "Question Embedding-Instruction Examples", question_instruction_choice ) query_embedding_instruction = st.text_area( "Question Embedding-Instruction", value=default_query_embedding_instruction, ) num_results = int( st.number_input("Number of Results to query", 1, 15, value=5) ) corpus, bm25 = get_bm25_model(data) tokenized_query = preprocess_text(query_text).split() sparse_scores = np.argsort(bm25.get_scores(tokenized_query), axis=0)[::-1] indices = get_bm25_search_hits(corpus, sparse_scores, 50) dense_embedding = instructor_model.predict( query_embedding_instruction, query_text, api_name="/predict", ) text_embedding_instructions_choice = [ "Represent the financial statement for retrieval:", "Represent the financial document for retrieval:", "Represent the finance passage for retrieval:", "Represent the earnings call transcript for retrieval:", "Represent the earnings call transcript sentence for retrieval:", "Represent the earnings call transcript answer for retrieval:", ] index_mapping = { "Represent the financial statement for retrieval:": "week14-instructor-xl-amd-fsr-1", "Represent the financial document for retrieval:": "week14-instructor-xl-amd-fdr-2", "Represent the finance passage for retrieval:": "week14-instructor-xl-amd-fpr-3", "Represent the earnings call transcript for retrieval:": "week14-instructor-xl-amd-ectr-4", "Represent the earnings call transcript sentence for retrieval:": "week14-instructor-xl-amd-ects-5", "Represent the earnings call transcript answer for retrieval:": "week14-instructor-xl-amd-ecta-6", } with st.form("my_form"): text_embedding_instruction = st.selectbox( "Select instruction for Text Embedding", text_embedding_instructions_choice, ) pinecone_index_name = index_mapping[text_embedding_instruction] pinecone.init( api_key=st.secrets[f"pinecone_{pinecone_index_name}"], environment="asia-southeast1-gcp-free", ) pinecone_index = pinecone.Index(pinecone_index_name) submitted = st.form_submit_button("Submit") if submitted: matches = query_pinecone( dense_embedding, num_results, pinecone_index, indices ) context = format_query(matches) output_text = format_context(context) tab1 = st.tabs(["View transcript"]) with col2: st.subheader("Retrieved Text:") for output in output_text: output = f"""{output}""" st.write( f"", unsafe_allow_html=True, ) with tab1: file_text = retrieve_transcript() with st.expander("See Transcript"): st.subheader("AMD Q1 2020 Earnings Call Transcript:") stx.scrollableTextbox( file_text, height=700, border=False, fontFamily="Helvetica" )