File size: 4,254 Bytes
f82d341
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import pre_reqs
import cv2
import numpy as np
import streamlit as st
from mmcls.apis import init_model
from mmcls.apis import inference_model_topk as inference_cls_model
from mmdet.registry import VISUALIZERS
# from mmcls.utils import register_all_modules as register_all_modules_cls
from mmdet.apis import init_detector, inference_detector
from mmdet.utils import register_all_modules as register_all_modules_det
import pandas as pd
from PIL import Image

st.set_page_config(page_title="πŸ”₯ An Image Demo", page_icon="πŸ”₯", layout='wide')
st.markdown("# πŸ”₯ An Image Demo")
st.write(
    ":dog: Try uploading an image to get the possible categories, objects."
)
st.sidebar.header("An Image Demo")
my_upload = st.sidebar.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
col1, col2, col3 = st.columns(3)
model_option = st.radio(
    "What\'s your inference model",
    ('cls', 'det'))

parent_folder = './'
topk = st.slider('Return top-k predictions', 1, 10, 3)

@st.cache_resource
def _init_model(model_option):
    if model_option == 'cls':
    # init model
        model = init_model(parent_folder + 'configs/resnet/resnet50_8xb32_in1k.py',
                       'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth')
        visualizer = None
    elif model_option == 'det':
        # register_all_modules_det()
        model = init_detector(parent_folder + 'configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py',
                          'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth', device='cpu')
        visualizer = VISUALIZERS.build(model.cfg.visualizer)
        visualizer.dataset_meta = model.dataset_meta
    else:
        model = None
        visualizer = None
    return model, visualizer

@st.cache_data
def _get_image(my_upload=my_upload):
    if my_upload is not None:
        img_file = my_upload
    else:
        img_file = parent_folder + "images/zebra.jpg"
    return Image.open(img_file).convert('RGB')

# @st.cache_resource
def _inference_model(img, model, visualizer, model_option):
    img = np.array(img)
    if model_option == 'cls':
        return_results = inference_cls_model(model, img, 10)
        vis_img = img
    elif model_option == 'det':
        vis_img = img.copy()
        results = inference_detector(model, img)
        # import pdb
        # pdb.set_trace()
        b, h, w = results.pred_instances.masks.shape
        vis_img = cv2.resize(vis_img, (w, h))
        visualizer.add_datasample(
            name='result',
            image=vis_img,
            data_sample=results,
            draw_gt=False,
            show=False)
        vis_img = visualizer.get_image()
        cls_names = visualizer.dataset_meta['classes']
        return_results = {'scores': results.pred_instances.scores[:10].numpy(),
                          'bboxes': results.pred_instances.bboxes[:10].numpy(),
                          'labels': [cls_names[x.item()] for x in results.pred_instances.labels[:10]]
                          }
    return return_results, vis_img


def plot_canvas(img, vis_img, results, model_option):
    col1.write("Original Image :camera:")
    col1.image(img)

    col2.write("Visualization:wrench:")
    col3.write("Metainfo:wrench:")
    if model_option == 'cls':
        col2.image(vis_img)
        df = pd.DataFrame({
            'category': results["pred_class"][:topk],
            'probability': [f"{x:.2f}" for x in results["pred_score"]][:topk]
            }, index=None)
        col3.dataframe(df)
    elif model_option == 'det':
        # vis_idx = st.slider('Show a prediction', 1, 10, 3, disabled=True)
        col2.image(vis_img)

        df = pd.DataFrame({
            'category': results["labels"][:topk],
            'probability': [f"{x:.2f}" for x in results["scores"]][:topk],
            'box': [list(map(lambda t: f"{t:.2f}", list(x))) for x in results["bboxes"][:topk]]
        }, index=None)
        col3.dataframe(df)


model, visualizer = _init_model(model_option)
img = _get_image(my_upload)
results, vis_img = _inference_model(img, model, visualizer, model_option)
plot_canvas(img, vis_img, results, model_option)