File size: 2,038 Bytes
42cfb33
 
 
50d0f30
42cfb33
 
 
 
 
 
 
c7b34db
 
42cfb33
5bf9861
42cfb33
 
 
 
5bf9861
42cfb33
7fc771d
 
6c61156
7fc771d
 
5bf9861
6c61156
 
7fc771d
 
 
 
 
 
 
 
 
 
6c61156
 
 
 
 
7fc771d
 
 
 
42cfb33
6c61156
50d0f30
42cfb33
 
 
 
 
c7b34db
42cfb33
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
import torch

# Load CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Define a list of target words for the game
words = ["cat", "car", "tree", "house", "dog", "cloud", "flower", "bicycle", "boat", "star", "bird", "fish", "sun"]


# Precompute text embeddings for faster comparisons
text_inputs = processor(text=words, return_tensors="pt", padding=True)
with torch.no_grad():
    text_features = model.get_text_features(**text_inputs)

# Define the function to process drawing and make a prediction
def guess_drawing(drawing):
    # Access the image data from the 'background' key
    if 'background' in drawing:
        image_array = np.array(drawing['background'], dtype=np.uint8)
    else:
        return "Invalid drawing format. Unable to process."

    # Convert to RGB PIL image to ensure compatibility with CLIP
    image = Image.fromarray(image_array).convert("RGB")

    # Prepare the image for the model
    image_inputs = processor(images=image, return_tensors="pt")

    # Get image features from the model
    with torch.no_grad():
        image_features = model.get_image_features(**image_inputs)

    # Calculate cosine similarity with each word
    similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
    
    # Debug: Print similarity scores for each word
    for word, score in zip(words, similarity.tolist()):
        print(f"Similarity score for '{word}': {score}")
    
    best_match = words[similarity.argmax().item()]

    # Return the AI's best guess
    return f"AI's guess: {best_match}"


# Set up Gradio interface
interface = gr.Interface(
    fn=guess_drawing,
    inputs=gr.Sketchpad(),
    outputs="text",
    live=True,
    description="Draw cat, car, tree, house, dog, cloud, flower, bicycle, boat, star, bird, fish, sun"
)

interface.launch()