import gradio as gr import torch from PIL import Image, ImageDraw from transformers import AutoImageProcessor from transformers import AutoModelForObjectDetection from PIL import Image model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_synthetic_data_only" image_processor = AutoImageProcessor.from_pretrained(model_save_path) model = AutoModelForObjectDetection.from_pretrained(model_save_path) id2label = model.config.id2label color_dict = { "not_trash": "red", "bin": "green", "trash": "blue", "hand": "purple" } device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) def predict_on_image(image, conf_threshold=0.25): with torch.no_grad(): inputs = image_processor(images=[image], return_tensors="pt") outputs = model(**inputs.to(device)) target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width results = image_processor.post_process_object_detection(outputs, threshold=conf_threshold, target_sizes=target_sizes)[0] # Return all items in results to CPU for key, value in results.items(): try: results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block except: results[key] = value.cpu() # Can return results as plotted on a PIL image (then display the image) draw = ImageDraw.Draw(image) for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): # Create coordinates x, y, x2, y2 = tuple(box.tolist()) # Get label_name label_name = id2label[label.item()] targ_color = color_dict[label_name] # Draw the rectangle draw.rectangle(xy=(x, y, x2, y2), outline=targ_color, width=3) # Create a text string to display text_string_to_show = f"{label_name} ({round(score.item(), 3)})" # Draw the text on the image draw.text(xy=(x, y), text=text_string_to_show, fill="white") # Remove the draw each time del draw return image demo = gr.Interface( fn=predict_on_image, inputs=[ gr.Image(type="pil", label="Upload Target Image"), gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold") ], outputs=gr.Image(type="pil"), title="🚮 Trashify Object Detection Demo", description="Upload an image to detect whether there's a bin, a hand or trash in it." ) if __name__ == "__main__": demo.launch()