Spaces:
Runtime error
Runtime error
Samuel Schmidt
commited on
Commit
·
a329e3c
1
Parent(s):
b56a778
Removed print statements, increased batch size
Browse files- src/CLIP.py +0 -6
- src/app.py +1 -2
src/CLIP.py
CHANGED
@@ -16,13 +16,7 @@ class CLIPImageEncoder:
|
|
16 |
|
17 |
def encode_images(self, batch):
|
18 |
images = batch["image"]
|
19 |
-
print(images)
|
20 |
input = self.processor(images=images, return_tensors="pt")
|
21 |
-
print(input)
|
22 |
with torch.no_grad():
|
23 |
image_features = self.model.get_image_features(**input)
|
24 |
-
#image_features = self.model(**input).last_hidden_state[:,0].cpu()
|
25 |
-
print(image_features)
|
26 |
-
print("--------------------")
|
27 |
-
print(self.model.get_image_features(**input).cpu().detach().numpy())
|
28 |
return {"clip_embeddings": image_features.cpu().detach().numpy()}
|
|
|
16 |
|
17 |
def encode_images(self, batch):
|
18 |
images = batch["image"]
|
|
|
19 |
input = self.processor(images=images, return_tensors="pt")
|
|
|
20 |
with torch.no_grad():
|
21 |
image_features = self.model.get_image_features(**input)
|
|
|
|
|
|
|
|
|
22 |
return {"clip_embeddings": image_features.cpu().detach().numpy()}
|
src/app.py
CHANGED
@@ -19,7 +19,7 @@ def emb_dataset(dataset):
|
|
19 |
|
20 |
## CLIP Embeddings
|
21 |
clip_model = CLIPImageEncoder()
|
22 |
-
dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=
|
23 |
|
24 |
# Add index
|
25 |
dataset_with_embeddings.add_faiss_index(column='color_embeddings')
|
@@ -32,7 +32,6 @@ def emb_dataset(dataset):
|
|
32 |
dataset_with_embeddings = emb_dataset(candidate_subset)
|
33 |
|
34 |
# Main function, to find similar images
|
35 |
-
# TODO: allow different descriptor/embedding functions
|
36 |
# TODO: implement different distance measures
|
37 |
|
38 |
def get_neighbors(query_image, selected_descriptor, top_k=5):
|
|
|
19 |
|
20 |
## CLIP Embeddings
|
21 |
clip_model = CLIPImageEncoder()
|
22 |
+
dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=16)
|
23 |
|
24 |
# Add index
|
25 |
dataset_with_embeddings.add_faiss_index(column='color_embeddings')
|
|
|
32 |
dataset_with_embeddings = emb_dataset(candidate_subset)
|
33 |
|
34 |
# Main function, to find similar images
|
|
|
35 |
# TODO: implement different distance measures
|
36 |
|
37 |
def get_neighbors(query_image, selected_descriptor, top_k=5):
|