qsaheeb
commited on
Commit
·
594e600
1
Parent(s):
b667f07
Addsome changes 5
Browse files
app.py
CHANGED
@@ -12,9 +12,25 @@ with open("model/sbert_embeddings2.pkl", "rb") as f:
|
|
12 |
book_embeddings = pickle.load(f)
|
13 |
|
14 |
# Load models
|
|
|
15 |
reranker_model = CrossEncoder("cross-encoder/stsb-roberta-large") # More accurate ranking
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def rerank_books(query_title, candidates):
|
20 |
"""Re-rank books using a cross-encoder"""
|
@@ -28,15 +44,15 @@ def rerank_books(query_title, candidates):
|
|
28 |
|
29 |
def recommend_books(book_title):
|
30 |
"""Complete recommendation pipeline"""
|
31 |
-
candidates =
|
32 |
if isinstance(candidates, list) and "Error" in candidates[0]:
|
33 |
return candidates[0]
|
34 |
|
35 |
-
return candidates
|
36 |
|
37 |
# Gradio Interface
|
38 |
with gr.Blocks() as demo:
|
39 |
-
gr.Markdown("#
|
40 |
gr.Markdown("Enter a book title to find similar books based on summaries.")
|
41 |
|
42 |
with gr.Row():
|
|
|
12 |
book_embeddings = pickle.load(f)
|
13 |
|
14 |
# Load models
|
15 |
+
retriever_model = SentenceTransformer("all-mpnet-base-v2") # More accurate than MiniLM
|
16 |
reranker_model = CrossEncoder("cross-encoder/stsb-roberta-large") # More accurate ranking
|
17 |
|
18 |
+
def retrieve_candidates(book_title, top_n=10):
|
19 |
+
"""Retrieve top-N similar books using SBERT embeddings"""
|
20 |
+
if book_title not in df["title"].values:
|
21 |
+
return ["Error: Book title not found in dataset!"]
|
22 |
+
|
23 |
+
# Get book index
|
24 |
+
book_idx = df[df["book_name"] == book_title].index[0]
|
25 |
+
|
26 |
+
# Compute cosine similarity
|
27 |
+
query_embedding = book_embeddings[book_idx]
|
28 |
+
scores = util.cos_sim(query_embedding, book_embeddings)[0]
|
29 |
+
|
30 |
+
# Get top-N similar books (excluding the book itself)
|
31 |
+
top_indices = torch.argsort(scores, descending=True)[1:top_n+1]
|
32 |
+
|
33 |
+
return df.iloc[top_indices][["title", "summary"]].values.tolist()
|
34 |
|
35 |
def rerank_books(query_title, candidates):
|
36 |
"""Re-rank books using a cross-encoder"""
|
|
|
44 |
|
45 |
def recommend_books(book_title):
|
46 |
"""Complete recommendation pipeline"""
|
47 |
+
candidates = retrieve_candidates(book_title, top_n=10)
|
48 |
if isinstance(candidates, list) and "Error" in candidates[0]:
|
49 |
return candidates[0]
|
50 |
|
51 |
+
return rerank_books(book_title, candidates)
|
52 |
|
53 |
# Gradio Interface
|
54 |
with gr.Blocks() as demo:
|
55 |
+
gr.Markdown("# Content-Based Book Recommendation")
|
56 |
gr.Markdown("Enter a book title to find similar books based on summaries.")
|
57 |
|
58 |
with gr.Row():
|