import gradio as gr from transformers import CLIPProcessor, CLIPModel from PIL import Image import numpy as np import torch # Load CLIP model and processor model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Define a list of target words for the game words = ["cat", "car", "tree", "house", "dog"] # Add more words as needed # Precompute text embeddings for faster comparisons text_inputs = processor(text=words, return_tensors="pt", padding=True) with torch.no_grad(): text_features = model.get_text_features(**text_inputs) # Define the function to process drawing and make a prediction def guess_drawing(drawing): # Access the image data from the 'background' key if 'background' in drawing: image_array = np.array(drawing['background'], dtype=np.uint8) else: return "Invalid drawing format. Unable to process." # Convert to RGB PIL image to ensure compatibility with CLIP image = Image.fromarray(image_array).convert("RGB") # Prepare the image for the model image_inputs = processor(images=image, return_tensors="pt") # Get image features from the model with torch.no_grad(): image_features = model.get_image_features(**image_inputs) # Calculate cosine similarity with each word similarity = torch.nn.functional.cosine_similarity(image_features, text_features) # Debug: Print similarity scores for each word for word, score in zip(words, similarity.tolist()): print(f"Similarity score for '{word}': {score}") best_match = words[similarity.argmax().item()] # Return the AI's best guess return f"AI's guess: {best_match}" # Set up Gradio interface interface = gr.Interface( fn=guess_drawing, inputs=gr.Sketchpad(), outputs="text", live=True, description="Draw cat, car, tree, house, dog!" ) interface.launch()