tasmiachow commited on
Commit
6c61156
1 Parent(s): 7fc771d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -20,12 +20,12 @@ with torch.no_grad():
20
  def guess_drawing(drawing):
21
  # Access the image data from the 'background' key
22
  if 'background' in drawing:
23
- image_array = np.array(drawing['background'], dtype=np.uint8) # Extract the background data as an array
24
  else:
25
  return "Invalid drawing format. Unable to process."
26
 
27
- # Convert to PIL image
28
- image = Image.fromarray(image_array)
29
 
30
  # Prepare the image for the model
31
  image_inputs = processor(images=image, return_tensors="pt")
@@ -36,11 +36,17 @@ def guess_drawing(drawing):
36
 
37
  # Calculate cosine similarity with each word
38
  similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
 
 
 
 
 
39
  best_match = words[similarity.argmax().item()]
40
 
41
  # Return the AI's best guess
42
  return f"AI's guess: {best_match}"
43
 
 
44
  # Set up Gradio interface
45
  interface = gr.Interface(
46
  fn=guess_drawing,
 
20
  def guess_drawing(drawing):
21
  # Access the image data from the 'background' key
22
  if 'background' in drawing:
23
+ image_array = np.array(drawing['background'], dtype=np.uint8)
24
  else:
25
  return "Invalid drawing format. Unable to process."
26
 
27
+ # Convert to RGB PIL image to ensure compatibility with CLIP
28
+ image = Image.fromarray(image_array).convert("RGB")
29
 
30
  # Prepare the image for the model
31
  image_inputs = processor(images=image, return_tensors="pt")
 
36
 
37
  # Calculate cosine similarity with each word
38
  similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
39
+
40
+ # Debug: Print similarity scores for each word
41
+ for word, score in zip(words, similarity.tolist()):
42
+ print(f"Similarity score for '{word}': {score}")
43
+
44
  best_match = words[similarity.argmax().item()]
45
 
46
  # Return the AI's best guess
47
  return f"AI's guess: {best_match}"
48
 
49
+
50
  # Set up Gradio interface
51
  interface = gr.Interface(
52
  fn=guess_drawing,