ai-photo-gallery / pages /1_πŸ”₯_An_Image.py
KyanChen's picture
Update pages/1_πŸ”₯_An_Image.py
3bd3201
raw
history blame
4.26 kB
import pre_reqs
import cv2
import numpy as np
import streamlit as st
from mmcls.apis import init_model
from mmcls.apis import inference_model_topk as inference_cls_model
from mmdet.registry import VISUALIZERS
# from mmcls.utils import register_all_modules as register_all_modules_cls
from mmdet.apis import init_detector, inference_detector
from mmdet.utils import register_all_modules as register_all_modules_det
import pandas as pd
from PIL import Image
st.set_page_config(page_title="πŸ”₯ An Image Demo", page_icon="πŸ”₯", layout='wide')
st.markdown("# πŸ”₯ An Image Demo")
st.write(
":dog: Try uploading an image to get the possible categories, objects."
)
st.sidebar.header("An Image Demo")
my_upload = st.sidebar.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
col1, col2, col3 = st.columns(3)
model_option = st.radio(
"What\'s your inference model",
('cls', 'det'))
parent_folder = './'
topk = st.slider('Return top-k predictions', 1, 10, 3)
# @st.cache_resource
def _init_model(model_option):
if model_option == 'cls':
# init model
model = init_model(parent_folder + 'configs/resnet/resnet50_8xb32_in1k.py',
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth')
visualizer = None
elif model_option == 'det':
# register_all_modules_det()
model = init_detector(parent_folder + 'configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py',
'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth', device='cpu')
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta
else:
model = None
visualizer = None
return model, visualizer
# @st.cache_data
def _get_image(my_upload=my_upload):
if my_upload is not None:
img_file = my_upload
else:
img_file = parent_folder + "images/zebra.jpg"
return Image.open(img_file).convert('RGB')
# @st.cache_resource
def _inference_model(img, model, visualizer, model_option):
img = np.array(img)
if model_option == 'cls':
return_results = inference_cls_model(model, img, 10)
vis_img = img
elif model_option == 'det':
vis_img = img.copy()
results = inference_detector(model, img)
# import pdb
# pdb.set_trace()
b, h, w = results.pred_instances.masks.shape
vis_img = cv2.resize(vis_img, (w, h))
visualizer.add_datasample(
name='result',
image=vis_img,
data_sample=results,
draw_gt=False,
show=False)
vis_img = visualizer.get_image()
cls_names = visualizer.dataset_meta['classes']
return_results = {'scores': results.pred_instances.scores[:10].numpy(),
'bboxes': results.pred_instances.bboxes[:10].numpy(),
'labels': [cls_names[x.item()] for x in results.pred_instances.labels[:10]]
}
return return_results, vis_img
def plot_canvas(img, vis_img, results, model_option):
col1.write("Original Image :camera:")
col1.image(img)
col2.write("Visualization:wrench:")
col3.write("Metainfo:wrench:")
if model_option == 'cls':
col2.image(vis_img)
df = pd.DataFrame({
'category': results["pred_class"][:topk],
'probability': [f"{x:.2f}" for x in results["pred_score"]][:topk]
}, index=None)
col3.dataframe(df)
elif model_option == 'det':
# vis_idx = st.slider('Show a prediction', 1, 10, 3, disabled=True)
col2.image(vis_img)
df = pd.DataFrame({
'category': results["labels"][:topk],
'probability': [f"{x:.2f}" for x in results["scores"]][:topk],
'box': [list(map(lambda t: f"{t:.2f}", list(x))) for x in results["bboxes"][:topk]]
}, index=None)
col3.dataframe(df)
model, visualizer = _init_model(model_option)
img = _get_image(my_upload)
results, vis_img = _inference_model(img, model, visualizer, model_option)
plot_canvas(img, vis_img, results, model_option)