Saiteja Solleti commited on
Commit
a46269a
·
1 Parent(s): 748ac82

fine tuning and reranking is pushed

Browse files
Files changed (4) hide show
  1. app.py +7 -2
  2. finetuneresults.py +61 -0
  3. generationhelper.py +8 -0
  4. requirements.txt +2 -1
app.py CHANGED
@@ -6,6 +6,7 @@ from createmilvusschema import CreateMilvusDbSchema
6
  from insertmilvushelper import EmbedAllDocumentsAndInsert
7
  from sentence_transformers import SentenceTransformer
8
  from searchmilvushelper import SearchTopKDocuments
 
9
 
10
  from model import generate_response
11
  from huggingface_hub import login
@@ -15,6 +16,7 @@ from huggingface_hub import dataset_info
15
 
16
  # Load embedding model
17
  QUERY_EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
 
18
  WINDOW_SIZE = 5
19
  OVERLAP = 2
20
  RETRIVE_TOP_K_SIZE=10
@@ -38,8 +40,11 @@ EmbedAllDocumentsAndInsert(QUERY_EMBEDDING_MODEL, rag_extracted_data, db_collect
38
  """
39
  query = "what would the net revenue have been in 2015 if there wasn't a stipulated settlement from the business combination in october 2015?"
40
 
41
- results_for_top5_chunks = SearchTopKDocuments(db_collection, query, QUERY_EMBEDDING_MODEL, top_k=RETRIVE_TOP_K_SIZE)
42
- print(results_for_top5_chunks)
 
 
 
43
 
44
 
45
  def chatbot(prompt):
 
6
  from insertmilvushelper import EmbedAllDocumentsAndInsert
7
  from sentence_transformers import SentenceTransformer
8
  from searchmilvushelper import SearchTopKDocuments
9
+ from finetuneresults import FineTuneAndRerankSearchResults
10
 
11
  from model import generate_response
12
  from huggingface_hub import login
 
16
 
17
  # Load embedding model
18
  QUERY_EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
19
+ RERANKING_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
20
  WINDOW_SIZE = 5
21
  OVERLAP = 2
22
  RETRIVE_TOP_K_SIZE=10
 
40
  """
41
  query = "what would the net revenue have been in 2015 if there wasn't a stipulated settlement from the business combination in october 2015?"
42
 
43
+ results_for_top10_chunks = SearchTopKDocuments(db_collection, query, QUERY_EMBEDDING_MODEL, top_k=RETRIVE_TOP_K_SIZE)
44
+
45
+ reranked_results = FineTuneAndRerankSearchResults(results_for_top10_chunks, rag_extracted_data, query, RERANKING_MODEL)
46
+
47
+ print(reranked_results)
48
 
49
 
50
  def chatbot(prompt):
finetuneresults.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder
2
+
3
+ """
4
+ Retrieves unique full documents based on the top-ranked document IDs.
5
+
6
+ Args:
7
+ top_documents (list): List of dictionaries containing 'doc_id'.
8
+ df (pd.DataFrame): The dataset containing document IDs and text.
9
+
10
+ Returns:
11
+ pd.DataFrame: A DataFrame with 'doc_id' and 'document'.
12
+ """
13
+ def retrieve_full_documents(top_documents, df):
14
+
15
+ # Extract unique doc_ids
16
+ unique_doc_ids = list(set(doc["doc_id"] for doc in top_documents))
17
+
18
+ # Print for debugging
19
+ print(f"Extracted Doc IDs: {unique_doc_ids}")
20
+
21
+ # Filter DataFrame where 'id' matches any of the unique_doc_ids
22
+ filtered_df = df[df["id"].isin(unique_doc_ids)][["id", "documents"]].drop_duplicates(subset="id")
23
+
24
+ # Rename columns for clarity
25
+ filtered_df = filtered_df.rename(columns={"id": "doc_id", "documents": "document"})
26
+
27
+ return filtered_df
28
+
29
+ """
30
+ Reranks the retrieved documents based on their relevance to the query using a Cross-Encoder model.
31
+ Args:
32
+ query (str): The search query.
33
+ retrieved_docs (pd.DataFrame): DataFrame with 'doc_id' and 'document'.
34
+ model_name (str): Name of the Cross-Encoder model.
35
+ Returns:
36
+ pd.DataFrame: A sorted DataFrame with doc_id, document, and reranking score.
37
+ """
38
+
39
+ def rerank_documents(query, retrieved_docs_df, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
40
+
41
+ # Load Cross-Encoder model
42
+ model = CrossEncoder(model_name)
43
+
44
+ # Prepare query-document pairs
45
+ query_doc_pairs = [(query, " ".join(doc)) for doc in retrieved_docs_df["document"]]
46
+
47
+ # Compute relevance scores
48
+ scores = model.predict(query_doc_pairs)
49
+
50
+ # Add scores to the DataFrame
51
+ retrieved_docs_df["relevance_score"] = scores
52
+
53
+ # Sort by score in descending order (higher score = more relevant)
54
+ reranked_docs_df = retrieved_docs_df.sort_values(by="relevance_score", ascending=False).reset_index(drop=True)
55
+
56
+ return reranked_docs_df
57
+
58
+ def FineTuneAndRerankSearchResults(top_10_chunk_results, rag_extarcted_data, question, reranking_model):
59
+ unique_docs= retrieve_full_documents(top_10_chunk_results, rag_extarcted_data)
60
+ reranked_results = rerank_documents(question, unique_docs, reranking_model)
61
+ return rerank_documents
generationhelper.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from groq import Groq
3
+
4
+ groq_token = os.getenv("GROQ_TOKEN")
5
+
6
+ groq_client = Groq(
7
+ api_key = groq_token
8
+ )
requirements.txt CHANGED
@@ -4,4 +4,5 @@ torch
4
  huggingface_hub
5
  pymilvus
6
  nltk
7
- sentence-transformers
 
 
4
  huggingface_hub
5
  pymilvus
6
  nltk
7
+ sentence-transformers
8
+ Groq