Spaces:
Running
Running
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
import numpy as np | |
import torch | |
# Load CLIP model and processor | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# Define a list of target words for the game | |
words = ["cat", "car", "tree", "house", "dog", "cloud", "flower", "bicycle", "boat", "star", "bird", "fish", "sun"] | |
# Precompute text embeddings for faster comparisons | |
text_inputs = processor(text=words, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
text_features = model.get_text_features(**text_inputs) | |
# Define the function to process drawing and make a prediction | |
def guess_drawing(drawing): | |
# Access the image data from the 'background' key | |
if 'background' in drawing: | |
image_array = np.array(drawing['background'], dtype=np.uint8) | |
else: | |
return "Invalid drawing format. Unable to process." | |
# Convert to RGB PIL image to ensure compatibility with CLIP | |
image = Image.fromarray(image_array).convert("RGB") | |
# Prepare the image for the model | |
image_inputs = processor(images=image, return_tensors="pt") | |
# Get image features from the model | |
with torch.no_grad(): | |
image_features = model.get_image_features(**image_inputs) | |
# Calculate cosine similarity with each word | |
similarity = torch.nn.functional.cosine_similarity(image_features, text_features) | |
# Debug: Print similarity scores for each word | |
for word, score in zip(words, similarity.tolist()): | |
print(f"Similarity score for '{word}': {score}") | |
best_match = words[similarity.argmax().item()] | |
# Return the AI's best guess | |
return f"AI's guess: {best_match}" | |
# Set up Gradio interface | |
interface = gr.Interface( | |
fn=guess_drawing, | |
inputs=gr.Sketchpad(), | |
outputs="text", | |
live=True, | |
description="Draw cat, car, tree, house, dog, cloud, flower, bicycle, boat, star, bird, fish, sun" | |
) | |
interface.launch() | |