Clement Vachet commited on
Commit
67f4974
·
1 Parent(s): f764d1c

Improve code based on pylint and black suggestions

Browse files
detection/ml_detection.py CHANGED
@@ -1,15 +1,21 @@
 
 
 
 
 
 
1
  from transformers import DetrImageProcessor, DetrForObjectDetection
2
  from transformers import YolosImageProcessor, YolosForObjectDetection
3
- import torch
4
  from PIL import Image
5
- import io
6
 
7
 
8
  # Load transformer-based model (Yolos or DETR)
9
  def load_model(model_uri: str):
10
- """Load Transformer model"""
11
- """ - Doc DETR: https://huggingface.co/docs/transformers/en/model_doc/detr"""
12
- """ - Doc Yolos: https://huggingface.co/docs/transformers/en/model_doc/yolos"""
 
 
13
 
14
  if "detr" in model_uri:
15
  # you can specify the revision tag if you don't want the timm dependency
@@ -27,9 +33,9 @@ def load_model(model_uri: str):
27
  def object_detection(processor, model, image_bytes):
28
  """Perform object detection task"""
29
 
30
- print('Object detection prediction...')
31
- #url = "http://images.cocodataset.org/val2017/000000039769.jpg"
32
- #image = Image.open(requests.get(url, stream=True).raw)
33
 
34
  img = Image.open(io.BytesIO(image_bytes))
35
  inputs = processor(images=img, return_tensors="pt")
@@ -39,5 +45,7 @@ def object_detection(processor, model, image_bytes):
39
  # convert outputs (bounding boxes and class logits) to COCO API
40
  # let's only keep detections with score > 0.9
41
  target_sizes = torch.tensor([img.size[::-1]])
42
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
 
 
43
  return results
 
1
+ """
2
+ Object Detection module
3
+ """
4
+
5
+ import io
6
+ import torch
7
  from transformers import DetrImageProcessor, DetrForObjectDetection
8
  from transformers import YolosImageProcessor, YolosForObjectDetection
 
9
  from PIL import Image
 
10
 
11
 
12
  # Load transformer-based model (Yolos or DETR)
13
  def load_model(model_uri: str):
14
+ """
15
+ Load Transformer model
16
+ - Doc DETR: https://huggingface.co/docs/transformers/en/model_doc/detr
17
+ - Doc Yolos: https://huggingface.co/docs/transformers/en/model_doc/yolos
18
+ """
19
 
20
  if "detr" in model_uri:
21
  # you can specify the revision tag if you don't want the timm dependency
 
33
  def object_detection(processor, model, image_bytes):
34
  """Perform object detection task"""
35
 
36
+ print("Object detection prediction...")
37
+ # url = "http://images.cocodataset.org/val2017/000000039769.jpg"
38
+ # image = Image.open(requests.get(url, stream=True).raw)
39
 
40
  img = Image.open(io.BytesIO(image_bytes))
41
  inputs = processor(images=img, return_tensors="pt")
 
45
  # convert outputs (bounding boxes and class logits) to COCO API
46
  # let's only keep detections with score > 0.9
47
  target_sizes = torch.tensor([img.size[::-1]])
48
+ results = processor.post_process_object_detection(
49
+ outputs, target_sizes=target_sizes, threshold=0.9
50
+ )[0]
51
  return results
detection/ml_utils.py CHANGED
@@ -1,17 +1,24 @@
 
 
 
 
1
  import torch
2
- import json
3
 
4
 
5
  def convert_tensor_dict_to_json(tensor_dict):
6
  """Convert a dictionary of tensors to a JSON-serializable dictionary."""
7
 
8
- json_dict = {key: value.detach().cpu().numpy().tolist() if isinstance(value, torch.Tensor) else value
9
- for key, value in tensor_dict.items()}
 
 
 
 
 
 
10
 
11
  # Convert to JSON string
12
  # json_str = json.dumps(json_dict)
13
  # print(json_str)
14
 
15
  return json_dict
16
-
17
-
 
1
+ """
2
+ Utility functions for detection module
3
+ """
4
+
5
  import torch
 
6
 
7
 
8
  def convert_tensor_dict_to_json(tensor_dict):
9
  """Convert a dictionary of tensors to a JSON-serializable dictionary."""
10
 
11
+ json_dict = {
12
+ key: (
13
+ value.detach().cpu().numpy().tolist()
14
+ if isinstance(value, torch.Tensor)
15
+ else value
16
+ )
17
+ for key, value in tensor_dict.items()
18
+ }
19
 
20
  # Convert to JSON string
21
  # json_str = json.dumps(json_dict)
22
  # print(json_str)
23
 
24
  return json_dict
 
 
inference_api.py CHANGED
@@ -1,7 +1,11 @@
1
- import requests
2
- import argparse
 
 
3
  import sys
4
  import base64
 
 
5
 
6
  # Default examples
7
  # api = "http://localhost:8080/2015-03-31/functions/function/invocations"
@@ -12,15 +16,26 @@ def arg_parser():
12
  """Parse arguments"""
13
 
14
  # Create an ArgumentParser object
15
- parser = argparse.ArgumentParser(description='Object detection inference via API call')
 
 
16
  # Add arguments
17
- parser.add_argument('--api', type=str, help='URL to server API (with endpoint)', required=True)
18
- parser.add_argument('--file', type=str, help='Path to the input image file', required=True)
19
- parser.add_argument('--model', type=str, \
20
- choices=['detr-resnet-50', 'detr-resnet-101', 'yolos-tiny', 'yolos-small'], \
21
- help='Model type', \
22
- required=False)
23
- parser.add_argument('-v', '--verbose', action='store_true', help='Increase output verbosity')
 
 
 
 
 
 
 
 
 
24
  return parser
25
 
26
 
@@ -30,7 +45,7 @@ def main(args=None):
30
  args = arg_parser().parse_args(args)
31
  # Use the arguments
32
  if args.verbose:
33
- print(f'Input file: {args.file}')
34
 
35
  # Retrieve model type
36
  if args.model:
@@ -39,27 +54,27 @@ def main(args=None):
39
  model_name = ""
40
 
41
  # Load image
42
- with open(args.file, 'rb') as image_file:
43
  image_data = image_file.read()
44
 
45
  # Encode the image data in base64
46
- encoded_image = base64.b64encode(image_data).decode('utf-8')
47
 
48
  # Prepare the payload
49
  payload = {
50
- 'body': encoded_image,
51
- 'isBase64Encoded': True,
52
- 'model': model_name,
53
  }
54
 
55
  # Send request to API
56
  # Option 'files': A dictionary of files to send to the specified url
57
  # response = requests.post(args.api, files={'image': image_data})
58
  # Option 'json': A JSON object to send to the specified url
59
- response = requests.post(args.api, json = payload)
60
 
61
  if response.status_code == 200:
62
- print('Detection Results:')
63
  # Process the response
64
  # processed_data = json.loads(response.content)
65
  # print('processed_data', processed_data)
 
1
+ """
2
+ Object detection - command line inference via API
3
+ """
4
+
5
  import sys
6
  import base64
7
+ import argparse
8
+ import requests
9
 
10
  # Default examples
11
  # api = "http://localhost:8080/2015-03-31/functions/function/invocations"
 
16
  """Parse arguments"""
17
 
18
  # Create an ArgumentParser object
19
+ parser = argparse.ArgumentParser(
20
+ description="Object detection inference via API call"
21
+ )
22
  # Add arguments
23
+ parser.add_argument(
24
+ "--api", type=str, help="URL to server API (with endpoint)", required=True
25
+ )
26
+ parser.add_argument(
27
+ "--file", type=str, help="Path to the input image file", required=True
28
+ )
29
+ parser.add_argument(
30
+ "--model",
31
+ type=str,
32
+ choices=["detr-resnet-50", "detr-resnet-101", "yolos-tiny", "yolos-small"],
33
+ help="Model type",
34
+ required=False,
35
+ )
36
+ parser.add_argument(
37
+ "-v", "--verbose", action="store_true", help="Increase output verbosity"
38
+ )
39
  return parser
40
 
41
 
 
45
  args = arg_parser().parse_args(args)
46
  # Use the arguments
47
  if args.verbose:
48
+ print(f"Input file: {args.file}")
49
 
50
  # Retrieve model type
51
  if args.model:
 
54
  model_name = ""
55
 
56
  # Load image
57
+ with open(args.file, "rb") as image_file:
58
  image_data = image_file.read()
59
 
60
  # Encode the image data in base64
61
+ encoded_image = base64.b64encode(image_data).decode("utf-8")
62
 
63
  # Prepare the payload
64
  payload = {
65
+ "body": encoded_image,
66
+ "isBase64Encoded": True,
67
+ "model": model_name,
68
  }
69
 
70
  # Send request to API
71
  # Option 'files': A dictionary of files to send to the specified url
72
  # response = requests.post(args.api, files={'image': image_data})
73
  # Option 'json': A JSON object to send to the specified url
74
+ response = requests.post(args.api, json=payload, timeout=60)
75
 
76
  if response.status_code == 200:
77
+ print("Detection Results:")
78
  # Process the response
79
  # processed_data = json.loads(response.content)
80
  # print('processed_data', processed_data)
inference_direct.py CHANGED
@@ -1,10 +1,13 @@
 
 
 
 
1
  from detection import ml_detection, ml_utils
2
- import json
3
- from PIL import Image
4
 
5
 
6
  # Run detection pipeline: load ML model, perform object detection and return json object
7
  def detection_pipeline(model_type, image_bytes):
 
8
  # Load correct ML model
9
  detr_processor, detr_model = ml_detection.load_model(model_type)
10
 
@@ -18,15 +21,16 @@ def detection_pipeline(model_type, image_bytes):
18
 
19
 
20
  def main():
21
- print('Main function')
 
22
 
23
  model_type = "facebook/detr-resnet-50"
24
- image_path = './samples/boats.jpg'
25
 
26
  # Reading image file as image_bytes (similar to API request)
27
- print('Reading image file...')
28
- with open(image_path, 'rb') as image_file:
29
- image_bytes = image_file.read()
30
 
31
  result_json = detection_pipeline(model_type, image_bytes)
32
  print("result_json:", result_json)
@@ -34,4 +38,3 @@ def main():
34
 
35
  if __name__ == "__main__":
36
  main()
37
-
 
1
+ """
2
+ Direct inference with hard-coded data
3
+ """
4
+
5
  from detection import ml_detection, ml_utils
 
 
6
 
7
 
8
  # Run detection pipeline: load ML model, perform object detection and return json object
9
  def detection_pipeline(model_type, image_bytes):
10
+ """Detection pipeline: load ML model, perform object detection and return json object"""
11
  # Load correct ML model
12
  detr_processor, detr_model = ml_detection.load_model(model_type)
13
 
 
21
 
22
 
23
  def main():
24
+ """Main function"""
25
+ print("Main function")
26
 
27
  model_type = "facebook/detr-resnet-50"
28
+ image_path = "./samples/boats.jpg"
29
 
30
  # Reading image file as image_bytes (similar to API request)
31
+ print("Reading image file...")
32
+ with open(image_path, "rb") as image_file:
33
+ image_bytes = image_file.read()
34
 
35
  result_json = detection_pipeline(model_type, image_bytes)
36
  print("result_json:", result_json)
 
38
 
39
  if __name__ == "__main__":
40
  main()
 
lambda_function.py CHANGED
@@ -1,7 +1,11 @@
1
- from detection import ml_detection, ml_utils
 
 
 
2
  import base64
3
  import json
4
  import logging
 
5
 
6
 
7
  logger = logging.getLogger()
@@ -10,6 +14,7 @@ logger.setLevel(logging.INFO)
10
 
11
  # Find ML model type based on string request
12
  def get_model_type(query_string):
 
13
  # Default ml model type
14
  if query_string == "":
15
  model_type = "facebook/detr-resnet-50"
@@ -19,12 +24,13 @@ def get_model_type(query_string):
19
  elif "yolos" in query_string:
20
  model_type = "hustvl/" + query_string
21
  else:
22
- raise Exception('Incorrect model type.')
23
  return model_type
24
 
25
 
26
  # Run detection pipeline: load ML model, perform object detection and return json object
27
  def detection_pipeline(model_type, image_bytes):
 
28
  # Load correct ML model
29
  processor, model = ml_detection.load_model(model_type)
30
 
@@ -38,36 +44,43 @@ def detection_pipeline(model_type, image_bytes):
38
 
39
 
40
  def lambda_handler(event, context):
 
 
 
 
 
 
 
 
 
 
 
41
  # logger.info(f"API event: {event}")
42
  try:
43
  # Retrieve model type
44
- model_query = event.get('model', '')
45
  model_type = get_model_type(model_query)
46
- logger.info(f"Model query: {model_query}")
47
- logger.info(f"Model type: {model_type}")
48
 
49
  # Decode the base64-encoded image data from the event
50
- image_data = event['body']
51
- if event['isBase64Encoded']:
52
  image_data = base64.b64decode(image_data)
53
 
54
  # Run detection pipeline
55
  result_dict = detection_pipeline(model_type, image_data)
56
- logger.info(f"API Results: {result_dict}")
57
 
58
  return {
59
- 'statusCode': 200,
60
- 'headers': {
61
- 'Content-Type': 'application/json'
62
- },
63
- 'body': json.dumps(result_dict),
64
  }
65
  except Exception as e:
66
- logger.info(f"API Error: {str(e)}")
67
  return {
68
- 'statusCode': 500,
69
- 'headers': {
70
- 'Content-Type': 'application/json'
71
- },
72
- 'body': json.dumps({'error': str(e)}),
73
- }
 
1
+ """
2
+ AWS Lambda function
3
+ """
4
+
5
  import base64
6
  import json
7
  import logging
8
+ from detection import ml_detection, ml_utils
9
 
10
 
11
  logger = logging.getLogger()
 
14
 
15
  # Find ML model type based on string request
16
  def get_model_type(query_string):
17
+ """Find ML model type based on string request"""
18
  # Default ml model type
19
  if query_string == "":
20
  model_type = "facebook/detr-resnet-50"
 
24
  elif "yolos" in query_string:
25
  model_type = "hustvl/" + query_string
26
  else:
27
+ raise Exception("Incorrect model type.")
28
  return model_type
29
 
30
 
31
  # Run detection pipeline: load ML model, perform object detection and return json object
32
  def detection_pipeline(model_type, image_bytes):
33
+ """detection pipeline: load ML model, perform object detection and return json object"""
34
  # Load correct ML model
35
  processor, model = ml_detection.load_model(model_type)
36
 
 
44
 
45
 
46
  def lambda_handler(event, context):
47
+ """
48
+ Lambda handler (proxy integration option unchecked on AWS API Gateway)
49
+
50
+ Args:
51
+ event (dict): The event that triggered the Lambda function.
52
+ context (LambdaContext): Information about the execution environment.
53
+
54
+ Returns:
55
+ dict: The response to be returned from the Lambda function.
56
+ """
57
+
58
  # logger.info(f"API event: {event}")
59
  try:
60
  # Retrieve model type
61
+ model_query = event.get("model", "")
62
  model_type = get_model_type(model_query)
63
+ logger.info("Model query: %s", model_query)
64
+ logger.info("Model type: %s", model_type)
65
 
66
  # Decode the base64-encoded image data from the event
67
+ image_data = event["body"]
68
+ if event["isBase64Encoded"]:
69
  image_data = base64.b64decode(image_data)
70
 
71
  # Run detection pipeline
72
  result_dict = detection_pipeline(model_type, image_data)
73
+ logger.info("API Results: %s", str(result_dict))
74
 
75
  return {
76
+ "statusCode": 200,
77
+ "headers": {"Content-Type": "application/json"},
78
+ "body": json.dumps(result_dict),
 
 
79
  }
80
  except Exception as e:
81
+ logger.info("API Error: %s", str(e))
82
  return {
83
+ "statusCode": 500,
84
+ "headers": {"Content-Type": "application/json"},
85
+ "body": json.dumps({"error": str(e)}),
86
+ }
 
 
tests/test_lambda.py CHANGED
@@ -1,35 +1,41 @@
 
 
 
 
1
  import os
2
  import sys
3
- import pytest
4
  import json
5
  import base64
 
 
 
6
 
7
 
8
  current_dir = os.path.dirname(os.path.abspath(__file__))
9
  parent_dir = os.path.dirname(current_dir)
10
  sys.path.insert(0, os.path.dirname(parent_dir))
11
 
12
- from lambda_function import lambda_handler
13
-
14
 
15
  @pytest.fixture
16
  def event():
 
 
17
  # Get the directory of the current test file
18
  test_dir = os.path.dirname(os.path.abspath(__file__))
19
  # Construct the image path relative to the test directory
20
- image_path = os.path.join(test_dir, 'data', 'savanna.jpg')
21
 
22
  # Read image data
23
- with open(image_path, 'rb') as image_file:
24
  image_data = image_file.read()
25
 
26
  # Encode the image data in base64
27
- encoded_image = base64.b64encode(image_data).decode('utf-8')
28
 
29
  # Prepare the payload
30
  json_event = {
31
- 'body': encoded_image,
32
- 'isBase64Encoded': True,
33
  }
34
 
35
  return json_event
@@ -37,22 +43,25 @@ def event():
37
 
38
  @pytest.fixture
39
  def context():
 
40
  return None
41
 
42
 
43
  def test_lambda_handler(event, context):
 
 
44
  lambda_response = lambda_handler(event, context)
45
  response_data = json.loads(lambda_response["body"])
46
 
47
- print("lambda_response - type",type(lambda_response))
48
  print("lambda_response", lambda_response)
49
  print("response_data - type", type(response_data))
50
  print("response_data", response_data)
51
 
52
  response_keys = list(response_data.keys())
53
- gt_keys = ['scores', 'labels', 'boxes']
54
 
55
  assert lambda_response["statusCode"] == 200
56
  assert set(response_keys) == set(gt_keys), "Response keys do not match ground truth"
57
- assert len(response_data['scores']) == 5
58
- assert len(response_data['labels']) == 5
 
1
+ """
2
+ Testing Lambda handler
3
+ """
4
+
5
  import os
6
  import sys
 
7
  import json
8
  import base64
9
+ import pytest
10
+
11
+ from lambda_function import lambda_handler
12
 
13
 
14
  current_dir = os.path.dirname(os.path.abspath(__file__))
15
  parent_dir = os.path.dirname(current_dir)
16
  sys.path.insert(0, os.path.dirname(parent_dir))
17
 
 
 
18
 
19
  @pytest.fixture
20
  def event():
21
+ """Example json event"""
22
+
23
  # Get the directory of the current test file
24
  test_dir = os.path.dirname(os.path.abspath(__file__))
25
  # Construct the image path relative to the test directory
26
+ image_path = os.path.join(test_dir, "data", "savanna.jpg")
27
 
28
  # Read image data
29
+ with open(image_path, "rb") as image_file:
30
  image_data = image_file.read()
31
 
32
  # Encode the image data in base64
33
+ encoded_image = base64.b64encode(image_data).decode("utf-8")
34
 
35
  # Prepare the payload
36
  json_event = {
37
+ "body": encoded_image,
38
+ "isBase64Encoded": True,
39
  }
40
 
41
  return json_event
 
43
 
44
  @pytest.fixture
45
  def context():
46
+ """Example context"""
47
  return None
48
 
49
 
50
  def test_lambda_handler(event, context):
51
+ """Tests lambda handler"""
52
+
53
  lambda_response = lambda_handler(event, context)
54
  response_data = json.loads(lambda_response["body"])
55
 
56
+ print("lambda_response - type", type(lambda_response))
57
  print("lambda_response", lambda_response)
58
  print("response_data - type", type(response_data))
59
  print("response_data", response_data)
60
 
61
  response_keys = list(response_data.keys())
62
+ gt_keys = ["scores", "labels", "boxes"]
63
 
64
  assert lambda_response["statusCode"] == 200
65
  assert set(response_keys) == set(gt_keys), "Response keys do not match ground truth"
66
+ assert len(response_data["scores"]) == 5
67
+ assert len(response_data["labels"]) == 5
tests/test_ml_detection.py CHANGED
@@ -1,31 +1,46 @@
1
- from detection import ml_detection
 
 
 
2
  import os
3
  import pytest
4
 
 
 
5
 
6
  # Test model loading
7
- @pytest.mark.parametrize("test_model_uri",[
8
- ("facebook/detr-resnet-50"),
9
- ("facebook/detr-resnet-101"),
10
- ])
 
 
 
11
  def test_load_model(test_model_uri):
 
 
12
  processor, model = ml_detection.load_model(test_model_uri)
13
  assert processor is not None
14
  assert model is not None
15
 
16
 
17
  # Test image detection
18
- @pytest.mark.parametrize("test_model_uri", [
19
- ("facebook/detr-resnet-50"),
20
- ("facebook/detr-resnet-101"),
21
- ])
 
 
 
22
  def test_object_detection(test_model_uri):
 
 
23
  processor, model = ml_detection.load_model(test_model_uri)
24
 
25
  # Get the directory of the current test file
26
  test_dir = os.path.dirname(os.path.abspath(__file__))
27
  # Construct the image path relative to the test directory
28
- image_path = os.path.join(test_dir, 'data', 'savanna.jpg')
29
 
30
  with open(image_path, "rb") as f:
31
  image_bytes = f.read()
 
1
+ """
2
+ Testing Detection module
3
+ """
4
+
5
  import os
6
  import pytest
7
 
8
+ from detection import ml_detection
9
+
10
 
11
  # Test model loading
12
+ @pytest.mark.parametrize(
13
+ "test_model_uri",
14
+ [
15
+ ("facebook/detr-resnet-50"),
16
+ ("facebook/detr-resnet-101"),
17
+ ],
18
+ )
19
  def test_load_model(test_model_uri):
20
+ """Testing model loading"""
21
+
22
  processor, model = ml_detection.load_model(test_model_uri)
23
  assert processor is not None
24
  assert model is not None
25
 
26
 
27
  # Test image detection
28
+ @pytest.mark.parametrize(
29
+ "test_model_uri",
30
+ [
31
+ ("facebook/detr-resnet-50"),
32
+ ("facebook/detr-resnet-101"),
33
+ ],
34
+ )
35
  def test_object_detection(test_model_uri):
36
+ """Testing object detection function"""
37
+
38
  processor, model = ml_detection.load_model(test_model_uri)
39
 
40
  # Get the directory of the current test file
41
  test_dir = os.path.dirname(os.path.abspath(__file__))
42
  # Construct the image path relative to the test directory
43
+ image_path = os.path.join(test_dir, "data", "savanna.jpg")
44
 
45
  with open(image_path, "rb") as f:
46
  image_bytes = f.read()
tests/test_ml_utils.py CHANGED
@@ -1,11 +1,16 @@
1
- from detection import ml_utils
 
 
 
2
  import torch
3
- import json
4
 
5
 
6
  # Test dictionary conversion
7
  def test_convert_tensor_dict_to_json():
8
- my_dict = {'scores': torch.tensor([1, 2, 3])}
9
- my_list_gt = {'scores': [1, 2, 3]}
 
 
10
  my_list = ml_utils.convert_tensor_dict_to_json(my_dict)
11
  assert my_list == my_list_gt
 
1
+ """
2
+ Testing Detection utility functions
3
+ """
4
+
5
  import torch
6
+ from detection import ml_utils
7
 
8
 
9
  # Test dictionary conversion
10
  def test_convert_tensor_dict_to_json():
11
+ """Testing tensor to dict conversions"""
12
+
13
+ my_dict = {"scores": torch.tensor([1, 2, 3])}
14
+ my_list_gt = {"scores": [1, 2, 3]}
15
  my_list = ml_utils.convert_tensor_dict_to_json(my_dict)
16
  assert my_list == my_list_gt