File size: 3,891 Bytes
f82d341
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import pre_reqs
import glob
import os.path

import cv2
import numpy as np
import streamlit as st
from mmcls.apis import init_model
from mmcls.apis import inference_model_topk as inference_cls_model
from mmdet.registry import VISUALIZERS
# from mmcls.utils import register_all_modules as register_all_modules_cls
from mmdet.apis import init_detector, inference_detector
from mmdet.utils import register_all_modules as register_all_modules_det
import pandas as pd
from PIL import Image

st.set_page_config(page_title="πŸ“· A Folder Demo", page_icon="πŸ“·", layout='wide')
st.markdown("# πŸ“· A Folder Demo")
st.write(
    ":dog: Try uploading multi images to get the possible categories, objects."
)
st.sidebar.header("A Folder Demo")
my_upload = st.sidebar.file_uploader("Upload multi images",  type=["png", "jpg", "jpeg"], accept_multiple_files=True)

col1, col2 = st.columns(2)
parent_folder = './'

@st.cache_resource
def _init_model_return_results(imgs):
    cls_model = init_model(parent_folder + 'configs/resnet/resnet50_8xb32_in1k.py',
                       'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth')
    imgs = [np.array(x) for x in imgs]
    results = {}

    for idx, img in enumerate(imgs):
        return_results = inference_cls_model(cls_model, img, 5)
        results[idx] = set(np.array(return_results["pred_class"])[return_results["pred_score"] > 0.35])

        det_model = init_detector(parent_folder + 'configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py',
                                  'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth',
                                  device='cpu')
        dataset_meta = det_model.dataset_meta
        return_results = inference_detector(det_model, img)
        cls_names = dataset_meta['classes']
        scores = return_results.pred_instances.scores.numpy()[:10]
        labels = np.array([cls_names[x.item()] for x in return_results.pred_instances.labels[:10]])
        results[idx] |= set(labels[scores > 0.35])

    class2idx = {}
    for k, v in results.items():
        for sub_v in v:
            class2idx[sub_v] = class2idx.get(sub_v, []) + [k]
    return results, class2idx


@st.cache_data
def _get_image(my_upload):
    if len(my_upload):
        img_files = my_upload
        if isinstance(img_files, str):
            img_files = [img_files]
        file_names = [os.path.basename(x.name).split('.')[0][:8] for x in img_files]
    else:
        img_files = glob.glob(parent_folder + "/images/*.jpg") + glob.glob(parent_folder + "/images/*.png")
        file_names = [os.path.basename(x).split('.')[0][:8] for x in img_files]

    return [Image.open(img_file).convert('RGB') for img_file in img_files], file_names


def plot_canvas(imgs, results, file_names, class2idx):
    col1.write("Original Images :camera:")
    col2.write("Filtered Images :wrench:")

    tabs = col1.tabs(file_names)
    for idx, tab in enumerate(tabs):
        tab.image(imgs[idx], width=400)

    all_classes = set()
    for x in results.values():
        all_classes |= x
    all_classes = list(all_classes)
    options = st.multiselect(
        'Select the classes:',
        all_classes)

    select_idx = set(range(len(file_names)))
    for idx, op in enumerate(options):
        select_idx &= set(class2idx[op])

    select_idx = np.array(list(select_idx))
    if len(select_idx):
        names = np.array(file_names)[select_idx].tolist()
        tabs = col2.tabs(names)
        for idx, tab in enumerate(tabs):
            tabs[idx].image(imgs[select_idx[idx]], width=400)
            tabs[idx].write(', '.join(results[select_idx[idx]]))


imgs, file_names = _get_image(my_upload)
results, class2idx = _init_model_return_results(imgs)
plot_canvas(imgs, results, file_names, class2idx)