Spaces:
Sleeping
Sleeping
Clement Vachet
commited on
Commit
·
67f4974
1
Parent(s):
f764d1c
Improve code based on pylint and black suggestions
Browse files- detection/ml_detection.py +17 -9
- detection/ml_utils.py +12 -5
- inference_api.py +33 -18
- inference_direct.py +11 -8
- lambda_function.py +33 -20
- tests/test_lambda.py +21 -12
- tests/test_ml_detection.py +25 -10
- tests/test_ml_utils.py +9 -4
detection/ml_detection.py
CHANGED
@@ -1,15 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import DetrImageProcessor, DetrForObjectDetection
|
2 |
from transformers import YolosImageProcessor, YolosForObjectDetection
|
3 |
-
import torch
|
4 |
from PIL import Image
|
5 |
-
import io
|
6 |
|
7 |
|
8 |
# Load transformer-based model (Yolos or DETR)
|
9 |
def load_model(model_uri: str):
|
10 |
-
"""
|
11 |
-
|
12 |
-
|
|
|
|
|
13 |
|
14 |
if "detr" in model_uri:
|
15 |
# you can specify the revision tag if you don't want the timm dependency
|
@@ -27,9 +33,9 @@ def load_model(model_uri: str):
|
|
27 |
def object_detection(processor, model, image_bytes):
|
28 |
"""Perform object detection task"""
|
29 |
|
30 |
-
print(
|
31 |
-
#url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
32 |
-
#image = Image.open(requests.get(url, stream=True).raw)
|
33 |
|
34 |
img = Image.open(io.BytesIO(image_bytes))
|
35 |
inputs = processor(images=img, return_tensors="pt")
|
@@ -39,5 +45,7 @@ def object_detection(processor, model, image_bytes):
|
|
39 |
# convert outputs (bounding boxes and class logits) to COCO API
|
40 |
# let's only keep detections with score > 0.9
|
41 |
target_sizes = torch.tensor([img.size[::-1]])
|
42 |
-
results = processor.post_process_object_detection(
|
|
|
|
|
43 |
return results
|
|
|
1 |
+
"""
|
2 |
+
Object Detection module
|
3 |
+
"""
|
4 |
+
|
5 |
+
import io
|
6 |
+
import torch
|
7 |
from transformers import DetrImageProcessor, DetrForObjectDetection
|
8 |
from transformers import YolosImageProcessor, YolosForObjectDetection
|
|
|
9 |
from PIL import Image
|
|
|
10 |
|
11 |
|
12 |
# Load transformer-based model (Yolos or DETR)
|
13 |
def load_model(model_uri: str):
|
14 |
+
"""
|
15 |
+
Load Transformer model
|
16 |
+
- Doc DETR: https://huggingface.co/docs/transformers/en/model_doc/detr
|
17 |
+
- Doc Yolos: https://huggingface.co/docs/transformers/en/model_doc/yolos
|
18 |
+
"""
|
19 |
|
20 |
if "detr" in model_uri:
|
21 |
# you can specify the revision tag if you don't want the timm dependency
|
|
|
33 |
def object_detection(processor, model, image_bytes):
|
34 |
"""Perform object detection task"""
|
35 |
|
36 |
+
print("Object detection prediction...")
|
37 |
+
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
38 |
+
# image = Image.open(requests.get(url, stream=True).raw)
|
39 |
|
40 |
img = Image.open(io.BytesIO(image_bytes))
|
41 |
inputs = processor(images=img, return_tensors="pt")
|
|
|
45 |
# convert outputs (bounding boxes and class logits) to COCO API
|
46 |
# let's only keep detections with score > 0.9
|
47 |
target_sizes = torch.tensor([img.size[::-1]])
|
48 |
+
results = processor.post_process_object_detection(
|
49 |
+
outputs, target_sizes=target_sizes, threshold=0.9
|
50 |
+
)[0]
|
51 |
return results
|
detection/ml_utils.py
CHANGED
@@ -1,17 +1,24 @@
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
-
import json
|
3 |
|
4 |
|
5 |
def convert_tensor_dict_to_json(tensor_dict):
|
6 |
"""Convert a dictionary of tensors to a JSON-serializable dictionary."""
|
7 |
|
8 |
-
json_dict = {
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Convert to JSON string
|
12 |
# json_str = json.dumps(json_dict)
|
13 |
# print(json_str)
|
14 |
|
15 |
return json_dict
|
16 |
-
|
17 |
-
|
|
|
1 |
+
"""
|
2 |
+
Utility functions for detection module
|
3 |
+
"""
|
4 |
+
|
5 |
import torch
|
|
|
6 |
|
7 |
|
8 |
def convert_tensor_dict_to_json(tensor_dict):
|
9 |
"""Convert a dictionary of tensors to a JSON-serializable dictionary."""
|
10 |
|
11 |
+
json_dict = {
|
12 |
+
key: (
|
13 |
+
value.detach().cpu().numpy().tolist()
|
14 |
+
if isinstance(value, torch.Tensor)
|
15 |
+
else value
|
16 |
+
)
|
17 |
+
for key, value in tensor_dict.items()
|
18 |
+
}
|
19 |
|
20 |
# Convert to JSON string
|
21 |
# json_str = json.dumps(json_dict)
|
22 |
# print(json_str)
|
23 |
|
24 |
return json_dict
|
|
|
|
inference_api.py
CHANGED
@@ -1,7 +1,11 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
3 |
import sys
|
4 |
import base64
|
|
|
|
|
5 |
|
6 |
# Default examples
|
7 |
# api = "http://localhost:8080/2015-03-31/functions/function/invocations"
|
@@ -12,15 +16,26 @@ def arg_parser():
|
|
12 |
"""Parse arguments"""
|
13 |
|
14 |
# Create an ArgumentParser object
|
15 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
16 |
# Add arguments
|
17 |
-
parser.add_argument(
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
return parser
|
25 |
|
26 |
|
@@ -30,7 +45,7 @@ def main(args=None):
|
|
30 |
args = arg_parser().parse_args(args)
|
31 |
# Use the arguments
|
32 |
if args.verbose:
|
33 |
-
print(f
|
34 |
|
35 |
# Retrieve model type
|
36 |
if args.model:
|
@@ -39,27 +54,27 @@ def main(args=None):
|
|
39 |
model_name = ""
|
40 |
|
41 |
# Load image
|
42 |
-
with open(args.file,
|
43 |
image_data = image_file.read()
|
44 |
|
45 |
# Encode the image data in base64
|
46 |
-
encoded_image = base64.b64encode(image_data).decode(
|
47 |
|
48 |
# Prepare the payload
|
49 |
payload = {
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
}
|
54 |
|
55 |
# Send request to API
|
56 |
# Option 'files': A dictionary of files to send to the specified url
|
57 |
# response = requests.post(args.api, files={'image': image_data})
|
58 |
# Option 'json': A JSON object to send to the specified url
|
59 |
-
response = requests.post(args.api, json =
|
60 |
|
61 |
if response.status_code == 200:
|
62 |
-
print(
|
63 |
# Process the response
|
64 |
# processed_data = json.loads(response.content)
|
65 |
# print('processed_data', processed_data)
|
|
|
1 |
+
"""
|
2 |
+
Object detection - command line inference via API
|
3 |
+
"""
|
4 |
+
|
5 |
import sys
|
6 |
import base64
|
7 |
+
import argparse
|
8 |
+
import requests
|
9 |
|
10 |
# Default examples
|
11 |
# api = "http://localhost:8080/2015-03-31/functions/function/invocations"
|
|
|
16 |
"""Parse arguments"""
|
17 |
|
18 |
# Create an ArgumentParser object
|
19 |
+
parser = argparse.ArgumentParser(
|
20 |
+
description="Object detection inference via API call"
|
21 |
+
)
|
22 |
# Add arguments
|
23 |
+
parser.add_argument(
|
24 |
+
"--api", type=str, help="URL to server API (with endpoint)", required=True
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--file", type=str, help="Path to the input image file", required=True
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--model",
|
31 |
+
type=str,
|
32 |
+
choices=["detr-resnet-50", "detr-resnet-101", "yolos-tiny", "yolos-small"],
|
33 |
+
help="Model type",
|
34 |
+
required=False,
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"-v", "--verbose", action="store_true", help="Increase output verbosity"
|
38 |
+
)
|
39 |
return parser
|
40 |
|
41 |
|
|
|
45 |
args = arg_parser().parse_args(args)
|
46 |
# Use the arguments
|
47 |
if args.verbose:
|
48 |
+
print(f"Input file: {args.file}")
|
49 |
|
50 |
# Retrieve model type
|
51 |
if args.model:
|
|
|
54 |
model_name = ""
|
55 |
|
56 |
# Load image
|
57 |
+
with open(args.file, "rb") as image_file:
|
58 |
image_data = image_file.read()
|
59 |
|
60 |
# Encode the image data in base64
|
61 |
+
encoded_image = base64.b64encode(image_data).decode("utf-8")
|
62 |
|
63 |
# Prepare the payload
|
64 |
payload = {
|
65 |
+
"body": encoded_image,
|
66 |
+
"isBase64Encoded": True,
|
67 |
+
"model": model_name,
|
68 |
}
|
69 |
|
70 |
# Send request to API
|
71 |
# Option 'files': A dictionary of files to send to the specified url
|
72 |
# response = requests.post(args.api, files={'image': image_data})
|
73 |
# Option 'json': A JSON object to send to the specified url
|
74 |
+
response = requests.post(args.api, json=payload, timeout=60)
|
75 |
|
76 |
if response.status_code == 200:
|
77 |
+
print("Detection Results:")
|
78 |
# Process the response
|
79 |
# processed_data = json.loads(response.content)
|
80 |
# print('processed_data', processed_data)
|
inference_direct.py
CHANGED
@@ -1,10 +1,13 @@
|
|
|
|
|
|
|
|
|
|
1 |
from detection import ml_detection, ml_utils
|
2 |
-
import json
|
3 |
-
from PIL import Image
|
4 |
|
5 |
|
6 |
# Run detection pipeline: load ML model, perform object detection and return json object
|
7 |
def detection_pipeline(model_type, image_bytes):
|
|
|
8 |
# Load correct ML model
|
9 |
detr_processor, detr_model = ml_detection.load_model(model_type)
|
10 |
|
@@ -18,15 +21,16 @@ def detection_pipeline(model_type, image_bytes):
|
|
18 |
|
19 |
|
20 |
def main():
|
21 |
-
|
|
|
22 |
|
23 |
model_type = "facebook/detr-resnet-50"
|
24 |
-
image_path =
|
25 |
|
26 |
# Reading image file as image_bytes (similar to API request)
|
27 |
-
print(
|
28 |
-
with open(image_path,
|
29 |
-
|
30 |
|
31 |
result_json = detection_pipeline(model_type, image_bytes)
|
32 |
print("result_json:", result_json)
|
@@ -34,4 +38,3 @@ def main():
|
|
34 |
|
35 |
if __name__ == "__main__":
|
36 |
main()
|
37 |
-
|
|
|
1 |
+
"""
|
2 |
+
Direct inference with hard-coded data
|
3 |
+
"""
|
4 |
+
|
5 |
from detection import ml_detection, ml_utils
|
|
|
|
|
6 |
|
7 |
|
8 |
# Run detection pipeline: load ML model, perform object detection and return json object
|
9 |
def detection_pipeline(model_type, image_bytes):
|
10 |
+
"""Detection pipeline: load ML model, perform object detection and return json object"""
|
11 |
# Load correct ML model
|
12 |
detr_processor, detr_model = ml_detection.load_model(model_type)
|
13 |
|
|
|
21 |
|
22 |
|
23 |
def main():
|
24 |
+
"""Main function"""
|
25 |
+
print("Main function")
|
26 |
|
27 |
model_type = "facebook/detr-resnet-50"
|
28 |
+
image_path = "./samples/boats.jpg"
|
29 |
|
30 |
# Reading image file as image_bytes (similar to API request)
|
31 |
+
print("Reading image file...")
|
32 |
+
with open(image_path, "rb") as image_file:
|
33 |
+
image_bytes = image_file.read()
|
34 |
|
35 |
result_json = detection_pipeline(model_type, image_bytes)
|
36 |
print("result_json:", result_json)
|
|
|
38 |
|
39 |
if __name__ == "__main__":
|
40 |
main()
|
|
lambda_function.py
CHANGED
@@ -1,7 +1,11 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
import base64
|
3 |
import json
|
4 |
import logging
|
|
|
5 |
|
6 |
|
7 |
logger = logging.getLogger()
|
@@ -10,6 +14,7 @@ logger.setLevel(logging.INFO)
|
|
10 |
|
11 |
# Find ML model type based on string request
|
12 |
def get_model_type(query_string):
|
|
|
13 |
# Default ml model type
|
14 |
if query_string == "":
|
15 |
model_type = "facebook/detr-resnet-50"
|
@@ -19,12 +24,13 @@ def get_model_type(query_string):
|
|
19 |
elif "yolos" in query_string:
|
20 |
model_type = "hustvl/" + query_string
|
21 |
else:
|
22 |
-
raise Exception(
|
23 |
return model_type
|
24 |
|
25 |
|
26 |
# Run detection pipeline: load ML model, perform object detection and return json object
|
27 |
def detection_pipeline(model_type, image_bytes):
|
|
|
28 |
# Load correct ML model
|
29 |
processor, model = ml_detection.load_model(model_type)
|
30 |
|
@@ -38,36 +44,43 @@ def detection_pipeline(model_type, image_bytes):
|
|
38 |
|
39 |
|
40 |
def lambda_handler(event, context):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
# logger.info(f"API event: {event}")
|
42 |
try:
|
43 |
# Retrieve model type
|
44 |
-
model_query = event.get(
|
45 |
model_type = get_model_type(model_query)
|
46 |
-
logger.info(
|
47 |
-
logger.info(
|
48 |
|
49 |
# Decode the base64-encoded image data from the event
|
50 |
-
image_data = event[
|
51 |
-
if event[
|
52 |
image_data = base64.b64decode(image_data)
|
53 |
|
54 |
# Run detection pipeline
|
55 |
result_dict = detection_pipeline(model_type, image_data)
|
56 |
-
logger.info(
|
57 |
|
58 |
return {
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
},
|
63 |
-
'body': json.dumps(result_dict),
|
64 |
}
|
65 |
except Exception as e:
|
66 |
-
logger.info(
|
67 |
return {
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
'body': json.dumps({'error': str(e)}),
|
73 |
-
}
|
|
|
1 |
+
"""
|
2 |
+
AWS Lambda function
|
3 |
+
"""
|
4 |
+
|
5 |
import base64
|
6 |
import json
|
7 |
import logging
|
8 |
+
from detection import ml_detection, ml_utils
|
9 |
|
10 |
|
11 |
logger = logging.getLogger()
|
|
|
14 |
|
15 |
# Find ML model type based on string request
|
16 |
def get_model_type(query_string):
|
17 |
+
"""Find ML model type based on string request"""
|
18 |
# Default ml model type
|
19 |
if query_string == "":
|
20 |
model_type = "facebook/detr-resnet-50"
|
|
|
24 |
elif "yolos" in query_string:
|
25 |
model_type = "hustvl/" + query_string
|
26 |
else:
|
27 |
+
raise Exception("Incorrect model type.")
|
28 |
return model_type
|
29 |
|
30 |
|
31 |
# Run detection pipeline: load ML model, perform object detection and return json object
|
32 |
def detection_pipeline(model_type, image_bytes):
|
33 |
+
"""detection pipeline: load ML model, perform object detection and return json object"""
|
34 |
# Load correct ML model
|
35 |
processor, model = ml_detection.load_model(model_type)
|
36 |
|
|
|
44 |
|
45 |
|
46 |
def lambda_handler(event, context):
|
47 |
+
"""
|
48 |
+
Lambda handler (proxy integration option unchecked on AWS API Gateway)
|
49 |
+
|
50 |
+
Args:
|
51 |
+
event (dict): The event that triggered the Lambda function.
|
52 |
+
context (LambdaContext): Information about the execution environment.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
dict: The response to be returned from the Lambda function.
|
56 |
+
"""
|
57 |
+
|
58 |
# logger.info(f"API event: {event}")
|
59 |
try:
|
60 |
# Retrieve model type
|
61 |
+
model_query = event.get("model", "")
|
62 |
model_type = get_model_type(model_query)
|
63 |
+
logger.info("Model query: %s", model_query)
|
64 |
+
logger.info("Model type: %s", model_type)
|
65 |
|
66 |
# Decode the base64-encoded image data from the event
|
67 |
+
image_data = event["body"]
|
68 |
+
if event["isBase64Encoded"]:
|
69 |
image_data = base64.b64decode(image_data)
|
70 |
|
71 |
# Run detection pipeline
|
72 |
result_dict = detection_pipeline(model_type, image_data)
|
73 |
+
logger.info("API Results: %s", str(result_dict))
|
74 |
|
75 |
return {
|
76 |
+
"statusCode": 200,
|
77 |
+
"headers": {"Content-Type": "application/json"},
|
78 |
+
"body": json.dumps(result_dict),
|
|
|
|
|
79 |
}
|
80 |
except Exception as e:
|
81 |
+
logger.info("API Error: %s", str(e))
|
82 |
return {
|
83 |
+
"statusCode": 500,
|
84 |
+
"headers": {"Content-Type": "application/json"},
|
85 |
+
"body": json.dumps({"error": str(e)}),
|
86 |
+
}
|
|
|
|
tests/test_lambda.py
CHANGED
@@ -1,35 +1,41 @@
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
import pytest
|
4 |
import json
|
5 |
import base64
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
9 |
parent_dir = os.path.dirname(current_dir)
|
10 |
sys.path.insert(0, os.path.dirname(parent_dir))
|
11 |
|
12 |
-
from lambda_function import lambda_handler
|
13 |
-
|
14 |
|
15 |
@pytest.fixture
|
16 |
def event():
|
|
|
|
|
17 |
# Get the directory of the current test file
|
18 |
test_dir = os.path.dirname(os.path.abspath(__file__))
|
19 |
# Construct the image path relative to the test directory
|
20 |
-
image_path = os.path.join(test_dir,
|
21 |
|
22 |
# Read image data
|
23 |
-
with open(image_path,
|
24 |
image_data = image_file.read()
|
25 |
|
26 |
# Encode the image data in base64
|
27 |
-
encoded_image = base64.b64encode(image_data).decode(
|
28 |
|
29 |
# Prepare the payload
|
30 |
json_event = {
|
31 |
-
|
32 |
-
|
33 |
}
|
34 |
|
35 |
return json_event
|
@@ -37,22 +43,25 @@ def event():
|
|
37 |
|
38 |
@pytest.fixture
|
39 |
def context():
|
|
|
40 |
return None
|
41 |
|
42 |
|
43 |
def test_lambda_handler(event, context):
|
|
|
|
|
44 |
lambda_response = lambda_handler(event, context)
|
45 |
response_data = json.loads(lambda_response["body"])
|
46 |
|
47 |
-
print("lambda_response - type",type(lambda_response))
|
48 |
print("lambda_response", lambda_response)
|
49 |
print("response_data - type", type(response_data))
|
50 |
print("response_data", response_data)
|
51 |
|
52 |
response_keys = list(response_data.keys())
|
53 |
-
gt_keys = [
|
54 |
|
55 |
assert lambda_response["statusCode"] == 200
|
56 |
assert set(response_keys) == set(gt_keys), "Response keys do not match ground truth"
|
57 |
-
assert len(response_data[
|
58 |
-
assert len(response_data[
|
|
|
1 |
+
"""
|
2 |
+
Testing Lambda handler
|
3 |
+
"""
|
4 |
+
|
5 |
import os
|
6 |
import sys
|
|
|
7 |
import json
|
8 |
import base64
|
9 |
+
import pytest
|
10 |
+
|
11 |
+
from lambda_function import lambda_handler
|
12 |
|
13 |
|
14 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
15 |
parent_dir = os.path.dirname(current_dir)
|
16 |
sys.path.insert(0, os.path.dirname(parent_dir))
|
17 |
|
|
|
|
|
18 |
|
19 |
@pytest.fixture
|
20 |
def event():
|
21 |
+
"""Example json event"""
|
22 |
+
|
23 |
# Get the directory of the current test file
|
24 |
test_dir = os.path.dirname(os.path.abspath(__file__))
|
25 |
# Construct the image path relative to the test directory
|
26 |
+
image_path = os.path.join(test_dir, "data", "savanna.jpg")
|
27 |
|
28 |
# Read image data
|
29 |
+
with open(image_path, "rb") as image_file:
|
30 |
image_data = image_file.read()
|
31 |
|
32 |
# Encode the image data in base64
|
33 |
+
encoded_image = base64.b64encode(image_data).decode("utf-8")
|
34 |
|
35 |
# Prepare the payload
|
36 |
json_event = {
|
37 |
+
"body": encoded_image,
|
38 |
+
"isBase64Encoded": True,
|
39 |
}
|
40 |
|
41 |
return json_event
|
|
|
43 |
|
44 |
@pytest.fixture
|
45 |
def context():
|
46 |
+
"""Example context"""
|
47 |
return None
|
48 |
|
49 |
|
50 |
def test_lambda_handler(event, context):
|
51 |
+
"""Tests lambda handler"""
|
52 |
+
|
53 |
lambda_response = lambda_handler(event, context)
|
54 |
response_data = json.loads(lambda_response["body"])
|
55 |
|
56 |
+
print("lambda_response - type", type(lambda_response))
|
57 |
print("lambda_response", lambda_response)
|
58 |
print("response_data - type", type(response_data))
|
59 |
print("response_data", response_data)
|
60 |
|
61 |
response_keys = list(response_data.keys())
|
62 |
+
gt_keys = ["scores", "labels", "boxes"]
|
63 |
|
64 |
assert lambda_response["statusCode"] == 200
|
65 |
assert set(response_keys) == set(gt_keys), "Response keys do not match ground truth"
|
66 |
+
assert len(response_data["scores"]) == 5
|
67 |
+
assert len(response_data["labels"]) == 5
|
tests/test_ml_detection.py
CHANGED
@@ -1,31 +1,46 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
import os
|
3 |
import pytest
|
4 |
|
|
|
|
|
5 |
|
6 |
# Test model loading
|
7 |
-
@pytest.mark.parametrize(
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
11 |
def test_load_model(test_model_uri):
|
|
|
|
|
12 |
processor, model = ml_detection.load_model(test_model_uri)
|
13 |
assert processor is not None
|
14 |
assert model is not None
|
15 |
|
16 |
|
17 |
# Test image detection
|
18 |
-
@pytest.mark.parametrize(
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
22 |
def test_object_detection(test_model_uri):
|
|
|
|
|
23 |
processor, model = ml_detection.load_model(test_model_uri)
|
24 |
|
25 |
# Get the directory of the current test file
|
26 |
test_dir = os.path.dirname(os.path.abspath(__file__))
|
27 |
# Construct the image path relative to the test directory
|
28 |
-
image_path = os.path.join(test_dir,
|
29 |
|
30 |
with open(image_path, "rb") as f:
|
31 |
image_bytes = f.read()
|
|
|
1 |
+
"""
|
2 |
+
Testing Detection module
|
3 |
+
"""
|
4 |
+
|
5 |
import os
|
6 |
import pytest
|
7 |
|
8 |
+
from detection import ml_detection
|
9 |
+
|
10 |
|
11 |
# Test model loading
|
12 |
+
@pytest.mark.parametrize(
|
13 |
+
"test_model_uri",
|
14 |
+
[
|
15 |
+
("facebook/detr-resnet-50"),
|
16 |
+
("facebook/detr-resnet-101"),
|
17 |
+
],
|
18 |
+
)
|
19 |
def test_load_model(test_model_uri):
|
20 |
+
"""Testing model loading"""
|
21 |
+
|
22 |
processor, model = ml_detection.load_model(test_model_uri)
|
23 |
assert processor is not None
|
24 |
assert model is not None
|
25 |
|
26 |
|
27 |
# Test image detection
|
28 |
+
@pytest.mark.parametrize(
|
29 |
+
"test_model_uri",
|
30 |
+
[
|
31 |
+
("facebook/detr-resnet-50"),
|
32 |
+
("facebook/detr-resnet-101"),
|
33 |
+
],
|
34 |
+
)
|
35 |
def test_object_detection(test_model_uri):
|
36 |
+
"""Testing object detection function"""
|
37 |
+
|
38 |
processor, model = ml_detection.load_model(test_model_uri)
|
39 |
|
40 |
# Get the directory of the current test file
|
41 |
test_dir = os.path.dirname(os.path.abspath(__file__))
|
42 |
# Construct the image path relative to the test directory
|
43 |
+
image_path = os.path.join(test_dir, "data", "savanna.jpg")
|
44 |
|
45 |
with open(image_path, "rb") as f:
|
46 |
image_bytes = f.read()
|
tests/test_ml_utils.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
import torch
|
3 |
-
import
|
4 |
|
5 |
|
6 |
# Test dictionary conversion
|
7 |
def test_convert_tensor_dict_to_json():
|
8 |
-
|
9 |
-
|
|
|
|
|
10 |
my_list = ml_utils.convert_tensor_dict_to_json(my_dict)
|
11 |
assert my_list == my_list_gt
|
|
|
1 |
+
"""
|
2 |
+
Testing Detection utility functions
|
3 |
+
"""
|
4 |
+
|
5 |
import torch
|
6 |
+
from detection import ml_utils
|
7 |
|
8 |
|
9 |
# Test dictionary conversion
|
10 |
def test_convert_tensor_dict_to_json():
|
11 |
+
"""Testing tensor to dict conversions"""
|
12 |
+
|
13 |
+
my_dict = {"scores": torch.tensor([1, 2, 3])}
|
14 |
+
my_list_gt = {"scores": [1, 2, 3]}
|
15 |
my_list = ml_utils.convert_tensor_dict_to_json(my_dict)
|
16 |
assert my_list == my_list_gt
|