SRDdev commited on
Commit
2288165
1 Parent(s): 96bae18

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +167 -0
  2. data.csv +0 -0
  3. data2.csv +0 -0
  4. embeddings.npy +3 -0
  5. embeddings2.npy +3 -0
  6. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from html import escape
2
+ import re
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import numpy as np
6
+ from transformers import CLIPProcessor, CLIPModel
7
+ from st_clickable_images import clickable_images
8
+
9
+ @st.cache(
10
+ show_spinner=False,
11
+ hash_funcs={
12
+ CLIPModel: lambda _: None,
13
+ CLIPProcessor: lambda _: None,
14
+ dict: lambda _: None,
15
+ },
16
+ )
17
+ def load():
18
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
19
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
20
+ df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
21
+ embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
22
+ for k in [0, 1]:
23
+ embeddings[k] = embeddings[k] / np.linalg.norm(
24
+ embeddings[k], axis=1, keepdims=True
25
+ )
26
+ return model, processor, df, embeddings
27
+
28
+ model, processor, df, embeddings = load()
29
+ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
30
+
31
+ def compute_text_embeddings(list_of_strings):
32
+ inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
33
+ result = model.get_text_features(**inputs).detach().numpy()
34
+ return result / np.linalg.norm(result, axis=1, keepdims=True)
35
+
36
+ def image_search(query, corpus, n_results=24):
37
+ positive_embeddings = None
38
+
39
+ def concatenate_embeddings(e1, e2):
40
+ if e1 is None:
41
+ return e2
42
+ else:
43
+ return np.concatenate((e1, e2), axis=0)
44
+
45
+ splitted_query = query.split("EXCLUDING ")
46
+ dot_product = 0
47
+ k = 0 if corpus == "Unsplash" else 1
48
+ if len(splitted_query[0]) > 0:
49
+ positive_queries = splitted_query[0].split(";")
50
+ for positive_query in positive_queries:
51
+ match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
52
+ if match:
53
+ corpus2, idx, remainder = match.groups()
54
+ idx, remainder = int(idx), remainder.strip()
55
+ k2 = 0 if corpus2 == "Unsplash" else 1
56
+ positive_embeddings = concatenate_embeddings(
57
+ positive_embeddings, embeddings[k2][idx : idx + 1, :]
58
+ )
59
+ if len(remainder) > 0:
60
+ positive_embeddings = concatenate_embeddings(
61
+ positive_embeddings, compute_text_embeddings([remainder])
62
+ )
63
+ else:
64
+ positive_embeddings = concatenate_embeddings(
65
+ positive_embeddings, compute_text_embeddings([positive_query])
66
+ )
67
+ dot_product = embeddings[k] @ positive_embeddings.T
68
+ dot_product = dot_product - np.median(dot_product, axis=0)
69
+ dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
70
+ dot_product = np.min(dot_product, axis=1)
71
+
72
+ if len(splitted_query) > 1:
73
+ negative_queries = (" ".join(splitted_query[1:])).split(";")
74
+ negative_embeddings = compute_text_embeddings(negative_queries)
75
+ dot_product2 = embeddings[k] @ negative_embeddings.T
76
+ dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
77
+ dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
78
+ dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
79
+
80
+ results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
81
+ return [
82
+ (
83
+ df[k].iloc[i]["path"],
84
+ df[k].iloc[i]["tooltip"] + source[k],
85
+ i,
86
+ )
87
+ for i in results
88
+ ]
89
+
90
+ description = """
91
+ # Semantic image search
92
+ **Enter your query and hit enter**
93
+ """
94
+
95
+ howto = """
96
+ - Click image to find similar images
97
+ - Use "**;**" to combine multiple queries)
98
+ - Use "**EXCLUDING**", to exclude a query
99
+ """
100
+
101
+ def main():
102
+ st.markdown(
103
+ """
104
+ <style>
105
+ .block-container{
106
+ max-width: 1200px;
107
+ }
108
+ div.row-widget.stRadio > div{
109
+ flex-direction:row;
110
+ display: flex;
111
+ justify-content: center;
112
+ }
113
+ div.row-widget.stRadio > div > label{
114
+ margin-left: 5px;
115
+ margin-right: 5px;
116
+ }
117
+ section.main>div:first-child {
118
+ padding-top: 0px;
119
+ }
120
+ section:not(.main)>div:first-child {
121
+ padding-top: 30px;
122
+ }
123
+ div.reportview-container > section:first-child{
124
+ max-width: 320px;
125
+ }
126
+ #MainMenu {
127
+ visibility: hidden;
128
+ }
129
+ footer {
130
+ visibility: hidden;
131
+ }
132
+ </style>""",
133
+ unsafe_allow_html=True,
134
+ )
135
+ st.markdown(description)
136
+ st.markdown(howto)
137
+
138
+ if "query" in st.session_state:
139
+ query = st.text_input("", value=st.session_state["query"])
140
+ else:
141
+ query = st.text_input("", value="lighthouse")
142
+ corpus = st.radio("", ["Unsplash"])
143
+ if len(query) > 0:
144
+ results = image_search(query, corpus)
145
+ clicked = clickable_images(
146
+ [result[0] for result in results],
147
+ titles=[result[1] for result in results],
148
+ div_style={
149
+ "display": "flex",
150
+ "justify-content": "center",
151
+ "flex-wrap": "wrap",
152
+ },
153
+ img_style={"margin": "2px", "height": "200px"},
154
+ )
155
+ if clicked >= 0:
156
+ change_query = False
157
+ if "last_clicked" not in st.session_state:
158
+ change_query = True
159
+ else:
160
+ if clicked != st.session_state["last_clicked"]:
161
+ change_query = True
162
+ if change_query:
163
+ st.session_state["query"] = f"[{corpus}:{results[clicked][2]}]"
164
+ st.experimental_rerun()
165
+
166
+ if __name__ == "__main__":
167
+ main()
data.csv ADDED
The diff for this file is too large to render. See raw diff
 
data2.csv ADDED
The diff for this file is too large to render. See raw diff
 
embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64515f7d3d71137e2944f2c3d72c8df3e684b5d6a6ff7dcebb92370f7326ccfd
3
+ size 76800128
embeddings2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d730b33e758c2648419a96ac86d39516c59795e613c35700d3a64079e5a9a27
3
+ size 25098368
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ numpy
4
+ pandas
5
+ st-clickable-images
6
+ altair<5