cmon2 / handler.py
sheshkar's picture
Update handler.py
f867f82 verified
raw
history blame
2.13 kB
from typing import Dict, Any
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from PIL import Image
import requests
from io import BytesIO
import torch
class EndpointHandler():
def __init__(self, path=""):
# Ładowanie modelu i procesora
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(path).to(self.device)
self.processor = AutoProcessor.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Sprawdzamy, czy dane wejściowe zawierają wymagane pola
if "inputs" not in data:
return {"error": "Payload must contain 'inputs' key with 'image' and 'text'."}
inputs = data["inputs"]
if "image" not in inputs or "text" not in inputs:
return {"error": "Payload must contain 'image' (base64 or URL) and 'text' (queries)."}
# Pobieramy obraz (URL lub Base64)
image_data = inputs["image"]
if image_data.startswith("http"): # URL
response = requests.get(image_data)
image = Image.open(BytesIO(response.content))
else:
return {"error": "Handler currently supports only URL-based images."}
# Pobieramy tekst zapytań
text_queries = inputs["text"]
if isinstance(text_queries, list):
text_queries = ". ".join([t.lower().strip() + "." for t in text_queries])
# Przygotowujemy dane wejściowe
processed_inputs = self.processor(images=image, text=text_queries, return_tensors="pt").to(self.device)
# Przeprowadzamy inferencję
with torch.no_grad():
outputs = self.model(**processed_inputs)
# Post-process wyników
results = self.processor.post_process_grounded_object_detection(
outputs,
processed_inputs.input_ids,
box_threshold=0.4,
text_threshold=0.3,
target_sizes=[image.size[::-1]]
)
# Zwracamy wyniki
return {"detections": results}