Spaces:
Runtime error
Runtime error
import pre_reqs | |
import cv2 | |
import numpy as np | |
import streamlit as st | |
from mmcls.apis import init_model | |
from mmcls.apis import inference_model_topk as inference_cls_model | |
from mmdet.registry import VISUALIZERS | |
# from mmcls.utils import register_all_modules as register_all_modules_cls | |
from mmdet.apis import init_detector, inference_detector | |
from mmdet.utils import register_all_modules as register_all_modules_det | |
import pandas as pd | |
from PIL import Image | |
st.set_page_config(page_title="π₯ An Image Demo", page_icon="π₯", layout='wide') | |
st.markdown("# π₯ An Image Demo") | |
st.write( | |
":dog: Try uploading an image to get the possible categories, objects." | |
) | |
st.sidebar.header("An Image Demo") | |
my_upload = st.sidebar.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) | |
col1, col2, col3 = st.columns(3) | |
model_option = st.radio( | |
"What\'s your inference model", | |
('cls', 'det')) | |
parent_folder = './' | |
topk = st.slider('Return top-k predictions', 1, 10, 3) | |
# @st.cache_resource | |
def _init_model(model_option): | |
if model_option == 'cls': | |
# init model | |
model = init_model(parent_folder + 'configs/resnet/resnet50_8xb32_in1k.py', | |
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth') | |
visualizer = None | |
elif model_option == 'det': | |
# register_all_modules_det() | |
model = init_detector(parent_folder + 'configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py', | |
'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth', device='cpu') | |
visualizer = VISUALIZERS.build(model.cfg.visualizer) | |
visualizer.dataset_meta = model.dataset_meta | |
else: | |
model = None | |
visualizer = None | |
return model, visualizer | |
# @st.cache_data | |
def _get_image(my_upload=my_upload): | |
if my_upload is not None: | |
img_file = my_upload | |
else: | |
img_file = parent_folder + "images/zebra.jpg" | |
return Image.open(img_file).convert('RGB') | |
# @st.cache_resource | |
def _inference_model(img, model, visualizer, model_option): | |
img = np.array(img) | |
if model_option == 'cls': | |
return_results = inference_cls_model(model, img, 10) | |
vis_img = img | |
elif model_option == 'det': | |
vis_img = img.copy() | |
results = inference_detector(model, img) | |
# import pdb | |
# pdb.set_trace() | |
b, h, w = results.pred_instances.masks.shape | |
vis_img = cv2.resize(vis_img, (w, h)) | |
visualizer.add_datasample( | |
name='result', | |
image=vis_img, | |
data_sample=results, | |
draw_gt=False, | |
show=False) | |
vis_img = visualizer.get_image() | |
cls_names = visualizer.dataset_meta['classes'] | |
return_results = {'scores': results.pred_instances.scores[:10].numpy(), | |
'bboxes': results.pred_instances.bboxes[:10].numpy(), | |
'labels': [cls_names[x.item()] for x in results.pred_instances.labels[:10]] | |
} | |
return return_results, vis_img | |
def plot_canvas(img, vis_img, results, model_option): | |
col1.write("Original Image :camera:") | |
col1.image(img) | |
col2.write("Visualization:wrench:") | |
col3.write("Metainfo:wrench:") | |
if model_option == 'cls': | |
col2.image(vis_img) | |
df = pd.DataFrame({ | |
'category': results["pred_class"][:topk], | |
'probability': [f"{x:.2f}" for x in results["pred_score"]][:topk] | |
}, index=None) | |
col3.dataframe(df) | |
elif model_option == 'det': | |
# vis_idx = st.slider('Show a prediction', 1, 10, 3, disabled=True) | |
col2.image(vis_img) | |
df = pd.DataFrame({ | |
'category': results["labels"][:topk], | |
'probability': [f"{x:.2f}" for x in results["scores"]][:topk], | |
'box': [list(map(lambda t: f"{t:.2f}", list(x))) for x in results["bboxes"][:topk]] | |
}, index=None) | |
col3.dataframe(df) | |
model, visualizer = _init_model(model_option) | |
img = _get_image(my_upload) | |
results, vis_img = _inference_model(img, model, visualizer, model_option) | |
plot_canvas(img, vis_img, results, model_option) | |