Eun0 commited on
Commit
d9c0937
β€’
1 Parent(s): df430f0

Add prefix and Optimize load index

Browse files
Files changed (1) hide show
  1. app.py +67 -35
app.py CHANGED
@@ -4,6 +4,7 @@ import pandas as pd
4
  import faiss
5
  import clip
6
  import torch
 
7
 
8
  title = r"""
9
  <h1 align="center" id="space-title"> πŸ” Search Similar Text/Image in the Dataset</h1>
@@ -11,9 +12,8 @@ title = r"""
11
 
12
  description = r"""
13
 
14
- In this demo, we use subset of [danbooru22](https://huggingface.co/datasets/animelover/danbooru2022) or [DiffusionDB](https://huggingface.co/datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available.
15
- <br>
16
- This demo currently supports text search only.
17
  <br>
18
  The content will be updated to include image search once LAION is available.
19
 
@@ -21,7 +21,6 @@ The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval)
21
 
22
  """
23
 
24
- #In this demo, we use because LAION is currently not available.
25
  # From local file
26
  # INDEX_DIR = "dataset/diffusiondb/text_index_folder"
27
  # IND = faiss.read_index(f"{INDEX_DIR}/text.index")
@@ -29,13 +28,9 @@ The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval)
29
  # pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet")
30
  # )['caption'].tolist()
31
 
32
- # From huggingface dataset
33
- from huggingface_hub import hf_hub_download, snapshot_download
34
-
35
- DATASET_NAME = {
36
- "danbooru22": "booru22_000-300",
37
- "DiffusionDB": "diffusiondb",
38
- }
39
 
40
  def load_faiss_index(dataset):
41
  index_dir = "data/faiss_index"
@@ -63,9 +58,13 @@ def load_faiss_index(dataset):
63
 
64
  return index, text_list
65
 
66
- # Load CLIP model
67
- device = "cpu"
68
- CLIP_MODEL, _ = clip.load("ViT-B/32", device=device)
 
 
 
 
69
 
70
  @torch.inference_mode
71
  def get_emb(text, device="cpu"):
@@ -76,55 +75,88 @@ def get_emb(text, device="cpu"):
76
  return text_embeddings
77
 
78
  @torch.inference_mode
79
- def search_text(dataset, top_k, show_score, query_text, device):
80
-
81
- ind, text_list = load_faiss_index(dataset)
82
-
83
  if query_text is None or query_text == "":
84
  raise gr.Error("Query text is missing")
85
 
86
  text_embeddings = get_emb(query_text, device)
87
- scores, retrieved_texts = ind.search(text_embeddings, top_k)
88
  scores, retrieved_texts = scores[0], retrieved_texts[0]
89
 
90
- result_str = ""
91
  for score, ind in zip(scores, retrieved_texts):
92
- item_str = text_list[ind].strip()
93
  if item_str == "":
94
  continue
95
- result_str += f"{item_str}"
 
 
 
 
 
 
 
96
  if show_score:
97
- result_str += f", {score:0.2f}"
98
- result_str += "\n"
99
-
100
  # file_name = query_text.replace(" ", "_")
101
  # if show_score:
102
  # file_name += "_score"
103
- file_name = "output"
104
- output_path = f"./{file_name}.txt"
105
- with open(output_path, "w") as f:
106
- f.writelines(result_str)
 
 
107
 
108
  return result_str, output_path
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  with gr.Blocks() as demo:
112
  gr.Markdown(title)
113
  gr.Markdown(description)
114
 
115
  with gr.Row():
116
- dataset = gr.Dropdown(label="dataset", choices=["danbooru22", "DiffusionDB"], value="DiffusionDB")
117
  top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8)
118
- show_score = gr.Checkbox(label="Show score", value=True)
 
 
 
119
  query_text = gr.Textbox(label="query text")
120
  btn = gr.Button()
121
- with gr.Row():
122
- result_text = gr.Textbox(label="retrieved text", interactive=False)
123
- result_file = gr.File(label="output file")
 
124
 
125
  btn.click(
 
 
 
126
  fn=search_text,
127
- inputs=[dataset, top_k, show_score, query_text],
128
  outputs=[result_text, result_file],
129
  )
130
 
 
4
  import faiss
5
  import clip
6
  import torch
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
 
9
  title = r"""
10
  <h1 align="center" id="space-title"> πŸ” Search Similar Text/Image in the Dataset</h1>
 
12
 
13
  description = r"""
14
 
15
+ Find text or images similar to your query text with this demo. Currently, it supports text search only.<br>
16
+ In this demo, we use a subset of [danbooru22](https://huggingface.co/datasets/animelover/danbooru2022) or [DiffusionDB](https://huggingface.co/datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available.
 
17
  <br>
18
  The content will be updated to include image search once LAION is available.
19
 
 
21
 
22
  """
23
 
 
24
  # From local file
25
  # INDEX_DIR = "dataset/diffusiondb/text_index_folder"
26
  # IND = faiss.read_index(f"{INDEX_DIR}/text.index")
 
28
  # pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet")
29
  # )['caption'].tolist()
30
 
31
+ def download_all_index(dataset_dict):
32
+ for k in dataset_dict:
33
+ load_faiss_index(k)
 
 
 
 
34
 
35
  def load_faiss_index(dataset):
36
  index_dir = "data/faiss_index"
 
58
 
59
  return index, text_list
60
 
61
+ def change_index(dataset):
62
+ global INDEX, TEXT_LIST, PREV_DATASET
63
+ if PREV_DATASET != dataset:
64
+ gr.Info("Load index...")
65
+ INDEX, TEXT_LIST = load_faiss_index(dataset)
66
+ PREV_DATASET = dataset
67
+ gr.Info("Done!!")
68
 
69
  @torch.inference_mode
70
  def get_emb(text, device="cpu"):
 
75
  return text_embeddings
76
 
77
  @torch.inference_mode
78
+ def search_text(top_k, show_score, numbering_prefix, output_file, query_text):
 
 
 
79
  if query_text is None or query_text == "":
80
  raise gr.Error("Query text is missing")
81
 
82
  text_embeddings = get_emb(query_text, device)
83
+ scores, retrieved_texts = INDEX.search(text_embeddings, top_k)
84
  scores, retrieved_texts = scores[0], retrieved_texts[0]
85
 
86
+ result_list = []
87
  for score, ind in zip(scores, retrieved_texts):
88
+ item_str = TEXT_LIST[ind].strip()
89
  if item_str == "":
90
  continue
91
+ if (item_str, score) not in result_list:
92
+ result_list.append((item_str, score))
93
+
94
+ # Postprocessing text
95
+ result_str = ""
96
+ for count, (item_str, score) in enumerate(result_list):
97
+ if numbering_prefix:
98
+ item_str = f"###################### {count+1} ######################\n {item_str}"
99
  if show_score:
100
+ item_str += f", {score:0.2f}"
101
+ result_str += f"{item_str}\n"
102
+
103
  # file_name = query_text.replace(" ", "_")
104
  # if show_score:
105
  # file_name += "_score"
106
+ output_path = None
107
+ if output_file:
108
+ file_name = "output"
109
+ output_path = f"./{file_name}.txt"
110
+ with open(output_path, "w") as f:
111
+ f.writelines(result_str)
112
 
113
  return result_str, output_path
114
 
115
 
116
+ # Load CLIP model
117
+ device = "cpu"
118
+ CLIP_MODEL, _ = clip.load("ViT-B/32", device=device)
119
+
120
+ # Dataset
121
+ DATASET_NAME = {
122
+ "danbooru22": "booru22_000-300",
123
+ "DiffusionDB": "diffusiondb",
124
+ }
125
+
126
+ DEFAULT_DATASET = "danbooru22"
127
+ PREV_DATASET = "danbooru22"
128
+
129
+ # Download needed index
130
+ download_all_index(DATASET_NAME)
131
+
132
+ # Load default index
133
+ INDEX, TEXT_LIST = load_faiss_index(DEFAULT_DATASET)
134
+
135
+
136
  with gr.Blocks() as demo:
137
  gr.Markdown(title)
138
  gr.Markdown(description)
139
 
140
  with gr.Row():
141
+ dataset = gr.Dropdown(label="dataset", choices=["danbooru22", "DiffusionDB"], value=DEFAULT_DATASET)
142
  top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8)
143
+ with gr.Column():
144
+ show_score = gr.Checkbox(label="Show score", value=False)
145
+ numbering_prefix = gr.Checkbox(label="Add numbering prefix", value=True)
146
+ output_file = gr.Checkbox(label="Return text file", value=True)
147
  query_text = gr.Textbox(label="query text")
148
  btn = gr.Button()
149
+ result_text = gr.Textbox(label="retrieved text", interactive=False)
150
+ result_file = gr.File(label="output file", visible=True)
151
+
152
+ #dataset.change(change_index, dataset, None)
153
 
154
  btn.click(
155
+ fn=change_index,
156
+ inputs=[dataset],
157
+ ).success(
158
  fn=search_text,
159
+ inputs=[top_k, show_score, numbering_prefix, output_file, query_text],
160
  outputs=[result_text, result_file],
161
  )
162