Clement Vachet commited on
Commit
375e6f0
·
1 Parent(s): a9c685c

Add Lambda handler

Browse files
Files changed (1) hide show
  1. lambda_function.py +72 -0
lambda_function.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from detection import ml_detection, ml_utils
2
+ import base64
3
+ import json
4
+ import logging
5
+
6
+
7
+ logger = logging.getLogger()
8
+ logger.setLevel(logging.INFO)
9
+
10
+
11
+ # Find ML model type based on string request
12
+ def get_model_type(query_string):
13
+ # Default ml model type
14
+ if query_string == "":
15
+ model_type = "facebook/detr-resnet-50"
16
+ # Assess query string value
17
+ elif "detr" in query_string:
18
+ model_type = "facebook/" + query_string
19
+ elif "yolos" in query_string:
20
+ model_type = "hustvl/" + query_string
21
+ else:
22
+ raise Exception('Incorrect model type.')
23
+ return model_type
24
+
25
+
26
+ # Run detection pipeline: load ML model, perform object detection and return json object
27
+ def detection_pipeline(model_type, image_bytes):
28
+ # Load correct ML model
29
+ processor, model = ml_detection.load_model(model_type)
30
+
31
+ # Perform object detection
32
+ results = ml_detection.object_detection(processor, model, image_bytes)
33
+
34
+ # Convert dictionary of tensors to JSON object
35
+ result_json_dict = ml_utils.convert_tensor_dict_to_json(results)
36
+
37
+ return result_json_dict
38
+
39
+
40
+ def lambda_handler(event, context):
41
+ try:
42
+ # Get the model name from the query string parameters
43
+ # Condition for local testing
44
+ is_querystringparam = event.get('queryStringParameters')
45
+ if is_querystringparam is not None:
46
+ model_query = event['queryStringParameters'].get('model', '').lower()
47
+ else:
48
+ model_query = ""
49
+ model_type = get_model_type(model_query)
50
+ logger.info(f"Model query: {model_query}")
51
+ logger.info(f"Model type: {model_type}")
52
+
53
+ # Decode the base64-encoded image data from the event
54
+ image_data = base64.b64decode(event['body'])
55
+ result_dict = detection_pipeline(model_type, image_data)
56
+ logger.info(f"API Results: {result_dict}")
57
+ return {
58
+ 'statusCode': 200,
59
+ 'headers': {
60
+ 'Content-Type': 'application/json'
61
+ },
62
+ 'body': json.dumps(result_dict),
63
+ }
64
+ except Exception as e:
65
+ logger.info(f"API Error: {str(e)}")
66
+ return {
67
+ 'statusCode': 500,
68
+ 'headers': {
69
+ 'Content-Type': 'application/json'
70
+ },
71
+ 'body': json.dumps({'error': str(e)}),
72
+ }