Vivien commited on
Commit
5b1c1bd
1 Parent(s): 7a848b2

Improve the composition of queries

Browse files
Files changed (1) hide show
  1. app.py +27 -27
app.py CHANGED
@@ -23,7 +23,6 @@ def load():
23
  embeddings[k] = embeddings[k] / np.linalg.norm(
24
  embeddings[k], axis=1, keepdims=True
25
  )
26
- embeddings[k] = embeddings[k] - np.mean(embeddings[k], axis=0)
27
  return model, processor, df, embeddings
28
 
29
 
@@ -46,39 +45,40 @@ def image_search(query, corpus, n_results=24):
46
  else:
47
  return np.concatenate((e1, e2), axis=0)
48
 
49
- splitted_query = query.split(" EXCLUDING ")
50
-
51
- positive_queries = splitted_query[0].split(";")
52
- for positive_query in positive_queries:
53
- match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
54
- if match:
55
- corpus2, idx, remainder = match.groups()
56
- idx, remainder = int(idx), remainder.strip()
57
- k = 0 if corpus2 == "Unsplash" else 1
58
- positive_embeddings = concatenate_embeddings(
59
- positive_embeddings, embeddings[k][idx : idx + 1, :]
60
- )
61
- if len(remainder) > 0:
62
  positive_embeddings = concatenate_embeddings(
63
- positive_embeddings, compute_text_embeddings([remainder])
64
  )
65
- else:
66
- positive_embeddings = concatenate_embeddings(
67
- positive_embeddings, compute_text_embeddings([positive_query])
68
- )
69
- k = 0 if corpus == "Unsplash" else 1
70
- dot_product = embeddings[k] @ positive_embeddings.T
71
- dot_product = dot_product - np.mean(dot_product, axis=0)
72
- dot_product = dot_product / np.linalg.norm(dot_product, axis=0)
73
- dot_product = np.min(dot_product, axis=1)
 
 
 
74
 
75
  if len(splitted_query) > 1:
76
  negative_queries = (" ".join(splitted_query[1:])).split(";")
77
  negative_embeddings = compute_text_embeddings(negative_queries)
78
  dot_product2 = embeddings[k] @ negative_embeddings.T
79
- dot_product2 = dot_product2 - np.mean(dot_product2, axis=0)
80
- dot_product2 = dot_product2 / np.linalg.norm(dot_product2, axis=0)
81
- dot_product -= np.max(dot_product2, axis=1)
82
 
83
  results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
84
  return [
 
23
  embeddings[k] = embeddings[k] / np.linalg.norm(
24
  embeddings[k], axis=1, keepdims=True
25
  )
 
26
  return model, processor, df, embeddings
27
 
28
 
 
45
  else:
46
  return np.concatenate((e1, e2), axis=0)
47
 
48
+ splitted_query = query.split("EXCLUDING ")
49
+ dot_product = 0
50
+ k = 0 if corpus == "Unsplash" else 1
51
+ if len(splitted_query[0]) > 0:
52
+ positive_queries = splitted_query[0].split(";")
53
+ for positive_query in positive_queries:
54
+ match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
55
+ if match:
56
+ corpus2, idx, remainder = match.groups()
57
+ idx, remainder = int(idx), remainder.strip()
58
+ k2 = 0 if corpus2 == "Unsplash" else 1
 
 
59
  positive_embeddings = concatenate_embeddings(
60
+ positive_embeddings, embeddings[k2][idx : idx + 1, :]
61
  )
62
+ if len(remainder) > 0:
63
+ positive_embeddings = concatenate_embeddings(
64
+ positive_embeddings, compute_text_embeddings([remainder])
65
+ )
66
+ else:
67
+ positive_embeddings = concatenate_embeddings(
68
+ positive_embeddings, compute_text_embeddings([positive_query])
69
+ )
70
+ dot_product = embeddings[k] @ positive_embeddings.T
71
+ dot_product = dot_product - np.median(dot_product, axis=0)
72
+ dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
73
+ dot_product = np.min(dot_product, axis=1)
74
 
75
  if len(splitted_query) > 1:
76
  negative_queries = (" ".join(splitted_query[1:])).split(";")
77
  negative_embeddings = compute_text_embeddings(negative_queries)
78
  dot_product2 = embeddings[k] @ negative_embeddings.T
79
+ dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
80
+ dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
81
+ dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
82
 
83
  results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
84
  return [