Files changed (2) hide show
  1. app.py +37 -29
  2. utils/retriever.py +16 -8
app.py CHANGED
@@ -30,9 +30,15 @@ st.title("Instructor XL Embeddings")
30
 
31
 
32
  st.write(
33
- "The app compares the performance of the Instructor-XL Embedding Model on the text from AMD's Q1 2020 Earnings Call Transcript."
 
 
 
34
  )
35
 
 
 
 
36
  data = get_data()
37
 
38
 
@@ -86,6 +92,7 @@ with col1:
86
  st.number_input("Number of Results to query", 1, 15, value=5)
87
  )
88
 
 
89
  corpus, bm25 = get_bm25_model(data)
90
 
91
  tokenized_query = preprocess_text(query_text).split()
@@ -127,35 +134,36 @@ index_mapping = {
127
  }
128
 
129
  with col2:
130
- with st.form("my_form"):
131
- text_embedding_instruction = st.selectbox(
132
- "Select instruction for Text Embedding",
133
- text_embedding_instructions_choice,
134
- )
 
 
 
 
135
 
136
- submitted = st.form_submit_button("Submit")
137
- if submitted:
138
- pinecone_index_name = index_mapping[text_embedding_instruction]
139
- pinecone.init(
140
- api_key=st.secrets[f"pinecone_{pinecone_index_name}"],
141
- environment="asia-southeast1-gcp-free",
142
- )
143
-
144
- pinecone_index = pinecone.Index(pinecone_index_name)
145
-
146
- matches = query_pinecone(
147
- dense_vec=dense_embedding_api, top_k=num_results, index=pinecone_index, indices=indices
148
- )
149
- context = format_query(matches)
150
- output_text = format_context(context)
151
-
152
- st.subheader("Retrieved Text:")
153
- for output in output_text:
154
- output = f"""{output}"""
155
- st.write(
156
- f"<ul><li><p>{output}</p></li></ul>",
157
- unsafe_allow_html=True,
158
- )
159
 
160
 
161
  file_text = retrieve_transcript()
 
30
 
31
 
32
  st.write(
33
+ """The app compares the performance of different instructions using the Instructor-XL Embedding Model on the text from AMD's Q1 2020 Earnings Call Transcript.
34
+ The app uses a two stage retreival process:
35
+ 1. BM-25 to filter the results based on keyword matching,
36
+ 2. Instructor-XL to perform Semantic Search."""
37
  )
38
 
39
+ use_bm25 = st.checkbox('Use BM25 for filtering results')
40
+
41
+
42
  data = get_data()
43
 
44
 
 
92
  st.number_input("Number of Results to query", 1, 15, value=5)
93
  )
94
 
95
+
96
  corpus, bm25 = get_bm25_model(data)
97
 
98
  tokenized_query = preprocess_text(query_text).split()
 
134
  }
135
 
136
  with col2:
137
+ text_embedding_instruction = st.selectbox(
138
+ "Select instruction for Text Embedding",
139
+ text_embedding_instructions_choice,
140
+ )
141
+ pinecone_index_name = index_mapping[text_embedding_instruction]
142
+ pinecone.init(
143
+ api_key=st.secrets[f"pinecone_{pinecone_index_name}"],
144
+ environment="asia-southeast1-gcp-free",
145
+ )
146
 
147
+ pinecone_index = pinecone.Index(pinecone_index_name)
148
+
149
+ if use_bm25==True:
150
+ matches = query_pinecone(
151
+ dense_vec=dense_embedding_api, top_k=num_results, index=pinecone_index, indices=indices
152
+ )
153
+ else:
154
+ matches = query_pinecone(
155
+ dense_vec=dense_embedding_api, top_k=num_results, index=pinecone_index, indices=None
156
+ )
157
+ context = format_query(matches)
158
+ output_text = format_context(context)
159
+
160
+ st.subheader("Retrieved Text:")
161
+ for output in output_text:
162
+ output = f"""{output}"""
163
+ st.write(
164
+ f"<ul><li><p>{output}</p></li></ul>",
165
+ unsafe_allow_html=True,
166
+ )
 
 
 
167
 
168
 
169
  file_text = retrieve_transcript()
utils/retriever.py CHANGED
@@ -6,15 +6,23 @@ def query_pinecone(
6
  dense_vec,
7
  top_k,
8
  index,
9
- indices
10
  ):
11
- xc = index.query(
12
- vector=dense_vec,
13
- top_k=top_k,
14
- filter={"QA_Flag": {"$eq": "Answer"},
15
- "index": {"$in": indices}},
16
- include_metadata=True,
17
- )
 
 
 
 
 
 
 
 
18
  return xc["matches"]
19
 
20
 
 
6
  dense_vec,
7
  top_k,
8
  index,
9
+ indices=None
10
  ):
11
+ if indices != None:
12
+ xc = index.query(
13
+ vector=dense_vec,
14
+ top_k=top_k,
15
+ filter={"QA_Flag": {"$eq": "Answer"},
16
+ "index": {"$in": indices}},
17
+ include_metadata=True,
18
+ )
19
+ else:
20
+ xc = index.query(
21
+ vector=dense_vec,
22
+ top_k=top_k,
23
+ filter={"QA_Flag": {"$eq": "Answer"}},
24
+ include_metadata=True,
25
+ )
26
  return xc["matches"]
27
 
28