import gradio as gr import faiss from datasets import load_dataset from transformers import pipeline from sentence_transformers import SentenceTransformer import numpy as np import os import torch pipe = None sample = None model_repo = 'Solenya-ai/dill' def search(image=None, query=None): global pipe global sample if pipe is None: pipe = SentenceTransformer(model_repo) if sample is None: sample = load_dataset(os.environ['data_id'])['train'] sample.add_faiss_index(column='embeddings', metric_type=faiss.METRIC_INNER_PRODUCT) if query is not None: text_embed = np.array(pipe.encode(query)) if image is not None: image_embed = np.array(pipe.encode(image)) if query is not None and image is not None: text_embed_normalized = text_embed / np.linalg.norm(text_embed) image_embed_normalized = image_embed / np.linalg.norm(image_embed) embed = 2*text_embed_normalized + image_embed_normalized embed = embed / np.linalg.norm(embed) # we renormalize elif image is not None: embed = image_embed else: embed = text_embed scores, retrieved_examples = sample.get_nearest_examples('embeddings', embed, k=10) return retrieved_examples['image'] iface = gr.Interface(search, [gr.inputs.Image(type="pil"), gr.inputs.Textbox(lines=1, placeholder="Type here to search for similar images...")], gr.Gallery(columns=3), title="Search for similar images", ) iface.launch(share=False)