Spaces:
Running
Running
from sentence_transformers import CrossEncoder | |
""" | |
Retrieves unique full documents based on the top-ranked document IDs. | |
Args: | |
top_documents (list): List of dictionaries containing 'doc_id'. | |
df (pd.DataFrame): The dataset containing document IDs and text. | |
Returns: | |
pd.DataFrame: A DataFrame with 'doc_id' and 'document'. | |
""" | |
def retrieve_full_documents(top_documents, df): | |
# Extract unique doc_ids | |
unique_doc_ids = list(set(doc["doc_id"] for doc in top_documents)) | |
# Print for debugging | |
print(f"Extracted Doc IDs: {unique_doc_ids}") | |
# Filter DataFrame where 'id' matches any of the unique_doc_ids | |
filtered_df = df[df["id"].isin(unique_doc_ids)][["id", "documents"]].drop_duplicates(subset="id") | |
# Rename columns for clarity | |
filtered_df = filtered_df.rename(columns={"id": "doc_id", "documents": "document"}) | |
return filtered_df | |
""" | |
Reranks the retrieved documents based on their relevance to the query using a Cross-Encoder model. | |
Args: | |
query (str): The search query. | |
retrieved_docs (pd.DataFrame): DataFrame with 'doc_id' and 'document'. | |
model_name (str): Name of the Cross-Encoder model. | |
Returns: | |
pd.DataFrame: A sorted DataFrame with doc_id, document, and reranking score. | |
""" | |
def rerank_documents(query, retrieved_docs_df, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
# Load Cross-Encoder model | |
model = CrossEncoder(model_name) | |
# Prepare query-document pairs | |
query_doc_pairs = [(query, " ".join(doc)) for doc in retrieved_docs_df["document"]] | |
# Compute relevance scores | |
scores = model.predict(query_doc_pairs) | |
# Add scores to the DataFrame | |
retrieved_docs_df["relevance_score"] = scores | |
# Sort by score in descending order (higher score = more relevant) | |
reranked_docs_df = retrieved_docs_df.sort_values(by="relevance_score", ascending=False).reset_index(drop=True) | |
return reranked_docs_df | |
def FineTuneAndRerankSearchResults(top_10_chunk_results, rag_extarcted_data, question, reranking_model): | |
unique_docs= retrieve_full_documents(top_10_chunk_results, rag_extarcted_data) | |
reranked_results = rerank_documents(question, unique_docs, reranking_model) | |
return reranked_results |