File size: 2,154 Bytes
6b17b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import streamlit as st
import pandas as pd
import warnings

warnings.filterwarnings("ignore")

from sentence_transformers import SentenceTransformer, CrossEncoder, util
from openai.embeddings_utils import get_embedding, cosine_similarity

df = pd.read_pickle('/content/drive/MyDrive/apatti_movie_search/movie_data_embedding.pkl')


embedder = SentenceTransformer('all-mpnet-base-v2')
#embedder.to('cuda')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2')

@st.experimental_memo(suppress_st_warning=True)
def search_bi_encoder(query,top_k=15):
  query_embedding = embedder.encode(query)
  df["bi_similarity"] = df.plot_embedding.apply(lambda x: cosine_similarity(x, query_embedding.reshape(768,-1)))

  results = (
      df.sort_values("bi_similarity", ascending=False)
      .head(top_k))

  resultlist = []

  hlist = []
  for r in results.index:
      if results.title[r] not in hlist:
          resultlist.append(
          {
            "name":results.title[r],
            "bi_encoder_score": results.bi_similarity[r][0],
            "year": results.year[r],
            "language": results.language[r],
            "cast":results.cast[r],
            "plot":results['plot'][r],
            "link":results.link[r]
          })
          hlist.append(results.title[r])
  return resultlist


@st.experimental_memo(suppress_st_warning=True)
def search_cross_encoder(query,candidates):
  cross_inp = [[query, candidate['plot']] for candidate in candidates]
  cross_scores = cross_encoder.predict(cross_inp)
  for idx in range(len(cross_scores)):
    candidates[idx]['cross-score'] = cross_scores[idx]

  sortedResult = sorted(candidates, key=lambda x: x['cross-score'], reverse=True)
  return sortedResult


@st.experimental_memo(suppress_st_warning=True)
def search(query,top_k=15):
  candidates = search_bi_encoder(query,top_k)
  rankedResult = search_cross_encoder(query,candidates)

  return rankedResult




x = st.slider('Select a value')
#st.subheader(f"Search Query: {query}")
search_query = st.text_input("Please Enter your search query here",value="What are the expectations for inflation for Australia?",key="text_input")