Spaces:
Runtime error
Runtime error
Add prefix and Optimize load index
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import pandas as pd
|
|
4 |
import faiss
|
5 |
import clip
|
6 |
import torch
|
|
|
7 |
|
8 |
title = r"""
|
9 |
<h1 align="center" id="space-title"> π Search Similar Text/Image in the Dataset</h1>
|
@@ -11,9 +12,8 @@ title = r"""
|
|
11 |
|
12 |
description = r"""
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
This demo currently supports text search only.
|
17 |
<br>
|
18 |
The content will be updated to include image search once LAION is available.
|
19 |
|
@@ -21,7 +21,6 @@ The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval)
|
|
21 |
|
22 |
"""
|
23 |
|
24 |
-
#In this demo, we use because LAION is currently not available.
|
25 |
# From local file
|
26 |
# INDEX_DIR = "dataset/diffusiondb/text_index_folder"
|
27 |
# IND = faiss.read_index(f"{INDEX_DIR}/text.index")
|
@@ -29,13 +28,9 @@ The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval)
|
|
29 |
# pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet")
|
30 |
# )['caption'].tolist()
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
DATASET_NAME = {
|
36 |
-
"danbooru22": "booru22_000-300",
|
37 |
-
"DiffusionDB": "diffusiondb",
|
38 |
-
}
|
39 |
|
40 |
def load_faiss_index(dataset):
|
41 |
index_dir = "data/faiss_index"
|
@@ -63,9 +58,13 @@ def load_faiss_index(dataset):
|
|
63 |
|
64 |
return index, text_list
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
69 |
|
70 |
@torch.inference_mode
|
71 |
def get_emb(text, device="cpu"):
|
@@ -76,55 +75,88 @@ def get_emb(text, device="cpu"):
|
|
76 |
return text_embeddings
|
77 |
|
78 |
@torch.inference_mode
|
79 |
-
def search_text(
|
80 |
-
|
81 |
-
ind, text_list = load_faiss_index(dataset)
|
82 |
-
|
83 |
if query_text is None or query_text == "":
|
84 |
raise gr.Error("Query text is missing")
|
85 |
|
86 |
text_embeddings = get_emb(query_text, device)
|
87 |
-
scores, retrieved_texts =
|
88 |
scores, retrieved_texts = scores[0], retrieved_texts[0]
|
89 |
|
90 |
-
|
91 |
for score, ind in zip(scores, retrieved_texts):
|
92 |
-
item_str =
|
93 |
if item_str == "":
|
94 |
continue
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
if show_score:
|
97 |
-
|
98 |
-
result_str += "\n"
|
99 |
-
|
100 |
# file_name = query_text.replace(" ", "_")
|
101 |
# if show_score:
|
102 |
# file_name += "_score"
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
f.
|
|
|
|
|
107 |
|
108 |
return result_str, output_path
|
109 |
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
with gr.Blocks() as demo:
|
112 |
gr.Markdown(title)
|
113 |
gr.Markdown(description)
|
114 |
|
115 |
with gr.Row():
|
116 |
-
dataset = gr.Dropdown(label="dataset", choices=["danbooru22", "DiffusionDB"], value=
|
117 |
top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8)
|
118 |
-
|
|
|
|
|
|
|
119 |
query_text = gr.Textbox(label="query text")
|
120 |
btn = gr.Button()
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
124 |
|
125 |
btn.click(
|
|
|
|
|
|
|
126 |
fn=search_text,
|
127 |
-
inputs=[
|
128 |
outputs=[result_text, result_file],
|
129 |
)
|
130 |
|
|
|
4 |
import faiss
|
5 |
import clip
|
6 |
import torch
|
7 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
8 |
|
9 |
title = r"""
|
10 |
<h1 align="center" id="space-title"> π Search Similar Text/Image in the Dataset</h1>
|
|
|
12 |
|
13 |
description = r"""
|
14 |
|
15 |
+
Find text or images similar to your query text with this demo. Currently, it supports text search only.<br>
|
16 |
+
In this demo, we use a subset of [danbooru22](https://huggingface.co/datasets/animelover/danbooru2022) or [DiffusionDB](https://huggingface.co/datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available.
|
|
|
17 |
<br>
|
18 |
The content will be updated to include image search once LAION is available.
|
19 |
|
|
|
21 |
|
22 |
"""
|
23 |
|
|
|
24 |
# From local file
|
25 |
# INDEX_DIR = "dataset/diffusiondb/text_index_folder"
|
26 |
# IND = faiss.read_index(f"{INDEX_DIR}/text.index")
|
|
|
28 |
# pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet")
|
29 |
# )['caption'].tolist()
|
30 |
|
31 |
+
def download_all_index(dataset_dict):
|
32 |
+
for k in dataset_dict:
|
33 |
+
load_faiss_index(k)
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def load_faiss_index(dataset):
|
36 |
index_dir = "data/faiss_index"
|
|
|
58 |
|
59 |
return index, text_list
|
60 |
|
61 |
+
def change_index(dataset):
|
62 |
+
global INDEX, TEXT_LIST, PREV_DATASET
|
63 |
+
if PREV_DATASET != dataset:
|
64 |
+
gr.Info("Load index...")
|
65 |
+
INDEX, TEXT_LIST = load_faiss_index(dataset)
|
66 |
+
PREV_DATASET = dataset
|
67 |
+
gr.Info("Done!!")
|
68 |
|
69 |
@torch.inference_mode
|
70 |
def get_emb(text, device="cpu"):
|
|
|
75 |
return text_embeddings
|
76 |
|
77 |
@torch.inference_mode
|
78 |
+
def search_text(top_k, show_score, numbering_prefix, output_file, query_text):
|
|
|
|
|
|
|
79 |
if query_text is None or query_text == "":
|
80 |
raise gr.Error("Query text is missing")
|
81 |
|
82 |
text_embeddings = get_emb(query_text, device)
|
83 |
+
scores, retrieved_texts = INDEX.search(text_embeddings, top_k)
|
84 |
scores, retrieved_texts = scores[0], retrieved_texts[0]
|
85 |
|
86 |
+
result_list = []
|
87 |
for score, ind in zip(scores, retrieved_texts):
|
88 |
+
item_str = TEXT_LIST[ind].strip()
|
89 |
if item_str == "":
|
90 |
continue
|
91 |
+
if (item_str, score) not in result_list:
|
92 |
+
result_list.append((item_str, score))
|
93 |
+
|
94 |
+
# Postprocessing text
|
95 |
+
result_str = ""
|
96 |
+
for count, (item_str, score) in enumerate(result_list):
|
97 |
+
if numbering_prefix:
|
98 |
+
item_str = f"###################### {count+1} ######################\n {item_str}"
|
99 |
if show_score:
|
100 |
+
item_str += f", {score:0.2f}"
|
101 |
+
result_str += f"{item_str}\n"
|
102 |
+
|
103 |
# file_name = query_text.replace(" ", "_")
|
104 |
# if show_score:
|
105 |
# file_name += "_score"
|
106 |
+
output_path = None
|
107 |
+
if output_file:
|
108 |
+
file_name = "output"
|
109 |
+
output_path = f"./{file_name}.txt"
|
110 |
+
with open(output_path, "w") as f:
|
111 |
+
f.writelines(result_str)
|
112 |
|
113 |
return result_str, output_path
|
114 |
|
115 |
|
116 |
+
# Load CLIP model
|
117 |
+
device = "cpu"
|
118 |
+
CLIP_MODEL, _ = clip.load("ViT-B/32", device=device)
|
119 |
+
|
120 |
+
# Dataset
|
121 |
+
DATASET_NAME = {
|
122 |
+
"danbooru22": "booru22_000-300",
|
123 |
+
"DiffusionDB": "diffusiondb",
|
124 |
+
}
|
125 |
+
|
126 |
+
DEFAULT_DATASET = "danbooru22"
|
127 |
+
PREV_DATASET = "danbooru22"
|
128 |
+
|
129 |
+
# Download needed index
|
130 |
+
download_all_index(DATASET_NAME)
|
131 |
+
|
132 |
+
# Load default index
|
133 |
+
INDEX, TEXT_LIST = load_faiss_index(DEFAULT_DATASET)
|
134 |
+
|
135 |
+
|
136 |
with gr.Blocks() as demo:
|
137 |
gr.Markdown(title)
|
138 |
gr.Markdown(description)
|
139 |
|
140 |
with gr.Row():
|
141 |
+
dataset = gr.Dropdown(label="dataset", choices=["danbooru22", "DiffusionDB"], value=DEFAULT_DATASET)
|
142 |
top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8)
|
143 |
+
with gr.Column():
|
144 |
+
show_score = gr.Checkbox(label="Show score", value=False)
|
145 |
+
numbering_prefix = gr.Checkbox(label="Add numbering prefix", value=True)
|
146 |
+
output_file = gr.Checkbox(label="Return text file", value=True)
|
147 |
query_text = gr.Textbox(label="query text")
|
148 |
btn = gr.Button()
|
149 |
+
result_text = gr.Textbox(label="retrieved text", interactive=False)
|
150 |
+
result_file = gr.File(label="output file", visible=True)
|
151 |
+
|
152 |
+
#dataset.change(change_index, dataset, None)
|
153 |
|
154 |
btn.click(
|
155 |
+
fn=change_index,
|
156 |
+
inputs=[dataset],
|
157 |
+
).success(
|
158 |
fn=search_text,
|
159 |
+
inputs=[top_k, show_score, numbering_prefix, output_file, query_text],
|
160 |
outputs=[result_text, result_file],
|
161 |
)
|
162 |
|