Spaces:
Running
Running
File size: 6,139 Bytes
6b17b03 38434f1 6b17b03 7f94736 6b17b03 7a39fe4 6b17b03 7a39fe4 6b17b03 7a39fe4 6b17b03 7a39fe4 704a86e 7a39fe4 6b17b03 704a86e 6b17b03 ea38858 69c7fed 1c5ff8b 2f85883 80f625a 69c7fed 51448c5 69c7fed 92c65e0 69c7fed 80f625a f03d1da ea38858 f03d1da ea38858 51448c5 e172747 29f34a8 6213b0d 29f34a8 e172747 29f34a8 e172747 29f34a8 6213b0d 29f34a8 1cd386f 29f34a8 87f318e 29f34a8 6213b0d 29f34a8 6213b0d 6b17b03 97746f8 81ab107 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import streamlit as st
import pandas as pd
import warnings
from PIL import Image
warnings.filterwarnings("ignore")
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from openai.embeddings_utils import get_embedding, cosine_similarity
df = pd.read_pickle('movie_data_embedding.pkl')
embedder = SentenceTransformer('all-mpnet-base-v2')
#embedder.to('cuda')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2')
def search_bi_encoder(query,top_k=15):
query_embedding = embedder.encode(query)
df["bi_similarity"] = df.plot_embedding.apply(lambda x: cosine_similarity(x, query_embedding.reshape(768,-1)))
results = (
df.sort_values("bi_similarity", ascending=False)
.head(top_k))
resultlist = []
hlist = []
rank=1
for r in results.index:
if results.title[r] not in hlist:
resultlist.append(
{
"name":results.title[r],
"bi_encoder_score": results.bi_similarity[r][0],
"retrieval_rank":rank,
"year": results.year[r],
"language": results.language[r],
"cast":results.cast[r],
"plot":results['plot'][r],
"link":results.link[r]
})
hlist.append(results.title[r])
rank = rank+1
return resultlist
def search_cross_encoder(query,candidates):
cross_inp = [[query, candidate['plot']] for candidate in candidates]
cross_scores = cross_encoder.predict(cross_inp)
for idx in range(len(cross_scores)):
candidates[idx]['cross-score'] = cross_scores[idx]
sortedResult = sorted(candidates, key=lambda x: x['cross-score'], reverse=True)
for idx in range(len(sortedResult)):
sortedResult[idx]['re-rank'] = idx+1
return sortedResult
def search(query,top_k=15):
candidates = search_bi_encoder(query,top_k)
rankedResult = search_cross_encoder(query,candidates)
return rankedResult
def displayResults(results, container):
for result in results:
#container.header(f"Link: [{result['name']}](https://en.wikipedia.org{result['link']})")
container.header(result['name'])
container.caption(f"Language: {result['language']}, Released in:{result['year']}")
#container.caption(f"Released in:{result['year']}")
cast = result['cast']
with container.expander("Plot:",expanded=True):
container.markdown(f'''{result['plot']}''')
with container.expander("Movie result internals"):
container.markdown(f"""Link: [{result['name']}](https://en.wikipedia.org{result['link']})""",unsafe_allow_html=True)
container.text(f"Cast:{cast}")
container.text("JSON Result:")
container.json(result)
container.divider()
st.title("Indian Movie Search")
st.caption("Using semantic search to improve the search accuracy")
appTab, detailsTab = st.tabs(["App", "App Technical Details"])
with appTab:
st.markdown(
f"""
- Search for movie names based on the plot.
- The corpus is made up of Hindi, Telugu, Tamil, Kannada, Bengali, Malayalam, Odiya, Marathi, Punjabi & Gujarathi movies released between 1950 and 2023.
- Corpus size:{len(df)}
- The app understands the context of the query and returns the results from the datastore.""")
#st.subheader(f"Search Query: {query}")
search_query = st.text_input("Please enter your search query here",value="",key="text_input")
top_k = st.slider("Number of Top Hits Generated",min_value=1,max_value=100,value=15)
#search = st.button("Search",key='search_but', help='Click to Search!!')
ranked_hits = []
if len(search_query)>0:
with st.spinner(
text="Searching for relevant movie plots for given query..."
):
ranked_hits = search(search_query,top_k)
if(len(ranked_hits)>0):
st.success("Matches found!!")
st.divider()
resultContainer = st.container()
resultContainer.subheader("Results:")
resultContainer.caption(f"Search Query: {search_query}")
displayResults(ranked_hits,st)
resultContainer.markdown("\n-------------------------\n")
st.divider()
with detailsTab:
st.header("App details")
st.markdown(
"""
- The app supports Semantic search which seeks to improve search accuracy by understanding the content of the search query in contrast to traditional search engines which only find documents based on lexical matches.
- The corpus consists of movie plots from Hindi, Telugu, Tamil, Kannada, Bengali, Malayalam, Odiya, Marathi, Punjabi & Gujarathi languages.
- The core idea of the retrieval:
- Use Bi-Encoder (Retrieval) and Cross-encoder (Re-ranker) to retrieve the search results.
- The Bi-encoder is responsible for independently embedding the sentences and search queries into a vector space. The result is then passed to the cross-encoder for checking the relevance/similarity between the query and sentences.
- All plot entries in the corpus is embedded into a vector space. At search time, the query is embedded into the same vector space.
- Corpus embeddings and search query embedding are passed into bi-encoder and it would return the closest embeddings from the corpus.
- Cosine similarity is used to find the similar embeddings.
- The result is then passed to cross-encoder to re-rank the results based on the relevance to the search query.
"""
)
st.image(Image.open('semantic_search.png'), caption='Semantic search using Retrieval and Re-Rank')
st.markdown(
"""
Model Source:
- Bi-Encoder - [all-mpnet-base-v2](https://huggingface.co./sentence-transformers/all-mpnet-base-v2)
- Cross-Encoder - [cross-encoder/ms-marco-MiniLM-L-2-v2](https://huggingface.co./cross-encoder/ms-marco-MiniLM-L-2-v2)""")
st.markdown("![](https://komarev.com/ghpvc/?username=ashwinpatti_semantic_movie_search&label=PAGE+VIEWS)")
#![](https://komarev.com/ghpvc/?username=your-github-username&label=PROFILE+VIEWS)
|