app.py
CHANGED
@@ -4,6 +4,7 @@ import faiss
|
|
4 |
import numpy as np
|
5 |
import os
|
6 |
from FlagEmbedding import BGEM3FlagModel
|
|
|
7 |
|
8 |
# Load the pre-trained embedding model
|
9 |
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
|
@@ -15,36 +16,9 @@ df['embeding_context'] = df['embeding_context'].astype(str).fillna('')
|
|
15 |
# Filter out any rows where 'embeding_context' might be empty or invalid
|
16 |
df = df[df['embeding_context'] != '']
|
17 |
|
18 |
-
#
|
19 |
-
# embedding_contexts = df['embeding_context'].tolist()
|
20 |
-
# embeddings_csv = model.encode(embedding_contexts, batch_size=12, max_length=1024)['dense_vecs']
|
21 |
-
|
22 |
-
# # Convert embeddings to numpy array
|
23 |
-
# embeddings_np = np.array(embeddings_csv).astype('float32')
|
24 |
-
|
25 |
-
# # FAISS index file path
|
26 |
-
# index_file_path = 'vector_store_bge_m3.index'
|
27 |
-
|
28 |
-
# # Check if FAISS index file already exists
|
29 |
-
# if os.path.exists(index_file_path):
|
30 |
-
# # Load the existing FAISS index from file
|
31 |
-
# index = faiss.read_index(index_file_path)
|
32 |
-
# print("FAISS index loaded from file.")
|
33 |
-
# else:
|
34 |
-
# # Initialize FAISS index (for L2 similarity)
|
35 |
-
# dim = embeddings_np.shape[1]
|
36 |
-
# index = faiss.IndexFlatL2(dim)
|
37 |
-
|
38 |
-
# # Add embeddings to the FAISS index
|
39 |
-
# index.add(embeddings_np)
|
40 |
-
|
41 |
-
# # Save the FAISS index to a file for future use
|
42 |
-
# faiss.write_index(index, index_file_path)
|
43 |
-
# print("FAISS index created and saved to file.")
|
44 |
-
|
45 |
index = faiss.read_index('vector_store_bge_m3.index')
|
46 |
|
47 |
-
|
48 |
# Function to perform search and return all columns
|
49 |
def search_query(query_text):
|
50 |
num_records = 50
|
@@ -64,7 +38,15 @@ def search_query(query_text):
|
|
64 |
# Gradio interface function
|
65 |
def gradio_interface(query_text):
|
66 |
search_results = search_query(query_text)
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
with gr.Blocks() as app:
|
70 |
gr.Markdown("<h1>White Stride Red Search (BEG-M3)</h1>")
|
@@ -78,10 +60,11 @@ with gr.Blocks() as app:
|
|
78 |
# Output table for displaying results
|
79 |
search_output = gr.DataFrame(label="Search Results")
|
80 |
|
81 |
-
#
|
82 |
-
|
83 |
-
|
84 |
|
|
|
|
|
85 |
|
86 |
# Launch the Gradio app
|
87 |
app.launch()
|
|
|
4 |
import numpy as np
|
5 |
import os
|
6 |
from FlagEmbedding import BGEM3FlagModel
|
7 |
+
from io import BytesIO
|
8 |
|
9 |
# Load the pre-trained embedding model
|
10 |
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
|
|
|
16 |
# Filter out any rows where 'embeding_context' might be empty or invalid
|
17 |
df = df[df['embeding_context'] != '']
|
18 |
|
19 |
+
# Load the FAISS index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
index = faiss.read_index('vector_store_bge_m3.index')
|
21 |
|
|
|
22 |
# Function to perform search and return all columns
|
23 |
def search_query(query_text):
|
24 |
num_records = 50
|
|
|
38 |
# Gradio interface function
|
39 |
def gradio_interface(query_text):
|
40 |
search_results = search_query(query_text)
|
41 |
+
|
42 |
+
# Save search_results to an Excel file in memory
|
43 |
+
output = BytesIO()
|
44 |
+
with pd.ExcelWriter(output, engine='xlsxwriter') as writer:
|
45 |
+
search_results.to_excel(writer, index=False)
|
46 |
+
excel_data = output.getvalue()
|
47 |
+
|
48 |
+
# Return the DataFrame and update the download button
|
49 |
+
return search_results, gr.update(value=excel_data)
|
50 |
|
51 |
with gr.Blocks() as app:
|
52 |
gr.Markdown("<h1>White Stride Red Search (BEG-M3)</h1>")
|
|
|
60 |
# Output table for displaying results
|
61 |
search_output = gr.DataFrame(label="Search Results")
|
62 |
|
63 |
+
# Download button for Excel file
|
64 |
+
download_button = gr.DownloadButton(label="Download Excel", file_name="search_results.xlsx")
|
|
|
65 |
|
66 |
+
# Link button click to action
|
67 |
+
search_button.click(fn=gradio_interface, inputs=search_input, outputs=[search_output, download_button])
|
68 |
|
69 |
# Launch the Gradio app
|
70 |
app.launch()
|