Ashish Soni commited on
Commit
dca16ec
·
1 Parent(s): 68d8fd5

streamlit app

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ # Load CLIP model and preprocessing
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model, preprocess = clip.load("ViT-B/32", device=device)
10
+
11
+ # Function to predict descriptions and probabilities
12
+ def predict(image, descriptions):
13
+ image = preprocess(image).unsqueeze(0).to(device)
14
+ text = clip.tokenize(descriptions).to(device)
15
+
16
+ with torch.no_grad():
17
+ image_features = model.encode_image(image)
18
+ text_features = model.encode_text(text)
19
+
20
+ logits_per_image, logits_per_text = model(image, text)
21
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
22
+
23
+ return descriptions[np.argmax(probs)], np.max(probs)
24
+
25
+ # Streamlit app
26
+ def main():
27
+ st.title("Image understanding model test")
28
+
29
+ # Instructions for the user
30
+ st.markdown("---")
31
+ st.markdown("### Upload an image to test how well the model understands it")
32
+
33
+ # Upload image through Streamlit with a unique key
34
+ uploaded_image = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"], key="uploaded_image")
35
+
36
+ if uploaded_image is not None:
37
+ # Convert the uploaded image to PIL Image
38
+ pil_image = Image.open(uploaded_image)
39
+
40
+ # Limit the height of the displayed image to 400px
41
+ st.image(pil_image, caption="Uploaded Image.", use_column_width=True, width=200)
42
+
43
+ # Instructions for the user
44
+ st.markdown("### 2 Lies and 1 Truth")
45
+ st.markdown("Write 3 descriptions about the image, 1 must be true.")
46
+
47
+ # Get user input for descriptions
48
+ description1 = st.text_input("Description 1:", placeholder='A red apple')
49
+ description2 = st.text_input("Description 2:", placeholder='A car parked in a garage')
50
+ description3 = st.text_input("Description 3:", placeholder='An orange fruit on a tree')
51
+
52
+ descriptions = [description1, description2, description3]
53
+
54
+ # Button to trigger prediction
55
+ if st.button("Predict"):
56
+ if all(descriptions):
57
+ # Make predictions
58
+ best_description, best_prob = predict(pil_image, descriptions)
59
+
60
+ # Display the highest probability description and its probability
61
+ st.write(f"**Best Description:** {best_description}")
62
+ st.write(f"**Prediction Probability:** {best_prob:.2%}")
63
+
64
+ # Display progress bar for the highest probability
65
+ st.progress(float(best_prob))
66
+
67
+ if __name__ == "__main__":
68
+ main()