Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -13,40 +13,40 @@ from PIL import Image
|
|
13 |
import matplotlib.cm as cm
|
14 |
import matplotlib as mpl
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
|
22 |
|
23 |
# For a model pretrained on VGGFace2
|
24 |
print('Loading model weights ........')
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
|
39 |
|
40 |
# Make FAISS index
|
41 |
-
|
42 |
-
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
|
49 |
-
|
50 |
|
51 |
print("Finished indexing")
|
52 |
|
@@ -98,16 +98,16 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
|
98 |
|
99 |
gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2)
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
|
112 |
|
113 |
btn.click(image_search,
|
|
|
13 |
import matplotlib.cm as cm
|
14 |
import matplotlib as mpl
|
15 |
|
16 |
+
img_names = []
|
17 |
+
with open('list_eval_partition.txt', 'r') as f:
|
18 |
+
for line in f:
|
19 |
+
img_name, dtype = line.rstrip().split(' ')
|
20 |
+
img_names.append(img_name)
|
21 |
|
22 |
|
23 |
# For a model pretrained on VGGFace2
|
24 |
print('Loading model weights ........')
|
25 |
|
26 |
+
class SiameseModel(nn.Module):
|
27 |
+
def __init__(self):
|
28 |
+
super().__init__()
|
29 |
+
self.backbone = InceptionResnetV1(pretrained='vggface2')
|
30 |
+
def forward(self, x):
|
31 |
+
x = self.backbone(x)
|
32 |
+
x = torch.nn.functional.normalize(x, dim=1)
|
33 |
+
return x
|
34 |
|
35 |
+
model = SiameseModel()
|
36 |
+
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
|
37 |
+
model.eval()
|
38 |
|
39 |
|
40 |
# Make FAISS index
|
41 |
+
print('Make index .............')
|
42 |
+
index = faiss.IndexFlatL2(512)
|
43 |
|
44 |
+
hf = h5py.File('face_vecs_full.h5', 'r')
|
45 |
+
for key in tqdm.tqdm(hf.keys()):
|
46 |
+
vec = np.array(hf.get(key))
|
47 |
+
index.add(vec)
|
48 |
|
49 |
+
hf.close()
|
50 |
|
51 |
print("Finished indexing")
|
52 |
|
|
|
98 |
|
99 |
gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2)
|
100 |
|
101 |
+
img_dir = './img_align_celeba'
|
102 |
+
examples = random.choices(img_names, k=6)
|
103 |
+
examples = [os.path.join(img_dir, ex) for ex in examples]
|
104 |
+
examples = [Image.open(img) for img in examples]
|
105 |
|
106 |
+
with gr.Row():
|
107 |
+
gr.Examples(
|
108 |
+
examples = examples,
|
109 |
+
inputs = image
|
110 |
+
)
|
111 |
|
112 |
|
113 |
btn.click(image_search,
|