Spaces:
Running
Running
Saiteja Solleti
commited on
Commit
·
a46269a
1
Parent(s):
748ac82
fine tuning and reranking is pushed
Browse files- app.py +7 -2
- finetuneresults.py +61 -0
- generationhelper.py +8 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -6,6 +6,7 @@ from createmilvusschema import CreateMilvusDbSchema
|
|
6 |
from insertmilvushelper import EmbedAllDocumentsAndInsert
|
7 |
from sentence_transformers import SentenceTransformer
|
8 |
from searchmilvushelper import SearchTopKDocuments
|
|
|
9 |
|
10 |
from model import generate_response
|
11 |
from huggingface_hub import login
|
@@ -15,6 +16,7 @@ from huggingface_hub import dataset_info
|
|
15 |
|
16 |
# Load embedding model
|
17 |
QUERY_EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
18 |
WINDOW_SIZE = 5
|
19 |
OVERLAP = 2
|
20 |
RETRIVE_TOP_K_SIZE=10
|
@@ -38,8 +40,11 @@ EmbedAllDocumentsAndInsert(QUERY_EMBEDDING_MODEL, rag_extracted_data, db_collect
|
|
38 |
"""
|
39 |
query = "what would the net revenue have been in 2015 if there wasn't a stipulated settlement from the business combination in october 2015?"
|
40 |
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
def chatbot(prompt):
|
|
|
6 |
from insertmilvushelper import EmbedAllDocumentsAndInsert
|
7 |
from sentence_transformers import SentenceTransformer
|
8 |
from searchmilvushelper import SearchTopKDocuments
|
9 |
+
from finetuneresults import FineTuneAndRerankSearchResults
|
10 |
|
11 |
from model import generate_response
|
12 |
from huggingface_hub import login
|
|
|
16 |
|
17 |
# Load embedding model
|
18 |
QUERY_EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
|
19 |
+
RERANKING_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
20 |
WINDOW_SIZE = 5
|
21 |
OVERLAP = 2
|
22 |
RETRIVE_TOP_K_SIZE=10
|
|
|
40 |
"""
|
41 |
query = "what would the net revenue have been in 2015 if there wasn't a stipulated settlement from the business combination in october 2015?"
|
42 |
|
43 |
+
results_for_top10_chunks = SearchTopKDocuments(db_collection, query, QUERY_EMBEDDING_MODEL, top_k=RETRIVE_TOP_K_SIZE)
|
44 |
+
|
45 |
+
reranked_results = FineTuneAndRerankSearchResults(results_for_top10_chunks, rag_extracted_data, query, RERANKING_MODEL)
|
46 |
+
|
47 |
+
print(reranked_results)
|
48 |
|
49 |
|
50 |
def chatbot(prompt):
|
finetuneresults.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import CrossEncoder
|
2 |
+
|
3 |
+
"""
|
4 |
+
Retrieves unique full documents based on the top-ranked document IDs.
|
5 |
+
|
6 |
+
Args:
|
7 |
+
top_documents (list): List of dictionaries containing 'doc_id'.
|
8 |
+
df (pd.DataFrame): The dataset containing document IDs and text.
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
pd.DataFrame: A DataFrame with 'doc_id' and 'document'.
|
12 |
+
"""
|
13 |
+
def retrieve_full_documents(top_documents, df):
|
14 |
+
|
15 |
+
# Extract unique doc_ids
|
16 |
+
unique_doc_ids = list(set(doc["doc_id"] for doc in top_documents))
|
17 |
+
|
18 |
+
# Print for debugging
|
19 |
+
print(f"Extracted Doc IDs: {unique_doc_ids}")
|
20 |
+
|
21 |
+
# Filter DataFrame where 'id' matches any of the unique_doc_ids
|
22 |
+
filtered_df = df[df["id"].isin(unique_doc_ids)][["id", "documents"]].drop_duplicates(subset="id")
|
23 |
+
|
24 |
+
# Rename columns for clarity
|
25 |
+
filtered_df = filtered_df.rename(columns={"id": "doc_id", "documents": "document"})
|
26 |
+
|
27 |
+
return filtered_df
|
28 |
+
|
29 |
+
"""
|
30 |
+
Reranks the retrieved documents based on their relevance to the query using a Cross-Encoder model.
|
31 |
+
Args:
|
32 |
+
query (str): The search query.
|
33 |
+
retrieved_docs (pd.DataFrame): DataFrame with 'doc_id' and 'document'.
|
34 |
+
model_name (str): Name of the Cross-Encoder model.
|
35 |
+
Returns:
|
36 |
+
pd.DataFrame: A sorted DataFrame with doc_id, document, and reranking score.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def rerank_documents(query, retrieved_docs_df, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
40 |
+
|
41 |
+
# Load Cross-Encoder model
|
42 |
+
model = CrossEncoder(model_name)
|
43 |
+
|
44 |
+
# Prepare query-document pairs
|
45 |
+
query_doc_pairs = [(query, " ".join(doc)) for doc in retrieved_docs_df["document"]]
|
46 |
+
|
47 |
+
# Compute relevance scores
|
48 |
+
scores = model.predict(query_doc_pairs)
|
49 |
+
|
50 |
+
# Add scores to the DataFrame
|
51 |
+
retrieved_docs_df["relevance_score"] = scores
|
52 |
+
|
53 |
+
# Sort by score in descending order (higher score = more relevant)
|
54 |
+
reranked_docs_df = retrieved_docs_df.sort_values(by="relevance_score", ascending=False).reset_index(drop=True)
|
55 |
+
|
56 |
+
return reranked_docs_df
|
57 |
+
|
58 |
+
def FineTuneAndRerankSearchResults(top_10_chunk_results, rag_extarcted_data, question, reranking_model):
|
59 |
+
unique_docs= retrieve_full_documents(top_10_chunk_results, rag_extarcted_data)
|
60 |
+
reranked_results = rerank_documents(question, unique_docs, reranking_model)
|
61 |
+
return rerank_documents
|
generationhelper.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from groq import Groq
|
3 |
+
|
4 |
+
groq_token = os.getenv("GROQ_TOKEN")
|
5 |
+
|
6 |
+
groq_client = Groq(
|
7 |
+
api_key = groq_token
|
8 |
+
)
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ torch
|
|
4 |
huggingface_hub
|
5 |
pymilvus
|
6 |
nltk
|
7 |
-
sentence-transformers
|
|
|
|
4 |
huggingface_hub
|
5 |
pymilvus
|
6 |
nltk
|
7 |
+
sentence-transformers
|
8 |
+
Groq
|