from transformers import ViTModel, ViTImageProcessor from PIL import Image, ImageOps import gradio as gr import torch from datasets import Dataset from torch.nn import CosineSimilarity image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval() scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval() candidates: Dataset = None cosinesimilarity = CosineSimilarity() def load_candidates(candidate_dir, progress=gr.Progress()): def preprocess(examples): images = [image for image in examples["image"]] examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"] progress.update(len(images)) return examples dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in progress.tqdm(candidate_dir)] dataset = Dataset.from_list(dataset) progress.tqdm(dataset) with torch.no_grad(): dataset = dataset.map(preprocess, batched=True, batch_size=1) return dataset def load_candidates_in_cache(candidate_files): global candidates candidates = load_candidates(candidate_files) return [f.name for f in candidate_files] def scribble_matching(input_img: Image): input_img = ImageOps.invert(input_img) scribble = input_img scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"])["pooler_output"].to("cpu") image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32) sim = cosinesimilarity(scribble_embedding, image_embeddings) predicts = torch.topk(sim, k=15) output_imgs = candidates[predicts.indices.tolist()]["image"] labels = predicts.values.tolist() labels = [f"{label:.3f}" for label in labels] return list(zip([input_img] + output_imgs, ["preview"] + labels)) def main(): with gr.Blocks() as demo: with gr.Row(): input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10) prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True) with gr.Row(): candidate_dir = gr.File(file_count="directory", min_width=300, height=300) load_candidates_btn = gr.Button("Load", variant="secondary", size="sm") btn = gr.Button("Scribble Matching", variant="primary") load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir], outputs=candidate_dir) btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery]) demo.queue().launch() if __name__ == "__main__": main()