ai-photo-gallery / pages /2_πŸ“·_A_Folder.py
KyanChen's picture
Upload 2 files
f82d341
raw
history blame
No virus
3.89 kB
import pre_reqs
import glob
import os.path
import cv2
import numpy as np
import streamlit as st
from mmcls.apis import init_model
from mmcls.apis import inference_model_topk as inference_cls_model
from mmdet.registry import VISUALIZERS
# from mmcls.utils import register_all_modules as register_all_modules_cls
from mmdet.apis import init_detector, inference_detector
from mmdet.utils import register_all_modules as register_all_modules_det
import pandas as pd
from PIL import Image
st.set_page_config(page_title="πŸ“· A Folder Demo", page_icon="πŸ“·", layout='wide')
st.markdown("# πŸ“· A Folder Demo")
st.write(
":dog: Try uploading multi images to get the possible categories, objects."
)
st.sidebar.header("A Folder Demo")
my_upload = st.sidebar.file_uploader("Upload multi images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
col1, col2 = st.columns(2)
parent_folder = './'
@st.cache_resource
def _init_model_return_results(imgs):
cls_model = init_model(parent_folder + 'configs/resnet/resnet50_8xb32_in1k.py',
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth')
imgs = [np.array(x) for x in imgs]
results = {}
for idx, img in enumerate(imgs):
return_results = inference_cls_model(cls_model, img, 5)
results[idx] = set(np.array(return_results["pred_class"])[return_results["pred_score"] > 0.35])
det_model = init_detector(parent_folder + 'configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py',
'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth',
device='cpu')
dataset_meta = det_model.dataset_meta
return_results = inference_detector(det_model, img)
cls_names = dataset_meta['classes']
scores = return_results.pred_instances.scores.numpy()[:10]
labels = np.array([cls_names[x.item()] for x in return_results.pred_instances.labels[:10]])
results[idx] |= set(labels[scores > 0.35])
class2idx = {}
for k, v in results.items():
for sub_v in v:
class2idx[sub_v] = class2idx.get(sub_v, []) + [k]
return results, class2idx
@st.cache_data
def _get_image(my_upload):
if len(my_upload):
img_files = my_upload
if isinstance(img_files, str):
img_files = [img_files]
file_names = [os.path.basename(x.name).split('.')[0][:8] for x in img_files]
else:
img_files = glob.glob(parent_folder + "/images/*.jpg") + glob.glob(parent_folder + "/images/*.png")
file_names = [os.path.basename(x).split('.')[0][:8] for x in img_files]
return [Image.open(img_file).convert('RGB') for img_file in img_files], file_names
def plot_canvas(imgs, results, file_names, class2idx):
col1.write("Original Images :camera:")
col2.write("Filtered Images :wrench:")
tabs = col1.tabs(file_names)
for idx, tab in enumerate(tabs):
tab.image(imgs[idx], width=400)
all_classes = set()
for x in results.values():
all_classes |= x
all_classes = list(all_classes)
options = st.multiselect(
'Select the classes:',
all_classes)
select_idx = set(range(len(file_names)))
for idx, op in enumerate(options):
select_idx &= set(class2idx[op])
select_idx = np.array(list(select_idx))
if len(select_idx):
names = np.array(file_names)[select_idx].tolist()
tabs = col2.tabs(names)
for idx, tab in enumerate(tabs):
tabs[idx].image(imgs[select_idx[idx]], width=400)
tabs[idx].write(', '.join(results[select_idx[idx]]))
imgs, file_names = _get_image(my_upload)
results, class2idx = _init_model_return_results(imgs)
plot_canvas(imgs, results, file_names, class2idx)