qsaheeb
Final changes
54a9816
raw
history blame
4.58 kB
import gradio as gr
import pandas as pd
import pickle
import torch
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from duckduckgo_search import DDGS
from fuzzywuzzy import process
# Load book dataset
df = pd.read_csv("data/books_summary_cleaned.csv")
# Load precomputed BERT embeddings
with open("model/sbert_embeddings2.pkl", "rb") as f:
book_embeddings = pickle.load(f)
# Load models
retriever_model = SentenceTransformer("all-mpnet-base-v2") # More accurate than MiniLM
reranker_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") # More accurate ranking
def fetch_summary_duckduckgo(book_title, log):
"""Fetch book summary from DuckDuckGo API (search engine)."""
log.append(f"Searching the internet for '{book_title}' summary...")
with DDGS() as ddgs:
search_results = list(ddgs.text(f"{book_title} book summary", max_results=3))
itr = 0
for result in search_results:
if itr<=1:
itr+=1
continue
if "body" in result:
log.append("Summary found from the web.")
return result["body"], log
log.append("No summary found on the web.")
return None, log
def get_best_match(book_title, book_list, log):
"""Find the closest matching book title in the dataset using fuzzy matching."""
best_match, score = process.extractOne(book_title, book_list)
if score > 90:
if(book_title!=best_match):
log.append(f"Typo detected! Corrected '{book_title}' to '{best_match}'.")
return best_match, log
log.append(f"No correction needed for '{book_title}'.")
return book_title, log
def retrieve_candidates(book_title, top_n=10):
"""Retrieve top-N similar books using SBERT embeddings"""
log = ["Starting book recommendation process..."]
book_title, log = get_best_match(book_title, df["book_name"].values.tolist(), log)
if book_title in df["book_name"].values:
book_idx = df[df["book_name"] == book_title].index[0]
query_embedding = book_embeddings[book_idx]
summary = df[df["book_name"] == book_title]["summaries"].values[0]
log.append(f"Book '{book_title}' found in the dataset.")
else:
log.append(f"Book '{book_title}' not found in the dataset.")
summary, log = fetch_summary_duckduckgo(book_title, log)
if summary is None:
log.append("No summary found. Cannot proceed with recommendation.")
return None, None, None, log
query_embedding = retriever_model.encode(summary, convert_to_tensor=True)
scores = util.cos_sim(query_embedding, book_embeddings)[0]
top_indices = torch.argsort(scores, descending=True)[1:top_n+1]
log.append(f"Top {top_n} similar books retrieved from the dataset.")
return book_title, summary, df.iloc[top_indices][["book_name", "summaries"]].values.tolist(), log
def rerank_books(query_title, query_summary, candidates, log):
"""Re-rank books using a cross-encoder"""
# query_summary = df[df["book_name"] == query_title]["summaries"].values[0]
pairs = [(query_summary, cand_summary) for _, cand_summary in candidates]
scores = reranker_model.predict(pairs)
ranked_books = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
log.append("Books re-ranked based on cross-encoder model and returning top 5 books")
return [book[0][0] for book in ranked_books[:5]], log
def recommend_books(book_title):
"""Complete recommendation pipeline with logging"""
book_title, summary, candidates, log = retrieve_candidates(book_title, top_n=10)
if book_title is None:
log.append("Book not found. Exiting recommendation process.")
return "Book not found", "\n".join(log)
recommendations, log = rerank_books(book_title, summary, candidates, log)
log.append("Recommendation process complete.")
return ", ".join(recommendations), "\n".join(log)
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Book Recommendation System")
gr.Markdown("Enter a book title to find similar books based on summaries.")
with gr.Row():
book_input = gr.Textbox(label="Enter Book Title")
submit_btn = gr.Button("Recommend")
output = gr.Textbox(label="Recommended Books", interactive=False)
log_output = gr.Textbox(label="Logs", interactive=False, lines=10) # Log display
submit_btn.click(recommend_books, inputs=book_input, outputs=[output, log_output])
# Run the app
if __name__ == "__main__":
demo.launch()