Eun0 commited on
Commit
ec82f37
1 Parent(s): f9767c2

Change hugginface dataset

Browse files
Files changed (1) hide show
  1. app.py +27 -25
app.py CHANGED
@@ -31,29 +31,29 @@ The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval)
31
  # From huggingface dataset
32
  from huggingface_hub import hf_hub_download, snapshot_download
33
 
34
- # Download index file
35
- hf_hub_download(
36
- repo_id="Eun02/diffusiondb_faiss_text_index",
37
- filename="text.index",
38
- repo_type="dataset",
39
- local_dir="./",
40
- )
41
-
42
- # Download text file
43
- snapshot_download(
44
- repo_id="Eun02/diffusiondb_faiss_text_index",
45
- allow_patterns="*.parquet",
46
- repo_type="dataset",
47
- local_dir="./",
48
- )
49
-
50
- # Load index and text data
51
- #root_path = "dataset/diffusiondb/text_index_folder"
52
- root_path = "."
53
- IND = faiss.read_index(f"{root_path}/text.index")
54
- TEXT_LIST = pd.concat(
55
- pd.read_parquet(file) for file in sorted(glob.glob(f"{root_path}/metadata/*.parquet"))
56
- )['caption'].tolist()
57
 
58
  # Load CLIP model
59
  device = "cpu"
@@ -70,16 +70,18 @@ def get_emb(text, device="cpu"):
70
  @torch.inference_mode
71
  def search_text(dataset, top_k, show_score, query_text, device):
72
 
 
 
73
  if query_text is None or query_text == "":
74
  raise gr.Error("Query text is missing")
75
 
76
  text_embeddings = get_emb(query_text, device)
77
- scores, retrieved_texts = IND.search(text_embeddings, top_k)
78
  scores, retrieved_texts = scores[0], retrieved_texts[0]
79
 
80
  result_str = ""
81
  for score, ind in zip(scores, retrieved_texts):
82
- item_str = TEXT_LIST[ind].strip()
83
  if item_str == "":
84
  continue
85
  result_str += f"{item_str}"
 
31
  # From huggingface dataset
32
  from huggingface_hub import hf_hub_download, snapshot_download
33
 
34
+ def load_faiss_index(dataset):
35
+ index_dir = "data/faiss_index"
36
+ hf_hub_download(
37
+ repo_id="Eun02/text_image_faiss_index",
38
+ subfolder=dataset,
39
+ filename="text.index",
40
+ repo_type="dataset",
41
+ local_dir=index_dir,
42
+ )
43
+
44
+ # Download text file
45
+ snapshot_download(
46
+ repo_id="Eun02/text_image_faiss_index",
47
+ allow_patterns=f"{dataset}/*.parquet",
48
+ repo_type="dataset",
49
+ local_dir=index_dir,
50
+ )
51
+ index = faiss.read_index(f"{index_dir}/{dataset}/text.index")
52
+ text_list = pd.concat(
53
+ pd.read_parquet(file) for file in sorted(glob.glob(f"{index_dir}/{dataset}/metadata/*.parquet"))
54
+ )['caption'].tolist()
55
+
56
+ return index, text_list
57
 
58
  # Load CLIP model
59
  device = "cpu"
 
70
  @torch.inference_mode
71
  def search_text(dataset, top_k, show_score, query_text, device):
72
 
73
+ ind, text_list = load_faiss_index(dataset)
74
+
75
  if query_text is None or query_text == "":
76
  raise gr.Error("Query text is missing")
77
 
78
  text_embeddings = get_emb(query_text, device)
79
+ scores, retrieved_texts = ind.search(text_embeddings, top_k)
80
  scores, retrieved_texts = scores[0], retrieved_texts[0]
81
 
82
  result_str = ""
83
  for score, ind in zip(scores, retrieved_texts):
84
+ item_str = text_list[ind].strip()
85
  if item_str == "":
86
  continue
87
  result_str += f"{item_str}"