Clement Vachet commited on
Commit
ae17bb5
·
1 Parent(s): b41850c

Add user interface via Gradio

Browse files
Files changed (4) hide show
  1. app.py +121 -0
  2. samples/boats.jpg +0 -0
  3. samples/savanna.jpg +0 -0
  4. utils.py +86 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import base64
3
+ import os
4
+ import requests
5
+ import json
6
+ import utils
7
+
8
+ from dotenv import load_dotenv, find_dotenv
9
+
10
+
11
+ # List of ML models
12
+ list_models = ["facebook/detr-resnet-50", "facebook/detr-resnet-101", "hustvl/yolos-tiny", "hustvl/yolos-small"]
13
+ list_models_simple = [os.path.basename(model) for model in list_models]
14
+
15
+
16
+ # Retrieve API URLs from env file or global settings
17
+ def retrieve_api():
18
+
19
+ env_path = find_dotenv('config_api.env')
20
+ if env_path:
21
+ load_dotenv(dotenv_path=env_path)
22
+ print("config_api.env file loaded successfully.")
23
+ else:
24
+ print("config_api.env file not found.")
25
+
26
+ # Use of AWS endpoint or local container by default
27
+ global AWS_API
28
+ AWS_API = os.getenv("AWS_API", default="http://localhost:8080")
29
+
30
+
31
+ #@spaces.GPU
32
+ def detect(image_path, model_id, threshold):
33
+ print("\n Object detection...")
34
+ print("\t ML model:", list_models[model_id])
35
+
36
+ with open(image_path, 'rb') as image_file:
37
+ image_bytes = image_file.read()
38
+
39
+ # API Call for object prediction with model type as query parameter
40
+ if AWS_API == "http://localhost:8080":
41
+ API_endpoint = AWS_API + "/2015-03-31/functions/function/invocations"
42
+ else:
43
+ API_endpoint = AWS_API + "/dev/detect"
44
+ print("\t API_Endpoint: ", API_endpoint)
45
+
46
+ # Encode the image data in base64
47
+ encoded_image = base64.b64encode(image_bytes).decode('utf-8')
48
+
49
+ # Prepare the payload
50
+ payload = {
51
+ 'body': encoded_image
52
+ }
53
+
54
+ # Prepare the query string parameters
55
+ model_name = list_models_simple[model_id]
56
+ params = {
57
+ 'model': model_name
58
+ }
59
+
60
+ response = requests.post(API_endpoint, json=payload, params=params)
61
+
62
+ if response.status_code == 200:
63
+ # Process the response
64
+ response_json = response.json()
65
+ print('\t API response', response_json)
66
+ print('\t API response - type', type(response_json))
67
+ prediction_dict = json.loads(response_json["body"])
68
+ print('\t API body prediction_dict', prediction_dict)
69
+ print('\t API body prediction_dict - type', type(prediction_dict))
70
+ else:
71
+ prediction_dict = {"Error": response.status_code}
72
+ gr.Error(f"\t API Error: {response.status_code}")
73
+
74
+ # Generate gradio output components: image and json
75
+ output_json, output_pil_img = utils.generate_gradio_outputs(image_path, prediction_dict, threshold)
76
+
77
+ return output_json, output_pil_img
78
+
79
+
80
+ def demo():
81
+ with gr.Blocks(theme="base") as demo:
82
+ gr.Markdown("# Object detection task - use of AWS Lambda")
83
+ gr.Markdown(
84
+ """
85
+ This web application uses transformer models to detect objects on images.
86
+ Machine learning models were trained on the COCO dataset.
87
+ You can load an image and see the predictions for the objects detected.
88
+
89
+ Note: This web application uses deployed ML models, available via AWS Lambda and AWS API Gateway.
90
+ """
91
+ )
92
+
93
+ with gr.Row():
94
+ with gr.Column():
95
+ model_id = gr.Radio(list_models, \
96
+ label="Detection models", value=list_models[0], type="index", info="Choose your detection model")
97
+ with gr.Column():
98
+ threshold = gr.Slider(0, 1.0, value=0.9, label='Detection threshold', info="Choose your detection threshold")
99
+
100
+ with gr.Row():
101
+ input_image = gr.Image(label="Input image", type="filepath")
102
+ output_image = gr.Image(label="Output image", type="pil")
103
+ output_json = gr.JSON(label="JSON output", min_height=240, max_height=300)
104
+
105
+ with gr.Row():
106
+ submit_btn = gr.Button("Submit")
107
+ clear_button = gr.ClearButton()
108
+
109
+ gr.Examples(['samples/savanna.jpg', 'samples/boats.jpg'], inputs=input_image)
110
+
111
+ submit_btn.click(fn=detect, inputs=[input_image, model_id, threshold], outputs=[output_json, output_image])
112
+ clear_button.click(lambda: [None, None, None], \
113
+ inputs=None, \
114
+ outputs=[input_image, output_image, output_json], \
115
+ queue=False)
116
+
117
+ demo.queue().launch(debug=True)
118
+
119
+ if __name__ == "__main__":
120
+ retrieve_api()
121
+ demo()
samples/boats.jpg ADDED
samples/savanna.jpg ADDED
utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import matplotlib.pyplot as plt
3
+ import io
4
+
5
+
6
+ # COCO classes
7
+ CLASSES = [
8
+ 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
9
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
10
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
11
+ 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
12
+ 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
13
+ 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
14
+ 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
15
+ 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
16
+ 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
17
+ 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
18
+ 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
19
+ 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
20
+ 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
21
+ 'toothbrush'
22
+ ]
23
+ COLORS = [
24
+ [0.000, 0.447, 0.741],
25
+ [0.850, 0.325, 0.098],
26
+ [0.929, 0.694, 0.125],
27
+ [0.494, 0.184, 0.556],
28
+ [0.466, 0.674, 0.188],
29
+ [0.301, 0.745, 0.933],
30
+ ]
31
+
32
+
33
+ # Update JSON dictionary with rounded values and class names
34
+ def generate_output_json(json_dict):
35
+ json_dict['scores'] = [round(score, 3) for score in json_dict['scores']]
36
+ json_dict['boxes'] = [[round(coord, 3) for coord in box] for box in json_dict['boxes']]
37
+ json_dict['labels'] = [CLASSES[label] for label in json_dict['labels']]
38
+ return json_dict
39
+
40
+
41
+ # Generate matplotlib figure from prediction scores and boxes
42
+ def generate_output_figure(image_path, predictions, threshold):
43
+ pil_img = Image.open(image_path)
44
+
45
+ plt.figure(figsize=(16, 10))
46
+ plt.imshow(pil_img)
47
+ ax = plt.gca()
48
+ colors = COLORS * 100
49
+
50
+ print("\t Detailed information...")
51
+ for score, label, box in zip(predictions["scores"], predictions["labels"], predictions["boxes"]):
52
+ #box = [round(i, 2) for i in box]
53
+ print(
54
+ f"\t\t Detected {label} with confidence "
55
+ f"{score} at location {box}"
56
+ )
57
+
58
+ if score > threshold:
59
+ c = COLORS[hash(label) % len(COLORS)]
60
+ ax.add_patch(
61
+ plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], fill=False, color=c, linewidth=3)
62
+ )
63
+ text = f"{label}: {score:0.2f}"
64
+ ax.text(box[0], box[1], text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
65
+ plt.axis("off")
66
+
67
+ return plt.gcf()
68
+
69
+
70
+ # Generate PIL image from matplotlib figure
71
+ def generate_output_image(output_figure):
72
+ # Convert matplotlib figure to PIL image
73
+ #output_figure = plt.gcf()
74
+ buf = io.BytesIO()
75
+ output_figure.savefig(buf, bbox_inches="tight")
76
+ buf.seek(0)
77
+ output_pil_img = Image.open(buf)
78
+
79
+ return output_pil_img
80
+
81
+
82
+ def generate_gradio_outputs(image_path, prediction_dict, threshold):
83
+ output_json = generate_output_json(prediction_dict)
84
+ output_figure = generate_output_figure(image_path, output_json, threshold)
85
+ output_pil_img = generate_output_image(output_figure)
86
+ return output_json, output_pil_img