|
import streamlit as st |
|
import torch |
|
import clip |
|
from PIL import Image |
|
from typing import List, Tuple |
|
import numpy as np |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
|
|
def predict(image: Image.Image, descriptions: List[str]) -> Tuple[str, float]: |
|
""" |
|
Predict the best matching description for the provided image based on the given descriptions. |
|
Uses the CLIP model to compute similarities between the image and text descriptions. |
|
|
|
Args: |
|
image (Image.Image): The input image for which the descriptions are being evaluated. |
|
descriptions (List[str]): A list of textual descriptions to compare against the image. |
|
|
|
Returns: |
|
Tuple[str, float]: A tuple containing the best-matching description and the corresponding probability. |
|
""" |
|
|
|
image = preprocess(image).unsqueeze(0).to(device) |
|
|
|
text = clip.tokenize(descriptions).to(device) |
|
|
|
with torch.no_grad(): |
|
|
|
image_features = model.encode_image(image) |
|
text_features = model.encode_text(text) |
|
|
|
|
|
logits_per_image, logits_per_text = model(image, text) |
|
|
|
probs = logits_per_image.softmax(dim=-1).cpu().numpy() |
|
|
|
|
|
return descriptions[np.argmax(probs)], np.max(probs) |
|
|
|
|
|
def main(): |
|
st.title("Image understanding model") |
|
|
|
|
|
st.markdown("---") |
|
st.markdown("### Upload an image to test how well the model understands it") |
|
|
|
|
|
uploaded_image = st.file_uploader("Upload an image...", type=["jpg", "png", "jpeg"], key="uploaded_image") |
|
|
|
if uploaded_image is not None: |
|
|
|
pil_image = Image.open(uploaded_image) |
|
|
|
|
|
st.image(pil_image, caption="Uploaded Image.", use_column_width=True, width=200) |
|
|
|
|
|
st.markdown("### 2 Lies and 1 Truth") |
|
st.markdown("Write 3 descriptions about the image, 1 must be true.") |
|
|
|
|
|
description1 = st.text_input("Description 1:", placeholder='A red apple') |
|
description2 = st.text_input("Description 2:", placeholder='A car parked in a garage') |
|
description3 = st.text_input("Description 3:", placeholder='An orange fruit on a tree') |
|
|
|
descriptions = [description1, description2, description3] |
|
|
|
|
|
if st.button("Predict"): |
|
if all(descriptions): |
|
|
|
best_description, best_prob = predict(pil_image, descriptions) |
|
|
|
|
|
st.write(f"**Best Description:** {best_description}") |
|
st.write(f"**Prediction Probability:** {best_prob:.2%}") |
|
|
|
|
|
st.progress(float(best_prob)) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|