Spaces:
Running
Running
File size: 2,038 Bytes
42cfb33 50d0f30 42cfb33 c7b34db 42cfb33 5bf9861 42cfb33 5bf9861 42cfb33 7fc771d 6c61156 7fc771d 5bf9861 6c61156 7fc771d 6c61156 7fc771d 42cfb33 6c61156 50d0f30 42cfb33 c7b34db 42cfb33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
import torch
# Load CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Define a list of target words for the game
words = ["cat", "car", "tree", "house", "dog", "cloud", "flower", "bicycle", "boat", "star", "bird", "fish", "sun"]
# Precompute text embeddings for faster comparisons
text_inputs = processor(text=words, return_tensors="pt", padding=True)
with torch.no_grad():
text_features = model.get_text_features(**text_inputs)
# Define the function to process drawing and make a prediction
def guess_drawing(drawing):
# Access the image data from the 'background' key
if 'background' in drawing:
image_array = np.array(drawing['background'], dtype=np.uint8)
else:
return "Invalid drawing format. Unable to process."
# Convert to RGB PIL image to ensure compatibility with CLIP
image = Image.fromarray(image_array).convert("RGB")
# Prepare the image for the model
image_inputs = processor(images=image, return_tensors="pt")
# Get image features from the model
with torch.no_grad():
image_features = model.get_image_features(**image_inputs)
# Calculate cosine similarity with each word
similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
# Debug: Print similarity scores for each word
for word, score in zip(words, similarity.tolist()):
print(f"Similarity score for '{word}': {score}")
best_match = words[similarity.argmax().item()]
# Return the AI's best guess
return f"AI's guess: {best_match}"
# Set up Gradio interface
interface = gr.Interface(
fn=guess_drawing,
inputs=gr.Sketchpad(),
outputs="text",
live=True,
description="Draw cat, car, tree, house, dog, cloud, flower, bicycle, boat, star, bird, fish, sun"
)
interface.launch()
|