File size: 1,672 Bytes
792ca73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
from safetensors import safe_open
from datasets import load_dataset
import torch
from multilingual_clip import pt_multilingual_clip
import transformers
import gradio as gr


def load_embeddings(file_path, key="vectors"):
    with safe_open(file_path, framework="numpy") as f:
        embeddings = f.get_tensor(key)
    return embeddings


image_embeddings = load_embeddings("clothes_desc.safetensors")


image_embeddings = image_embeddings / np.linalg.norm(
    image_embeddings, axis=1, keepdims=True
)


ds = load_dataset("wbensvage/clothes_desc")["train"]

model_name = "M-CLIP/LABSE-Vit-L-14"
model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)


def encode_text(texts, model, tokenizer):
    with torch.no_grad():
        embs = model.forward(texts, tokenizer)
    embs = embs.detach().cpu().numpy()
    embs = embs / np.linalg.norm(embs, axis=1, keepdims=True)
    return embs


def find_images(query, top_k):
    query_embedding = encode_text([query], model, tokenizer)
    similarity = np.dot(query_embedding, image_embeddings.T)
    top_k_indices = np.argsort(-similarity[0])[:top_k]
    images = [ds[int(i)]["image"] for i in top_k_indices]
    return images


iface = gr.Interface(
    fn=find_images,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter search text here...", label="Query"),
        gr.Slider(10, 50, step=10, value=20, label="Number of images"),
    ],
    outputs=gr.Gallery(label="Search Results", columns=5, height="auto"),
    title="Multilingual CLIP Image Search",
    description="Enter a text query",
)

iface.launch()