Ashish Soni
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
import clip
|
4 |
from PIL import Image
|
|
|
5 |
import numpy as np
|
6 |
|
7 |
# Load CLIP model and preprocessing
|
@@ -9,17 +10,34 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
9 |
model, preprocess = clip.load("ViT-B/32", device=device)
|
10 |
|
11 |
# Function to predict descriptions and probabilities
|
12 |
-
def predict(image, descriptions):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
image = preprocess(image).unsqueeze(0).to(device)
|
|
|
14 |
text = clip.tokenize(descriptions).to(device)
|
15 |
|
16 |
with torch.no_grad():
|
|
|
17 |
image_features = model.encode_image(image)
|
18 |
text_features = model.encode_text(text)
|
19 |
|
|
|
20 |
logits_per_image, logits_per_text = model(image, text)
|
|
|
21 |
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
22 |
|
|
|
23 |
return descriptions[np.argmax(probs)], np.max(probs)
|
24 |
|
25 |
# Streamlit app
|
|
|
2 |
import torch
|
3 |
import clip
|
4 |
from PIL import Image
|
5 |
+
from typing import List, Tuple
|
6 |
import numpy as np
|
7 |
|
8 |
# Load CLIP model and preprocessing
|
|
|
10 |
model, preprocess = clip.load("ViT-B/32", device=device)
|
11 |
|
12 |
# Function to predict descriptions and probabilities
|
13 |
+
def predict(image: Image.Image, descriptions: List[str]) -> Tuple[str, float]:
|
14 |
+
"""
|
15 |
+
Predict the best matching description for the provided image based on the given descriptions.
|
16 |
+
Uses the CLIP model to compute similarities between the image and text descriptions.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
image (Image.Image): The input image for which the descriptions are being evaluated.
|
20 |
+
descriptions (List[str]): A list of textual descriptions to compare against the image.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Tuple[str, float]: A tuple containing the best-matching description and the corresponding probability.
|
24 |
+
"""
|
25 |
+
# Preprocess the image and move it to the appropriate device
|
26 |
image = preprocess(image).unsqueeze(0).to(device)
|
27 |
+
# Tokenize the descriptions and move them to the appropriate device
|
28 |
text = clip.tokenize(descriptions).to(device)
|
29 |
|
30 |
with torch.no_grad():
|
31 |
+
# Encode image and text features using the CLIP model
|
32 |
image_features = model.encode_image(image)
|
33 |
text_features = model.encode_text(text)
|
34 |
|
35 |
+
# Compute the similarity scores (logits) between image and text
|
36 |
logits_per_image, logits_per_text = model(image, text)
|
37 |
+
# Convert logits to probabilities
|
38 |
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
39 |
|
40 |
+
# Return the description with the highest probability and the corresponding probability
|
41 |
return descriptions[np.argmax(probs)], np.max(probs)
|
42 |
|
43 |
# Streamlit app
|