Spaces:
Running
Running
File size: 3,134 Bytes
b4f7e81 a451dfa b4f7e81 3d49071 b4f7e81 3d49071 b4f7e81 3d49071 b4f7e81 3d49071 b4f7e81 3d49071 b4f7e81 3d49071 b4f7e81 0272604 b4f7e81 7e13603 b4f7e81 3d49071 d56b828 3d49071 b4f7e81 3d49071 b4f7e81 f130359 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
from facenet_pytorch import InceptionResnetV1
import torch.nn as nn
import torchvision.transforms as tf
import numpy as np
import torch
import faiss
import h5py
import tqdm
import os
import random
from PIL import Image
import matplotlib.cm as cm
import matplotlib as mpl
img_names = []
with open('list_eval_partition.txt', 'r') as f:
for line in f:
img_name, dtype = line.rstrip().split(' ')
img_names.append(img_name)
# For a model pretrained on VGGFace2
print('Loading model weights ........')
class SiameseModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = InceptionResnetV1(pretrained='vggface2')
def forward(self, x):
x = self.backbone(x)
x = torch.nn.functional.normalize(x, dim=1)
return x
model = SiameseModel()
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
model.eval()
# Make FAISS index
print('Make index .............')
index = faiss.IndexFlatL2(512)
hf = h5py.File('face_vecs_full.h5', 'r')
for key in tqdm.tqdm(hf.keys()):
vec = np.array(hf.get(key))
index.add(vec)
hf.close()
print("Finished indexing")
# Function to search image
def image_search(image, k=5):
transform = tf.Compose([
tf.Resize((160, 160)),
tf.ToTensor()
])
query_img = transform(image)
query_img = torch.unsqueeze(query_img, 0)
model.eval()
query_vec = model(query_img).detach().numpy()
D, I = index.search(query_vec, k=k)
retrieval_imgs = []
FOLDER = 'img_align_celeba'
for idx in I[0]:
img_file_name = img_names[idx]
path = os.path.join(FOLDER, img_file_name)
image = Image.open(path)
retrieval_imgs.append((image, ''))
return retrieval_imgs
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown('''
# Face Image Retrieval with Content-based Image Retrieval (CBIR)
--------
''')
with gr.Row():
with gr.Column():
image = gr.Image(type='pil', scale=1)
slider = gr.Slider(1, 10, value=5, step=1, label='Number of retrieval image')
with gr.Row():
btn = gr.Button('Search')
clear_btn = gr.ClearButton()
gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2)
img_dir = './img_align_celeba'
examples = random.choices(img_names, k=5)
examples = [os.path.join(img_dir, ex) for ex in examples]
examples = [Image.open(img) for img in examples]
with gr.Row():
gr.Examples(
examples = examples,
inputs = image
)
btn.click(image_search,
inputs= [image, slider],
outputs= [gallery])
def clear_image():
return None
clear_btn.click(
fn = clear_image,
inputs = [],
outputs = [image]
)
if __name__ == "__main__":
demo.launch(server_name = "0.0.0.0", server_port = 7860) |