ashwinpatti's picture
Update app.py
35f96b9
raw
history blame
2.29 kB
import streamlit as st
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from openai.embeddings_utils import get_embedding, cosine_similarity
df = pd.read_pickle('movie_data_embedding.pkl')
embedder = SentenceTransformer('all-mpnet-base-v2')
#embedder.to('cuda')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2')
def search_bi_encoder(query,top_k=15):
query_embedding = embedder.encode(query)
df["bi_similarity"] = df.plot_embedding.apply(lambda x: cosine_similarity(x, query_embedding.reshape(768,-1)))
results = (
df.sort_values("bi_similarity", ascending=False)
.head(top_k))
resultlist = []
hlist = []
for r in results.index:
if results.title[r] not in hlist:
resultlist.append(
{
"name":results.title[r],
"bi_encoder_score": results.bi_similarity[r][0],
"year": results.year[r],
"language": results.language[r],
"cast":results.cast[r],
"plot":results['plot'][r],
"link":results.link[r]
})
hlist.append(results.title[r])
return resultlist
def search_cross_encoder(query,candidates):
cross_inp = [[query, candidate['plot']] for candidate in candidates]
cross_scores = cross_encoder.predict(cross_inp)
for idx in range(len(cross_scores)):
candidates[idx]['cross-score'] = cross_scores[idx]
sortedResult = sorted(candidates, key=lambda x: x['cross-score'], reverse=True)
return sortedResult
def search(query,top_k=15):
candidates = search_bi_encoder(query,top_k)
rankedResult = search_cross_encoder(query,candidates)
return rankedResult
st.title("Semantic Indian Movie Search");
st.markdown(
"""
- Search for movie names based on the plot
- The datastore is made up of Hindi, Telugu, Tamil, Kannada, Bengali, Malayalam, Odiya, Marathi, Punjabi & Gujarathi movies released between 1950 and 2023.
- The app understands the context of the query and returns the results from the datastore.""")
#st.subheader(f"Search Query: {query}")
search_query = st.text_input("Please enter your search query here",value="",key="text_input")
st.divider()
st.header("Search details")