import streamlit as st
import pandas as pd, numpy as np
import os
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
@st.cache(show_spinner=False,
hash_funcs={CLIPModel: lambda _: None,
CLIPTextModel: lambda _: None,
CLIPProcessor: lambda _: None,
dict: lambda _: None})
def load():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
for k in [0, 1]:
embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
return model, text_model, processor, df, embeddings
model, text_model, processor, df, embeddings = load()
source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}
def get_html(url_list, height=200):
html = "
"
for url, title, link in url_list:
html2 = f"
"
if len(link) > 0:
html2 = f"
" + html2 + ""
html = html + html2
html += "
"
return html
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.text_projection(text_model(**inputs).pooler_output)
st.cache(show_spinner=False)
def image_search(query, corpus, n_results=24):
text_embeddings = compute_text_embeddings([query]).detach().numpy()
k = 0 if corpus == 'Unsplash' else 1
results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
return [(df[k].iloc[i]['path'],
df[k].iloc[i]['tooltip'] + source[k],
df[k].iloc[i]['link']) for i in results]
description = '''
# Semantic image search
Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co./transformers/), [Streamlit](https://streamlit.io/) and images from [Unsplash](https://unsplash.com/) and [The Movie Database (TMDB)](https://www.themoviedb.org/)
'''
def main():
st.markdown('''
''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, col1, col2, _ = st.columns([2, 10, 2, 2])
query = col1.text_input('')
corpus = col2.radio('', ["Unsplash","Movies"])
if len(query) > 0:
results = image_search(query, corpus)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()