CLIP-Retrieval / app.py
Eun0's picture
Add prefix and Optimize load index
d9c0937
raw
history blame
5.34 kB
import glob
import gradio as gr
import pandas as pd
import faiss
import clip
import torch
from huggingface_hub import hf_hub_download, snapshot_download
title = r"""
<h1 align="center" id="space-title"> πŸ” Search Similar Text/Image in the Dataset</h1>
"""
description = r"""
Find text or images similar to your query text with this demo. Currently, it supports text search only.<br>
In this demo, we use a subset of [danbooru22](https://huggingface.co./datasets/animelover/danbooru2022) or [DiffusionDB](https://huggingface.co./datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available.
<br>
The content will be updated to include image search once LAION is available.
The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval) and [autofaiss](https://github.com/criteo/autofaiss)
"""
# From local file
# INDEX_DIR = "dataset/diffusiondb/text_index_folder"
# IND = faiss.read_index(f"{INDEX_DIR}/text.index")
# TEXT_LIST = pd.concat(
# pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet")
# )['caption'].tolist()
def download_all_index(dataset_dict):
for k in dataset_dict:
load_faiss_index(k)
def load_faiss_index(dataset):
index_dir = "data/faiss_index"
dataset = DATASET_NAME[dataset]
hf_hub_download(
repo_id="Eun02/text_image_faiss_index",
subfolder=dataset,
filename="text.index",
repo_type="dataset",
local_dir=index_dir,
)
# Download text file
snapshot_download(
repo_id="Eun02/text_image_faiss_index",
allow_patterns=f"{dataset}/*.parquet",
repo_type="dataset",
local_dir=index_dir,
)
index = faiss.read_index(f"{index_dir}/{dataset}/text.index")
text_list = pd.concat(
pd.read_parquet(file) for file in sorted(glob.glob(f"{index_dir}/{dataset}/metadata/*.parquet"))
)['caption'].tolist()
return index, text_list
def change_index(dataset):
global INDEX, TEXT_LIST, PREV_DATASET
if PREV_DATASET != dataset:
gr.Info("Load index...")
INDEX, TEXT_LIST = load_faiss_index(dataset)
PREV_DATASET = dataset
gr.Info("Done!!")
@torch.inference_mode
def get_emb(text, device="cpu"):
text_tokens = clip.tokenize([text], truncate=True)
text_features = CLIP_MODEL.encode_text(text_tokens.to(device))
text_features /= text_features.norm(dim=-1, keepdim=True)
text_embeddings = text_features.cpu().numpy().astype('float32')
return text_embeddings
@torch.inference_mode
def search_text(top_k, show_score, numbering_prefix, output_file, query_text):
if query_text is None or query_text == "":
raise gr.Error("Query text is missing")
text_embeddings = get_emb(query_text, device)
scores, retrieved_texts = INDEX.search(text_embeddings, top_k)
scores, retrieved_texts = scores[0], retrieved_texts[0]
result_list = []
for score, ind in zip(scores, retrieved_texts):
item_str = TEXT_LIST[ind].strip()
if item_str == "":
continue
if (item_str, score) not in result_list:
result_list.append((item_str, score))
# Postprocessing text
result_str = ""
for count, (item_str, score) in enumerate(result_list):
if numbering_prefix:
item_str = f"###################### {count+1} ######################\n {item_str}"
if show_score:
item_str += f", {score:0.2f}"
result_str += f"{item_str}\n"
# file_name = query_text.replace(" ", "_")
# if show_score:
# file_name += "_score"
output_path = None
if output_file:
file_name = "output"
output_path = f"./{file_name}.txt"
with open(output_path, "w") as f:
f.writelines(result_str)
return result_str, output_path
# Load CLIP model
device = "cpu"
CLIP_MODEL, _ = clip.load("ViT-B/32", device=device)
# Dataset
DATASET_NAME = {
"danbooru22": "booru22_000-300",
"DiffusionDB": "diffusiondb",
}
DEFAULT_DATASET = "danbooru22"
PREV_DATASET = "danbooru22"
# Download needed index
download_all_index(DATASET_NAME)
# Load default index
INDEX, TEXT_LIST = load_faiss_index(DEFAULT_DATASET)
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
dataset = gr.Dropdown(label="dataset", choices=["danbooru22", "DiffusionDB"], value=DEFAULT_DATASET)
top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8)
with gr.Column():
show_score = gr.Checkbox(label="Show score", value=False)
numbering_prefix = gr.Checkbox(label="Add numbering prefix", value=True)
output_file = gr.Checkbox(label="Return text file", value=True)
query_text = gr.Textbox(label="query text")
btn = gr.Button()
result_text = gr.Textbox(label="retrieved text", interactive=False)
result_file = gr.File(label="output file", visible=True)
#dataset.change(change_index, dataset, None)
btn.click(
fn=change_index,
inputs=[dataset],
).success(
fn=search_text,
inputs=[top_k, show_score, numbering_prefix, output_file, query_text],
outputs=[result_text, result_file],
)
demo.launch()