|
from typing import Dict, Any |
|
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import torch |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(path).to(self.device) |
|
self.processor = AutoProcessor.from_pretrained(path) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
if "inputs" not in data: |
|
return {"error": "Payload must contain 'inputs' key with 'image' and 'text'."} |
|
|
|
inputs = data["inputs"] |
|
if "image" not in inputs or "text" not in inputs: |
|
return {"error": "Payload must contain 'image' (base64 or URL) and 'text' (queries)."} |
|
|
|
|
|
image_data = inputs["image"] |
|
if image_data.startswith("http"): |
|
response = requests.get(image_data) |
|
image = Image.open(BytesIO(response.content)) |
|
else: |
|
return {"error": "Handler currently supports only URL-based images."} |
|
|
|
|
|
text_queries = inputs["text"] |
|
if isinstance(text_queries, list): |
|
text_queries = ". ".join([t.lower().strip() + "." for t in text_queries]) |
|
|
|
|
|
processed_inputs = self.processor(images=image, text=text_queries, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**processed_inputs) |
|
|
|
|
|
results = self.processor.post_process_grounded_object_detection( |
|
outputs, |
|
processed_inputs.input_ids, |
|
box_threshold=0.4, |
|
text_threshold=0.3, |
|
target_sizes=[image.size[::-1]] |
|
) |
|
|
|
|
|
return {"detections": results} |
|
|