""" Object Detection module """ import io import torch from transformers import DetrImageProcessor, DetrForObjectDetection from transformers import YolosImageProcessor, YolosForObjectDetection from PIL import Image # Load transformer-based model (Yolos or DETR) def load_model(model_uri: str): """ Load Transformer model - Doc DETR: https://huggingface.co./docs/transformers/en/model_doc/detr - Doc Yolos: https://huggingface.co./docs/transformers/en/model_doc/yolos """ if "detr" in model_uri: # you can specify the revision tag if you don't want the timm dependency processor = DetrImageProcessor.from_pretrained(model_uri, revision="no_timm") model = DetrForObjectDetection.from_pretrained(model_uri, revision="no_timm") elif "yolos" in model_uri: processor = YolosImageProcessor.from_pretrained(model_uri) model = YolosForObjectDetection.from_pretrained(model_uri) else: processor = None model = None return processor, model def object_detection(processor, model, image_bytes): """Perform object detection task""" print("Object detection prediction...") # url = "http://images.cocodataset.org/val2017/000000039769.jpg" # image = Image.open(requests.get(url, stream=True).raw) img = Image.open(io.BytesIO(image_bytes)) inputs = processor(images=img, return_tensors="pt") # print('inputs', inputs) outputs = model(**inputs) # convert outputs (bounding boxes and class logits) to COCO API # let's only keep detections with score > 0.9 target_sizes = torch.tensor([img.size[::-1]]) results = processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=0.9 )[0] return results