Spaces:
Sleeping
Sleeping
File size: 4,040 Bytes
ae17bb5 71bec2b ae17bb5 6a10656 a1a1640 ae17bb5 a1a1640 6a10656 ae17bb5 a1a1640 ae17bb5 b062768 ae17bb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
import base64
import os
import requests
import json
import utils
from dotenv import load_dotenv, find_dotenv
# List of ML models
list_models = ["facebook/detr-resnet-50", "facebook/detr-resnet-101", "hustvl/yolos-tiny", "hustvl/yolos-small"]
list_models_simple = [os.path.basename(model) for model in list_models]
# Retrieve API URLs from env file or global settings
def retrieve_api():
env_path = find_dotenv('config_api.env')
if env_path:
load_dotenv(dotenv_path=env_path)
print("config_api.env file loaded successfully.")
else:
print("config_api.env file not found.")
# Use of AWS endpoint or local container by default
global AWS_API
AWS_API = os.getenv("AWS_API", default="http://localhost:8080")
#@spaces.GPU
def detect(image_path, model_id, threshold):
print("\n UI - Object detection...")
print("\t ML model:", list_models[model_id])
with open(image_path, 'rb') as image_file:
image_bytes = image_file.read()
# API Call for object prediction with model type as query parameter
if AWS_API == "http://localhost:8080":
API_endpoint = AWS_API + "/2015-03-31/functions/function/invocations"
else:
API_endpoint = AWS_API + "/dev/detect"
# Encode the image data in base64
encoded_image = base64.b64encode(image_bytes).decode('utf-8')
model_name = list_models_simple[model_id]
# Requests arguments
payload = {
'body': encoded_image,
'isBase64Encoded': True,
'model': model_name,
}
headers = {"Content-Type": "application/json"}
response = requests.post(API_endpoint, json=payload, headers=headers)
if response.status_code == 200:
# Process the response
response_json = response.json()
print('\t API response', response_json)
prediction_dict = json.loads(response_json["body"])
else:
prediction_dict = {"Error": response.status_code}
gr.Error(f"\t API Error: {response.status_code}")
# Generate gradio output components: image and json
output_json, output_pil_img = utils.generate_gradio_outputs(image_path, prediction_dict, threshold)
return output_json, output_pil_img
def demo():
with gr.Blocks(theme="base") as demo:
gr.Markdown("# Object detection task - use of AWS Lambda")
gr.Markdown(
"""
This web application uses transformer models to detect objects on images.
Machine learning models were trained on the COCO dataset.
You can load an image and see the predictions for the objects detected.
Note: This web application uses deployed ML models, available via AWS Lambda and AWS API Gateway.
"""
)
with gr.Row():
with gr.Column():
model_id = gr.Radio(list_models, \
label="Detection models", value=list_models[0], type="index", info="Choose your detection model")
with gr.Column():
threshold = gr.Slider(0, 1.0, value=0.9, label='Detection threshold', info="Choose your detection threshold")
with gr.Row():
input_image = gr.Image(label="Input image", type="filepath")
output_image = gr.Image(label="Output image", type="pil")
output_json = gr.JSON(label="JSON output", min_height=240, max_height=300)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_button = gr.ClearButton()
gr.Examples(['samples/savanna.jpg', 'samples/boats.jpg'], inputs=input_image)
submit_btn.click(fn=detect, inputs=[input_image, model_id, threshold], outputs=[output_json, output_image])
clear_button.click(lambda: [None, None, None], \
inputs=None, \
outputs=[input_image, output_image, output_json], \
queue=False)
demo.queue().launch(debug=True)
if __name__ == "__main__":
retrieve_api()
demo() |