Gabolozano commited on
Commit
b518471
·
verified ·
1 Parent(s): 68f583c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -22
app.py CHANGED
@@ -5,21 +5,14 @@ import numpy as np
5
  import cv2
6
  from PIL import Image
7
 
8
- # Pre-load the base configuration and models (without setting a threshold yet)
9
- base_config = DetrConfig.from_pretrained("facebook/detr-resnet-101")
10
- base_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", config=base_config)
11
- base_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101")
 
12
 
13
- def load_model(threshold):
14
- # Adjust the configuration for the current threshold
15
- config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
16
- # Create a new model instance with the updated configuration
17
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", config=config)
18
- # Image processor does not need to be re-loaded
19
- return pipeline(task='object-detection', model=model, image_processor=base_processor)
20
-
21
- # Initialize the pipeline with a default threshold
22
- od_pipe = load_model(0.25) # Set a default threshold here
23
 
24
  def draw_detections(image, detections):
25
  np_image = np.array(image)
@@ -31,34 +24,33 @@ def draw_detections(image, detections):
31
  box = detection['box']
32
  x_min, y_min = box['xmin'], box['ymin']
33
  x_max, y_max = box['xmax'], box['ymax']
34
- # Draw rectangles and text with a larger font
35
  cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
36
  label_text = f'{label} {score:.2f}'
37
- # Increase the font size and text thickness
38
  cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
39
 
40
- # Convert BGR to RGB for displaying
41
  final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
42
  final_pil_image = Image.fromarray(final_image)
43
  return final_pil_image
44
 
45
- def get_pipeline_prediction(threshold, pil_image):
46
  global od_pipe
47
- od_pipe = load_model(threshold) # reload model with the specified threshold
48
  try:
49
  if not isinstance(pil_image, Image.Image):
50
  pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
51
  result = od_pipe(pil_image)
52
  processed_image = draw_detections(pil_image, result)
53
- return processed_image, result
 
54
  except Exception as e:
55
- return pil_image, {"error": str(e)}
56
 
57
  with gr.Blocks() as demo:
58
  with gr.Row():
59
  with gr.Column():
60
  gr.Markdown("## Object Detection")
61
  inp_image = gr.Image(label="Upload your image here")
 
62
  threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
63
  run_button = gr.Button("Detect Objects")
64
  with gr.Column():
@@ -66,7 +58,9 @@ with gr.Blocks() as demo:
66
  output_image = gr.Image()
67
  with gr.Tab("Detection Results"):
68
  output_data = gr.JSON()
 
 
69
 
70
- run_button.click(get_pipeline_prediction, inputs=[threshold_slider, inp_image], outputs=[output_image, output_data])
71
 
72
  demo.launch()
 
5
  import cv2
6
  from PIL import Image
7
 
8
+ def load_model(model_name, threshold):
9
+ config = DetrConfig.from_pretrained(model_name, threshold=threshold)
10
+ model = DetrForObjectDetection.from_pretrained(model_name, config=config)
11
+ image_processor = DetrImageProcessor.from_pretrained(model_name)
12
+ return pipeline(task='object-detection', model=model, image_processor=image_processor)
13
 
14
+ # Load the initial model with default threshold
15
+ od_pipe = load_model("facebook/detr-resnet-101", 0.25) # Setting a default threshold
 
 
 
 
 
 
 
 
16
 
17
  def draw_detections(image, detections):
18
  np_image = np.array(image)
 
24
  box = detection['box']
25
  x_min, y_min = box['xmin'], box['ymin']
26
  x_max, y_max = box['xmax'], box['ymax']
 
27
  cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
28
  label_text = f'{label} {score:.2f}'
 
29
  cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
30
 
 
31
  final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
32
  final_pil_image = Image.fromarray(final_image)
33
  return final_pil_image
34
 
35
+ def get_pipeline_prediction(model_name, threshold, pil_image):
36
  global od_pipe
37
+ od_pipe = load_model(model_name, threshold) # Reload model with the specified model and threshold
38
  try:
39
  if not isinstance(pil_image, Image.Image):
40
  pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
41
  result = od_pipe(pil_image)
42
  processed_image = draw_detections(pil_image, result)
43
+ description = f'Model used: {model_name}, Detection Threshold: {threshold}'
44
+ return processed_image, result, description
45
  except Exception as e:
46
+ return pil_image, {"error": str(e)}, "Failed to process image"
47
 
48
  with gr.Blocks() as demo:
49
  with gr.Row():
50
  with gr.Column():
51
  gr.Markdown("## Object Detection")
52
  inp_image = gr.Image(label="Upload your image here")
53
+ model_dropdown = gr.Dropdown(choices=["facebook/detr-resnet-50", "facebook/detr-resnet-50-panoptic", "facebook/detr-resnet-101", "facebook/detr-resnet-101-panoptic"], value="facebook/detr-resnet-101", label="Select Model")
54
  threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
55
  run_button = gr.Button("Detect Objects")
56
  with gr.Column():
 
58
  output_image = gr.Image()
59
  with gr.Tab("Detection Results"):
60
  output_data = gr.JSON()
61
+ with gr.Tab("Description"):
62
+ description_output = gr.Textbox()
63
 
64
+ run_button.click(get_pipeline_prediction, inputs=[model_dropdown, threshold_slider, inp_image], outputs=[output_image, output_data, description_output])
65
 
66
  demo.launch()