neko941's picture
Update app.py
688e29a
raw
history blame contribute delete
No virus
4.03 kB
import os
import torch
import urllib
from PIL import Image
import streamlit as st
from pathlib import Path
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
# Check file(s) for acceptable suffix
if file and suffix:
if isinstance(suffix, str):
suffix = [suffix]
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower() # file suffix
if len(s):
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
def check_file(file, suffix=''):
# Search/download file (if necessary) and return path
check_suffix(file, suffix) # optional
file = str(file) # convert to str()
if os.path.isfile(file) or not file: # exists
return file
elif file.startswith(('http:/', 'https:/')): # download
url = file # warning: Pathlib turns :// -> :/
# '%2F' to '/', split https://url.com/file.txt?auth
file = Path(urllib.parse.unquote(file).split('?')[0]).name
if os.path.isfile(file):
print(f'Found {url} locally at {file}') # file already exists
else:
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, file)
assert Path(file).exists() and Path(file).stat(
).st_size > 0, f'File download failed: {url}' # check
return file
def read_pretrain(path):
return torch.hub.load('ultralytics/yolov5', 'custom', path=path)
# default_pretrained = '2022.11.04-YOLOv5x6_1280-Hololive_Waifu_Classification.pt'
st.title("Hololive Waifu Classification")
image = st.text_input('Image URL', '')
st.info(
'Images for quick tesing:\n \n \n'
' - https://i.imgur.com/tFZwWYw.jpg'
'\n \n \n'
' - https://static.wikia.nocookie.net/omniversal-battlefield/images/b/bd/Council.jpg'
'\n \n \n'
' - https://rare-gallery.com/uploads/posts/951368-anime-anime-girls-digital-art-artwork-2D-portrait.jpg'
'\n \n \n'
' - https://megapx-assets.dcard.tw/images/65993ab1-fe08-43be-87cd-2ecd201cacbd/1280.jpeg'
'\n \n \n'
' - https://img.esportsku.com/wp-content/uploads//2021/07/hololive-en.png')
pretrained = st.selectbox('Select pre-trained', ('2022.11.04-YOLOv5x6_1280-Hololive_Waifu_Classification.pt', '2022.11.01-YOLOv5x6_1280-Hololive_Waifu_Classification.pt'))
imgsz = st.number_input(label='Image Size', min_value=None, max_value=None, value=1280, step=1)
conf = st.slider(label='Confidence threshold', min_value=0.0, max_value=1.0, value=0.25, step=0.01)
iou = st.slider(label='IoU threshold', min_value=0.0, max_value=1.0, value=0.45, step=0.01)
multi_label = st.selectbox('Multiple labels per box', (False, True))
agnostic = st.selectbox('Class-agnostic', (False, True))
amp = st.selectbox('Automatic Mixed Precision inference', (False, True))
max_det = st.number_input(label='Maximum number of detections per image', min_value=None, max_value=None, value=1000, step=1)
clicked = st.button('Excute')
# with st.spinner('Loading the model...'):
# model = read_pretrain(default_pretrained)
if clicked:
with st.spinner('Loading the image...'):
image_path = check_file(image)
input_image = Image.open(image_path)
# if default_pretrained != pretrained:
with st.spinner('Loading the model...'):
# model = torch.hub.load('ultralytics/yolov5', 'custom', path=os.path.join('pretrained', pretrained))
# model = torch.hub.load('ultralytics/yolov5', 'custom', path=pretrained)
model = read_pretrain(pretrained)
with st.spinner('Updating configuration...'):
model.conf = float(conf)
model.max_det = int(max_det)
model.iou = float(iou)
model.agnostic = agnostic
model.multi_label = multi_label
model.amp = amp
with st.spinner('Predicting...'):
results = model(input_image, size=int(imgsz))
for img in results.render():
st.image(img)
st.write(results.pandas().xyxy[0])
os.remove(image_path)