Spaces:
Runtime error
Runtime error
File size: 5,040 Bytes
92808fd 2563fc9 92808fd 93b6d4c 92808fd 1915fad 92808fd 93b6d4c 92808fd 93b6d4c 92808fd 93b6d4c 1915fad 93b6d4c 92808fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import numpy as np
import streamlit_scrollable_textbox as stx
from ast import literal_eval
import pinecone
import streamlit as st
st.set_page_config(layout="wide") # isort: split
from utils import nltkmodules
from utils.models import (
get_bm25_model,
tokenizer,
get_data,
get_instructor_embedding_model,
preprocess_text,
)
from utils.retriever import (
query_pinecone,
format_context,
format_query,
get_bm25_search_hits,
retrieve_transcript,
)
st.title("Instructor XL Embeddings")
st.write(
"The app compares the performance of the Instructor-XL Embedding Model on the text from AMD's Q1 2020 Earnings Call Transcript."
)
data = get_data()
col1, col2 = st.columns([3, 3], gap="medium")
instructor_model = get_instructor_embedding_model()
question_choice = [
"What was discussed regarding Ryzen revenue performance?",
"What is the impact of the enterprise and cloud on AMD's growth",
"What was the impact of situation in China on the sales and revenue?",
]
question_instruction_choice = [
"Represent the financial question for retrieving supporting documents:",
"Represent the financial question for retrieving supporting sentences:",
"Represent the finance query for retrieving supporting documents:",
"Represent the finance query for retrieving related documents:",
"Represent a finance query for retrieving relevant documents:",
]
with col1:
st.subheader("Question")
st.write(
"Choose a preset question example from the dropdown or enter a question in the text box."
)
default_query = st.selectbox("Question Examples", question_choice)
query_text = st.text_area(
"Question",
value=default_query,
)
st.subheader("Question Embedding-Instruction")
st.write(
"Choose a preset instruction example from the dropdown or enter a instruction in the text box."
)
default_query_embedding_instruction = st.selectbox(
"Question Embedding-Instruction Examples", question_instruction_choice
)
query_embedding_instruction = st.text_area(
"Question Embedding-Instruction",
value=default_query_embedding_instruction,
)
num_results = int(
st.number_input("Number of Results to query", 1, 15, value=5)
)
corpus, bm25 = get_bm25_model(data)
tokenized_query = preprocess_text(query_text).split()
sparse_scores = np.argsort(bm25.get_scores(tokenized_query), axis=0)[::-1]
indices = get_bm25_search_hits(corpus, sparse_scores, 50)
dense_embedding = literal_eval(
instructor_model.predict(
query_embedding_instruction,
query_text,
api_name="/predict",
)
)
text_embedding_instructions_choice = [
"Represent the financial statement for retrieval:",
"Represent the financial document for retrieval:",
"Represent the finance passage for retrieval:",
"Represent the earnings call transcript for retrieval:",
"Represent the earnings call transcript sentence for retrieval:",
"Represent the earnings call transcript answer for retrieval:",
]
index_mapping = {
"Represent the financial statement for retrieval:": "week14-instructor-xl-amd-fsr-1",
"Represent the financial document for retrieval:": "week14-instructor-xl-amd-fdr-2",
"Represent the finance passage for retrieval:": "week14-instructor-xl-amd-fpr-3",
"Represent the earnings call transcript for retrieval:": "week14-instructor-xl-amd-ectr-4",
"Represent the earnings call transcript sentence for retrieval:": "week14-instructor-xl-amd-ects-5",
"Represent the earnings call transcript answer for retrieval:": "week14-instructor-xl-amd-ecta-6",
}
with col2:
with st.form("my_form"):
text_embedding_instruction = st.selectbox(
"Select instruction for Text Embedding",
text_embedding_instructions_choice,
)
pinecone_index_name = index_mapping[text_embedding_instruction]
pinecone.init(
api_key=st.secrets[f"pinecone_{pinecone_index_name}"],
environment="asia-southeast1-gcp-free",
)
pinecone_index = pinecone.Index(pinecone_index_name)
submitted = st.form_submit_button("Submit")
if submitted:
matches = query_pinecone(
dense_embedding, num_results, pinecone_index, indices
)
context = format_query(matches)
output_text = format_context(context)
st.subheader("Retrieved Text:")
for output in output_text:
output = f"""{output}"""
st.write(
f"<ul><li><p>{output}</p></li></ul>",
unsafe_allow_html=True,
)
tab1 = st.tabs(["View transcript"])
with tab1:
file_text = retrieve_transcript()
with st.expander("See Transcript"):
st.subheader("AMD Q1 2020 Earnings Call Transcript:")
stx.scrollableTextbox(
file_text, height=700, border=False, fontFamily="Helvetica"
)
|