Samuel Schmidt commited on
Commit
a329e3c
·
1 Parent(s): b56a778

Removed print statements, increased batch size

Browse files
Files changed (2) hide show
  1. src/CLIP.py +0 -6
  2. src/app.py +1 -2
src/CLIP.py CHANGED
@@ -16,13 +16,7 @@ class CLIPImageEncoder:
16
 
17
  def encode_images(self, batch):
18
  images = batch["image"]
19
- print(images)
20
  input = self.processor(images=images, return_tensors="pt")
21
- print(input)
22
  with torch.no_grad():
23
  image_features = self.model.get_image_features(**input)
24
- #image_features = self.model(**input).last_hidden_state[:,0].cpu()
25
- print(image_features)
26
- print("--------------------")
27
- print(self.model.get_image_features(**input).cpu().detach().numpy())
28
  return {"clip_embeddings": image_features.cpu().detach().numpy()}
 
16
 
17
  def encode_images(self, batch):
18
  images = batch["image"]
 
19
  input = self.processor(images=images, return_tensors="pt")
 
20
  with torch.no_grad():
21
  image_features = self.model.get_image_features(**input)
 
 
 
 
22
  return {"clip_embeddings": image_features.cpu().detach().numpy()}
src/app.py CHANGED
@@ -19,7 +19,7 @@ def emb_dataset(dataset):
19
 
20
  ## CLIP Embeddings
21
  clip_model = CLIPImageEncoder()
22
- dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=8)
23
 
24
  # Add index
25
  dataset_with_embeddings.add_faiss_index(column='color_embeddings')
@@ -32,7 +32,6 @@ def emb_dataset(dataset):
32
  dataset_with_embeddings = emb_dataset(candidate_subset)
33
 
34
  # Main function, to find similar images
35
- # TODO: allow different descriptor/embedding functions
36
  # TODO: implement different distance measures
37
 
38
  def get_neighbors(query_image, selected_descriptor, top_k=5):
 
19
 
20
  ## CLIP Embeddings
21
  clip_model = CLIPImageEncoder()
22
+ dataset_with_embeddings = dataset_with_embeddings.map(clip_model.encode_images, batched=True, batch_size=16)
23
 
24
  # Add index
25
  dataset_with_embeddings.add_faiss_index(column='color_embeddings')
 
32
  dataset_with_embeddings = emb_dataset(candidate_subset)
33
 
34
  # Main function, to find similar images
 
35
  # TODO: implement different distance measures
36
 
37
  def get_neighbors(query_image, selected_descriptor, top_k=5):