Spaces:
Running
Running
Refactor
Browse files
app.py
CHANGED
@@ -59,7 +59,7 @@ def get_image_features(model, image_dir):
|
|
59 |
|
60 |
loader = torch.utils.data.DataLoader(
|
61 |
dataset,
|
62 |
-
batch_size=
|
63 |
shuffle=False,
|
64 |
num_workers=4,
|
65 |
drop_last=False,
|
@@ -103,7 +103,8 @@ def text_encoder(text, tokenizer):
|
|
103 |
return jnp.expand_dims(embedding, axis=0)
|
104 |
|
105 |
|
106 |
-
|
|
|
107 |
image_features = []
|
108 |
for i, (images) in enumerate(tqdm(loader)):
|
109 |
images = images.permute(0, 2, 3, 1).numpy()
|
@@ -145,8 +146,32 @@ if query:
|
|
145 |
"dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
|
146 |
)
|
147 |
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
-
image_paths = find_image(query, dataset, tokenizer, image_features, n=
|
151 |
|
152 |
st.image(image_paths)
|
|
|
59 |
|
60 |
loader = torch.utils.data.DataLoader(
|
61 |
dataset,
|
62 |
+
batch_size=16,
|
63 |
shuffle=False,
|
64 |
num_workers=4,
|
65 |
drop_last=False,
|
|
|
103 |
return jnp.expand_dims(embedding, axis=0)
|
104 |
|
105 |
|
106 |
+
@st.cache
|
107 |
+
def precompute_image_features(model, loader):
|
108 |
image_features = []
|
109 |
for i, (images) in enumerate(tqdm(loader)):
|
110 |
images = images.permute(0, 2, 3, 1).numpy()
|
|
|
146 |
"dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
|
147 |
)
|
148 |
|
149 |
+
image_size = model.config.vision_config.image_size
|
150 |
+
|
151 |
+
val_preprocess = transforms.Compose(
|
152 |
+
[
|
153 |
+
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
154 |
+
CenterCrop(image_size),
|
155 |
+
ToTensor(),
|
156 |
+
Normalize(
|
157 |
+
(0.48145466, 0.4578275, 0.40821073),
|
158 |
+
(0.26862954, 0.26130258, 0.27577711),
|
159 |
+
),
|
160 |
+
]
|
161 |
+
)
|
162 |
+
|
163 |
+
dataset = CustomDataSet("photos/", transform=val_preprocess)
|
164 |
+
|
165 |
+
loader = torch.utils.data.DataLoader(
|
166 |
+
dataset,
|
167 |
+
batch_size=16,
|
168 |
+
shuffle=False,
|
169 |
+
num_workers=2,
|
170 |
+
drop_last=False,
|
171 |
+
)
|
172 |
+
|
173 |
+
image_features = precompute_image_features(model, loader)
|
174 |
|
175 |
+
image_paths = find_image(query, dataset, tokenizer, image_features, n=2)
|
176 |
|
177 |
st.image(image_paths)
|