tasmiachow commited on
Commit
50d0f30
1 Parent(s): 42cfb33
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from transformers import CLIPProcessor, CLIPModel
3
  from PIL import Image
 
4
  import torch
5
 
6
  # Load CLIP model and processor
@@ -17,17 +18,28 @@ with torch.no_grad():
17
 
18
 
19
  def guess_drawing(drawing):
20
- image = Image.fromarray(drawing) # Convert drawing to PIL image
 
 
 
 
 
 
 
21
  image_inputs = processor(images=image, return_tensors="pt")
 
 
22
  with torch.no_grad():
23
  image_features = model.get_image_features(**image_inputs)
24
 
25
  # Calculate cosine similarity with each word
26
  similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
27
  best_match = words[similarity.argmax().item()]
 
 
28
  return f"AI's guess: {best_match}"
29
 
30
-
31
  interface = gr.Interface(
32
  fn=guess_drawing,
33
  inputs=gr.Sketchpad(),
 
1
  import gradio as gr
2
  from transformers import CLIPProcessor, CLIPModel
3
  from PIL import Image
4
+ import numpy as np
5
  import torch
6
 
7
  # Load CLIP model and processor
 
18
 
19
 
20
  def guess_drawing(drawing):
21
+
22
+ drawing_data = drawing['data']
23
+ image_array = np.array(drawing_data, dtype=np.uint8)
24
+
25
+
26
+ image = Image.fromarray(image_array)
27
+
28
+
29
  image_inputs = processor(images=image, return_tensors="pt")
30
+
31
+
32
  with torch.no_grad():
33
  image_features = model.get_image_features(**image_inputs)
34
 
35
  # Calculate cosine similarity with each word
36
  similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
37
  best_match = words[similarity.argmax().item()]
38
+
39
+ # Return the AI's best guess
40
  return f"AI's guess: {best_match}"
41
 
42
+ # Set up Gradio interface
43
  interface = gr.Interface(
44
  fn=guess_drawing,
45
  inputs=gr.Sketchpad(),