ashwinpatti's picture
Create app.py
6b17b03
raw
history blame
2.15 kB
import streamlit as st
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from openai.embeddings_utils import get_embedding, cosine_similarity
df = pd.read_pickle('/content/drive/MyDrive/apatti_movie_search/movie_data_embedding.pkl')
embedder = SentenceTransformer('all-mpnet-base-v2')
#embedder.to('cuda')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2')
@st.experimental_memo(suppress_st_warning=True)
def search_bi_encoder(query,top_k=15):
query_embedding = embedder.encode(query)
df["bi_similarity"] = df.plot_embedding.apply(lambda x: cosine_similarity(x, query_embedding.reshape(768,-1)))
results = (
df.sort_values("bi_similarity", ascending=False)
.head(top_k))
resultlist = []
hlist = []
for r in results.index:
if results.title[r] not in hlist:
resultlist.append(
{
"name":results.title[r],
"bi_encoder_score": results.bi_similarity[r][0],
"year": results.year[r],
"language": results.language[r],
"cast":results.cast[r],
"plot":results['plot'][r],
"link":results.link[r]
})
hlist.append(results.title[r])
return resultlist
@st.experimental_memo(suppress_st_warning=True)
def search_cross_encoder(query,candidates):
cross_inp = [[query, candidate['plot']] for candidate in candidates]
cross_scores = cross_encoder.predict(cross_inp)
for idx in range(len(cross_scores)):
candidates[idx]['cross-score'] = cross_scores[idx]
sortedResult = sorted(candidates, key=lambda x: x['cross-score'], reverse=True)
return sortedResult
@st.experimental_memo(suppress_st_warning=True)
def search(query,top_k=15):
candidates = search_bi_encoder(query,top_k)
rankedResult = search_cross_encoder(query,candidates)
return rankedResult
x = st.slider('Select a value')
#st.subheader(f"Search Query: {query}")
search_query = st.text_input("Please Enter your search query here",value="What are the expectations for inflation for Australia?",key="text_input")