detr / app.py
ayaanzaveri's picture
Update app.py
df6c7a9
raw
history blame contribute delete
No virus
2.11 kB
from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from random import choice
from PIL import Image
import os
from matplotlib import rcParams, font_manager
import streamlit as st
import urllib.request
import requests
extractor = AutoFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50")
from transformers import pipeline
pipe = pipeline('object-detection', model=model, feature_extractor=extractor)
img_url = st.text_input('Image URL', 'https://images.unsplash.com/photo-1556911220-bff31c812dba?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=2468&q=80')
st.caption('Downloading Image...')
img_data = requests.get(img_url).content
with open('detect.jpg', 'wb') as handler:
handler.write(img_data)
st.caption('Running Detection...')
output = pipe(img_url)
st.caption('Adding Predictions to Image...')
fpath = "Poppins-SemiBold.ttf"
prop = font_manager.FontProperties(fname=fpath)
img = Image.open('detect.jpg')
plt.figure(dpi=2400)
# Create figure and axes
fig, ax = plt.subplots()
# Display the image
ax.imshow(img)
colors = ["#ef4444", "#f97316", "#eab308", "#84cc16", "#06b6d4", "#6366f1"]
# Create a Rectangle patch
for prediction in output:
selected_color = choice(colors)
x, y, w, h = prediction['box']['xmin'], prediction['box']['ymin'], prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
rect = patches.FancyBboxPatch((x, y), w, h, linewidth=1.25, edgecolor=selected_color, facecolor='none', boxstyle="round,pad=-0.0040,rounding_size=10",)
ax.add_patch(rect)
plt.text(x, y-25, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontsize=5, color=selected_color, fontproperties=prop)
plt.axis('off')
plt.savefig('detect-bbox.jpg', dpi=1200, bbox_inches='tight')
image = Image.open('detect-bbox.jpg')
st.image(image, caption='DETR Image')
plt.show()
st.caption('Done!')