File size: 2,433 Bytes
c209d46
 
 
 
 
 
092b75f
 
 
 
c209d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbf1542
c209d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
092b75f
c209d46
092b75f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import streamlit as st
import torch
import torchvision.transforms as T
from PIL import Image
from module import myModule

CLASS_TO_IDX = ['AMMAN', 'AYYAPPA', 'BHAIRAV', 'BRAHMA', 'BUDDHA', 'DURGA', 'GANESHA', 'HANUMAN', 'KAALI', 
                'KRISHNA', 'KURMA', 'LAKSHMI', 'LINGA', 'MATSYA', 'MURUGA', 'NARASIMHA', 'NATARAJA', 'PARASURAMA', 
                'RAMA', 'SARASWATI', 'SHIVA', 'THIRTHANKARA', 'VAMANA', 'VARAHA', 'VISHNU']

IMG_SIZE = (224, 224)
STATS = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# Define the transformation for the input image
TTA_TRANSFORM = T.Compose([
    T.Resize(IMG_SIZE),
    T.AutoAugment(),
    T.ToTensor(),
    T.Normalize(**STATS)
])


st.set_page_config(
    page_title="Identify the deity using Computer Vision.",
    layout="centered",
    initial_sidebar_state="collapsed",
    menu_items={
        'Get Help': 'https://www.extremelycoolapp.com/help',
        'Report a bug': "https://www.extremelycoolapp.com/bug",
        'About': "# This is an *extremely* cool app!"
    }
)


st.title(":sparkles: I:orange[deity]fy")
st.header("Discover the deity with a snap.")

model = myModule.load_from_checkpoint("vit_base_clip_rank4.ckpt")
model.to("cpu")
model.eval()


# Function to make predictions
def predict(image):
    # Load and preprocess the input image
    with Image.open(image).convert('RGB') as img:
        img_tensor = torch.stack([TTA_TRANSFORM(img) for img in [img for _ in range(10)]])
        img_tensor = torch.mean(img_tensor, dim=0).unsqueeze(0)

    # Make a prediction
    with torch.no_grad():
        logits = model(img_tensor)

    # Get the top 3 predictions and their probabilities
    probs = torch.softmax(logits, dim=1)
    topk = torch.topk(probs, k=3)
    values, indices = topk.values, topk.indices

    values = values.squeeze().cpu().numpy().tolist()
    indices = indices.cpu().squeeze().numpy().tolist()

    return values, indices


# Upload image through Streamlit
img = st.file_uploader(label='choose a file', type=['png', 'jpg', 'jpeg'], label_visibility="hidden")


if img is not None:

    # Make predictions when the user clicks the "Predict" button
    if st.button("Predict"):
        values, indices = predict(img)
        classes = [CLASS_TO_IDX[index] for index in indices]
        # Display the top 3 predictions as a bar chart
        st.bar_chart({label: prob for label, prob in zip(classes, values)}, color="#FFC101")