Ashish Soni
commited on
Commit
·
dca16ec
1
Parent(s):
68d8fd5
streamlit app
Browse files
app.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import clip
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
# Load CLIP model and preprocessing
|
8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
10 |
+
|
11 |
+
# Function to predict descriptions and probabilities
|
12 |
+
def predict(image, descriptions):
|
13 |
+
image = preprocess(image).unsqueeze(0).to(device)
|
14 |
+
text = clip.tokenize(descriptions).to(device)
|
15 |
+
|
16 |
+
with torch.no_grad():
|
17 |
+
image_features = model.encode_image(image)
|
18 |
+
text_features = model.encode_text(text)
|
19 |
+
|
20 |
+
logits_per_image, logits_per_text = model(image, text)
|
21 |
+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
22 |
+
|
23 |
+
return descriptions[np.argmax(probs)], np.max(probs)
|
24 |
+
|
25 |
+
# Streamlit app
|
26 |
+
def main():
|
27 |
+
st.title("Image understanding model test")
|
28 |
+
|
29 |
+
# Instructions for the user
|
30 |
+
st.markdown("---")
|
31 |
+
st.markdown("### Upload an image to test how well the model understands it")
|
32 |
+
|
33 |
+
# Upload image through Streamlit with a unique key
|
34 |
+
uploaded_image = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"], key="uploaded_image")
|
35 |
+
|
36 |
+
if uploaded_image is not None:
|
37 |
+
# Convert the uploaded image to PIL Image
|
38 |
+
pil_image = Image.open(uploaded_image)
|
39 |
+
|
40 |
+
# Limit the height of the displayed image to 400px
|
41 |
+
st.image(pil_image, caption="Uploaded Image.", use_column_width=True, width=200)
|
42 |
+
|
43 |
+
# Instructions for the user
|
44 |
+
st.markdown("### 2 Lies and 1 Truth")
|
45 |
+
st.markdown("Write 3 descriptions about the image, 1 must be true.")
|
46 |
+
|
47 |
+
# Get user input for descriptions
|
48 |
+
description1 = st.text_input("Description 1:", placeholder='A red apple')
|
49 |
+
description2 = st.text_input("Description 2:", placeholder='A car parked in a garage')
|
50 |
+
description3 = st.text_input("Description 3:", placeholder='An orange fruit on a tree')
|
51 |
+
|
52 |
+
descriptions = [description1, description2, description3]
|
53 |
+
|
54 |
+
# Button to trigger prediction
|
55 |
+
if st.button("Predict"):
|
56 |
+
if all(descriptions):
|
57 |
+
# Make predictions
|
58 |
+
best_description, best_prob = predict(pil_image, descriptions)
|
59 |
+
|
60 |
+
# Display the highest probability description and its probability
|
61 |
+
st.write(f"**Best Description:** {best_description}")
|
62 |
+
st.write(f"**Prediction Probability:** {best_prob:.2%}")
|
63 |
+
|
64 |
+
# Display progress bar for the highest probability
|
65 |
+
st.progress(float(best_prob))
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
main()
|