adorkin commited on
Commit
792ca73
1 Parent(s): 84fd25c

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +59 -0
  2. clothes_desc.safetensors +3 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from safetensors import safe_open
3
+ from datasets import load_dataset
4
+ import torch
5
+ from multilingual_clip import pt_multilingual_clip
6
+ import transformers
7
+ import gradio as gr
8
+ import clip
9
+
10
+
11
+ def load_embeddings(file_path, key="vectors"):
12
+ with safe_open(file_path, framework="numpy") as f:
13
+ embeddings = f.get_tensor(key)
14
+ return embeddings
15
+
16
+
17
+ image_embeddings = load_embeddings("clothes_desc.safetensors")
18
+
19
+
20
+ image_embeddings = image_embeddings / np.linalg.norm(
21
+ image_embeddings, axis=1, keepdims=True
22
+ )
23
+
24
+
25
+ ds = load_dataset("wbensvage/clothes_desc")["train"]
26
+
27
+ model_name = "M-CLIP/LABSE-Vit-L-14"
28
+ model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name)
29
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
30
+
31
+
32
+ def encode_text(texts, model, tokenizer):
33
+ with torch.no_grad():
34
+ embs = model.forward(texts, tokenizer)
35
+ embs = embs.detach().cpu().numpy()
36
+ embs = embs / np.linalg.norm(embs, axis=1, keepdims=True)
37
+ return embs
38
+
39
+
40
+ def find_images(query, top_k):
41
+ query_embedding = encode_text([query], model, tokenizer)
42
+ similarity = np.dot(query_embedding, image_embeddings.T)
43
+ top_k_indices = np.argsort(-similarity[0])[:top_k]
44
+ images = [ds[int(i)]["image"] for i in top_k_indices]
45
+ return images
46
+
47
+
48
+ iface = gr.Interface(
49
+ fn=find_images,
50
+ inputs=[
51
+ gr.Textbox(lines=2, placeholder="Enter search text here...", label="Query"),
52
+ gr.Slider(10, 50, step=10, value=20, label="Number of images"),
53
+ ],
54
+ outputs=gr.Gallery(label="Search Results", columns=5, height="auto"),
55
+ title="Multilingual CLIP Image Search",
56
+ description="Enter a text query",
57
+ )
58
+
59
+ iface.launch()
clothes_desc.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2d70d7a406ceb193d93bacc36f1e9b83b8c0008ce478cce9826f3cccc702c79
3
+ size 1536088
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ multilingual-clip
4
+ safetensors
5
+ gradio
6
+ numpy
7
+ datasets