qsaheeb commited on
Commit
594e600
·
1 Parent(s): b667f07

Addsome changes 5

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -12,9 +12,25 @@ with open("model/sbert_embeddings2.pkl", "rb") as f:
12
  book_embeddings = pickle.load(f)
13
 
14
  # Load models
 
15
  reranker_model = CrossEncoder("cross-encoder/stsb-roberta-large") # More accurate ranking
16
 
17
- recommender = BookRecommender()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def rerank_books(query_title, candidates):
20
  """Re-rank books using a cross-encoder"""
@@ -28,15 +44,15 @@ def rerank_books(query_title, candidates):
28
 
29
  def recommend_books(book_title):
30
  """Complete recommendation pipeline"""
31
- candidates = recommender.recommend(book_title, top_n=5)
32
  if isinstance(candidates, list) and "Error" in candidates[0]:
33
  return candidates[0]
34
 
35
- return candidates
36
 
37
  # Gradio Interface
38
  with gr.Blocks() as demo:
39
- gr.Markdown("# 📚 Content-Based Book Recommendation")
40
  gr.Markdown("Enter a book title to find similar books based on summaries.")
41
 
42
  with gr.Row():
 
12
  book_embeddings = pickle.load(f)
13
 
14
  # Load models
15
+ retriever_model = SentenceTransformer("all-mpnet-base-v2") # More accurate than MiniLM
16
  reranker_model = CrossEncoder("cross-encoder/stsb-roberta-large") # More accurate ranking
17
 
18
+ def retrieve_candidates(book_title, top_n=10):
19
+ """Retrieve top-N similar books using SBERT embeddings"""
20
+ if book_title not in df["title"].values:
21
+ return ["Error: Book title not found in dataset!"]
22
+
23
+ # Get book index
24
+ book_idx = df[df["book_name"] == book_title].index[0]
25
+
26
+ # Compute cosine similarity
27
+ query_embedding = book_embeddings[book_idx]
28
+ scores = util.cos_sim(query_embedding, book_embeddings)[0]
29
+
30
+ # Get top-N similar books (excluding the book itself)
31
+ top_indices = torch.argsort(scores, descending=True)[1:top_n+1]
32
+
33
+ return df.iloc[top_indices][["title", "summary"]].values.tolist()
34
 
35
  def rerank_books(query_title, candidates):
36
  """Re-rank books using a cross-encoder"""
 
44
 
45
  def recommend_books(book_title):
46
  """Complete recommendation pipeline"""
47
+ candidates = retrieve_candidates(book_title, top_n=10)
48
  if isinstance(candidates, list) and "Error" in candidates[0]:
49
  return candidates[0]
50
 
51
+ return rerank_books(book_title, candidates)
52
 
53
  # Gradio Interface
54
  with gr.Blocks() as demo:
55
+ gr.Markdown("# Content-Based Book Recommendation")
56
  gr.Markdown("Enter a book title to find similar books based on summaries.")
57
 
58
  with gr.Row():