mrdbourke's picture
Update app.py
4845fdd verified
import gradio as gr
import torch
from PIL import Image, ImageDraw
from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection
from PIL import Image
model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_synthetic_data_only"
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
model = AutoModelForObjectDetection.from_pretrained(model_save_path)
id2label = model.config.id2label
color_dict = {
"not_trash": "red",
"bin": "green",
"trash": "blue",
"hand": "purple"
}
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
def predict_on_image(image, conf_threshold=0.25):
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width
results = image_processor.post_process_object_detection(outputs,
threshold=conf_threshold,
target_sizes=target_sizes)[0]
# Return all items in results to CPU
for key, value in results.items():
try:
results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
except:
results[key] = value.cpu()
# Can return results as plotted on a PIL image (then display the image)
draw = ImageDraw.Draw(image)
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label_name
label_name = id2label[label.item()]
targ_color = color_dict[label_name]
# Draw the rectangle
draw.rectangle(xy=(x, y, x2, y2),
outline=targ_color,
width=3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text on the image
draw.text(xy=(x, y),
text=text_string_to_show,
fill="white")
# Remove the draw each time
del draw
return image
demo = gr.Interface(
fn=predict_on_image,
inputs=[
gr.Image(type="pil", label="Upload Target Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
],
outputs=gr.Image(type="pil"),
title="๐Ÿšฎ Trashify Object Detection Demo",
description="Upload an image to detect whether there's a bin, a hand or trash in it. Model trained on synthetically generated images by Flux and labels creating by GroundingDINO."
)
if __name__ == "__main__":
demo.launch()