ashwinpatti commited on
Commit
6b17b03
·
1 Parent(s): 46f9c2e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import warnings
4
+
5
+ warnings.filterwarnings("ignore")
6
+
7
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
8
+ from openai.embeddings_utils import get_embedding, cosine_similarity
9
+
10
+ df = pd.read_pickle('/content/drive/MyDrive/apatti_movie_search/movie_data_embedding.pkl')
11
+
12
+
13
+ embedder = SentenceTransformer('all-mpnet-base-v2')
14
+ #embedder.to('cuda')
15
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2')
16
+
17
+ @st.experimental_memo(suppress_st_warning=True)
18
+ def search_bi_encoder(query,top_k=15):
19
+ query_embedding = embedder.encode(query)
20
+ df["bi_similarity"] = df.plot_embedding.apply(lambda x: cosine_similarity(x, query_embedding.reshape(768,-1)))
21
+
22
+ results = (
23
+ df.sort_values("bi_similarity", ascending=False)
24
+ .head(top_k))
25
+
26
+ resultlist = []
27
+
28
+ hlist = []
29
+ for r in results.index:
30
+ if results.title[r] not in hlist:
31
+ resultlist.append(
32
+ {
33
+ "name":results.title[r],
34
+ "bi_encoder_score": results.bi_similarity[r][0],
35
+ "year": results.year[r],
36
+ "language": results.language[r],
37
+ "cast":results.cast[r],
38
+ "plot":results['plot'][r],
39
+ "link":results.link[r]
40
+ })
41
+ hlist.append(results.title[r])
42
+ return resultlist
43
+
44
+
45
+ @st.experimental_memo(suppress_st_warning=True)
46
+ def search_cross_encoder(query,candidates):
47
+ cross_inp = [[query, candidate['plot']] for candidate in candidates]
48
+ cross_scores = cross_encoder.predict(cross_inp)
49
+ for idx in range(len(cross_scores)):
50
+ candidates[idx]['cross-score'] = cross_scores[idx]
51
+
52
+ sortedResult = sorted(candidates, key=lambda x: x['cross-score'], reverse=True)
53
+ return sortedResult
54
+
55
+
56
+ @st.experimental_memo(suppress_st_warning=True)
57
+ def search(query,top_k=15):
58
+ candidates = search_bi_encoder(query,top_k)
59
+ rankedResult = search_cross_encoder(query,candidates)
60
+
61
+ return rankedResult
62
+
63
+
64
+
65
+
66
+ x = st.slider('Select a value')
67
+ #st.subheader(f"Search Query: {query}")
68
+ search_query = st.text_input("Please Enter your search query here",value="What are the expectations for inflation for Australia?",key="text_input")
69
+