Eun0 commited on
Commit
1a03e08
1 Parent(s): 6a497ce

Add app file

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import faiss
5
+ import clip
6
+ import torch
7
+ from datasets import load_dataset
8
+
9
+ title = r"""
10
+ <h1 align="center" id="space-title"> 🔍 Search Similar Text/Image in the Dataset</h1>
11
+ """
12
+
13
+ description = r"""
14
+
15
+ In this demo, we use [DiffusionDB](https://huggingface.co/datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available.
16
+ <br>
17
+ This demo currently supports text search only.
18
+ <br>
19
+ The content will be updated to include image search once LAION is available.
20
+
21
+ The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval) and [autofaiss](https://github.com/criteo/autofaiss)
22
+
23
+ """
24
+
25
+ # From local file
26
+ # INDEX_DIR = "dataset/diffusiondb/text_index_folder"
27
+ # IND = faiss.read_index(f"{INDEX_DIR}/text.index")
28
+ # TEXT_LIST = pd.concat(
29
+ # pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet")
30
+ # )['caption'].tolist()
31
+
32
+ # From huggingface dataset
33
+ from huggingface_hub import hf_hub_download, snapshot_download
34
+
35
+ # Download index file
36
+ hf_hub_download(
37
+ repo_id="Eun02/diffusiondb_faiss_text_index",
38
+ filename="text.index",
39
+ repo_type="dataset",
40
+ local_dir="./",
41
+ )
42
+
43
+ # Download text file
44
+ snapshot_download(
45
+ repo_id="Eun02/diffusiondb_faiss_text_index",
46
+ allow_patterns="*.parquet",
47
+ repo_type="dataset",
48
+ local_dir="./",
49
+ )
50
+
51
+ # Load index and text data
52
+ #root_path = "dataset/diffusiondb/text_index_folder"
53
+ root_path = "."
54
+ IND = faiss.read_index(f"{root_path}/text.index")
55
+ TEXT_LIST = pd.concat(
56
+ pd.read_parquet(file) for file in sorted(glob.glob(f"{root_path}/metadata/*.parquet"))
57
+ )['caption'].tolist()
58
+
59
+ # Load CLIP model
60
+ device = "cpu"
61
+ CLIP_MODEL, _ = clip.load("ViT-B/32", device=device)
62
+
63
+ @torch.inference_mode
64
+ def get_emb(text, device="cpu"):
65
+ text_tokens = clip.tokenize([text], truncate=True)
66
+ text_features = CLIP_MODEL.encode_text(text_tokens.to(device))
67
+ text_features /= text_features.norm(dim=-1, keepdim=True)
68
+ text_embeddings = text_features.cpu().numpy().astype('float32')
69
+ return text_embeddings
70
+
71
+ @torch.inference_mode
72
+ def search_text(dataset, top_k, show_score, query_text, device):
73
+
74
+ if query_text is None or query_text == "":
75
+ raise gr.Error("Query text is missing")
76
+
77
+ text_embeddings = get_emb(query_text, device)
78
+ scores, retrieved_texts = IND.search(text_embeddings, top_k)
79
+ scores, retrieved_texts = scores[0], retrieved_texts[0]
80
+
81
+ result_str = ""
82
+ for score, ind in zip(scores, retrieved_texts):
83
+ item_str = TEXT_LIST[ind].strip()
84
+ if item_str == "":
85
+ continue
86
+ result_str += f"{item_str}"
87
+ if show_score:
88
+ result_str += f", {score:0.2f}"
89
+ result_str += "\n"
90
+
91
+ file_name = query_text.replace(" ", "_")
92
+ if show_score:
93
+ file_name += "_score"
94
+ output_path = f"./{file_name}.txt"
95
+ with open(output_path, "w") as f:
96
+ f.writelines(result_str)
97
+
98
+ return result_str, output_path
99
+
100
+
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown(title)
103
+ gr.Markdown(description)
104
+
105
+ with gr.Row():
106
+ dataset = gr.Dropdown(label="dataset", choices=["DiffusionDB"], value="DiffusionDB")
107
+ top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8)
108
+ show_score = gr.Checkbox(label="Show score", value=True)
109
+ query_text = gr.Textbox(label="query text")
110
+ btn = gr.Button()
111
+ with gr.Row():
112
+ result_text = gr.Textbox(label="retrieved text", interactive=False)
113
+ result_file = gr.File(label="output file")
114
+
115
+ btn.click(
116
+ fn=search_text,
117
+ inputs=[dataset, top_k, show_score, query_text],
118
+ outputs=[result_text, result_file],
119
+ )
120
+
121
+ demo.launch()