Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
# Load model and processor | |
model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") | |
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") | |
def calculate_similarity(image, text_prompt, similarity_type): | |
# Process inputs | |
inputs = processor(images=image, text=text_prompt, return_tensors="pt", padding=True) | |
# Forward pass | |
outputs = model(**inputs) | |
# Normalize and calculate cosine similarity | |
image_features = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True) | |
text_features = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True) | |
cosine_similarity = torch.nn.functional.cosine_similarity(image_features, text_features) | |
# Adjusting the similarity score based on the dropdown selection | |
if similarity_type == "General Similarity (3x scaled)": | |
adjusted_similarity = cosine_similarity.item() * 3 * 100 | |
result_text = f"According to OpenCLIP, the image and the text prompt have a general similarity of {min(adjusted_similarity, 99.99):.2f}%." | |
else: # Cosine Similarity (raw) | |
result_text = f"According to OpenCLIP, the image and the text prompt have a cosine similarity of {cosine_similarity.item() * 100:.2f}%." | |
return result_text | |
# Set up Gradio interface | |
iface = gr.Interface( | |
fn=calculate_similarity, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image", height=512), | |
gr.Textbox(label="Text Prompt"), | |
gr.Dropdown(label="Similarity Type", choices=["General Similarity (3x scaled)", "Cosine Similarity (raw)"], value="General Similarity (3x scaled)") | |
], | |
outputs=gr.Text(), | |
allow_flagging="never", | |
title="OpenClip Similarity Calculator", | |
description="Upload an image and provide a text prompt to calculate the similarity." | |
) | |
# Launch the interface with a public link for sharing online | |
iface.launch(share=True) | |