sheshkar commited on
Commit
3bd162c
·
verified ·
1 Parent(s): 95aee8d

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +45 -0
  2. requirements.txt +10 -0
handler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
3
+ from PIL import Image
4
+ import torch
5
+
6
+ class EndpointHandler():
7
+ def __init__(self, path=""):
8
+ # Ładowanie modelu i procesora
9
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ self.model = AutoModelForZeroShotObjectDetection.from_pretrained(path).to(self.device)
11
+ self.processor = AutoProcessor.from_pretrained(path)
12
+
13
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
14
+ # Sprawdź, czy dane wejściowe zawierają wymagane pola
15
+ if "image" not in data or "text" not in data:
16
+ return {"error": "Payload must contain 'image' (base64 or URL) and 'text' (queries)."}
17
+
18
+ # Załaduj obraz
19
+ image = Image.open(data["image"]) if isinstance(data["image"], str) else data["image"]
20
+
21
+ # Pobierz teksty zapytań
22
+ text_queries = data["text"]
23
+ if isinstance(text_queries, list):
24
+ text_queries = ". ".join([t.lower().strip() + "." for t in text_queries])
25
+
26
+ # Przygotuj dane wejściowe
27
+ inputs = self.processor(images=image, text=text_queries, return_tensors="pt").to(self.device)
28
+
29
+ # Przeprowadź inferencję
30
+ with torch.no_grad():
31
+ outputs = self.model(**inputs)
32
+
33
+ # Post-process detekcji
34
+ results = self.processor.post_process_grounded_object_detection(
35
+ outputs,
36
+ inputs.input_ids,
37
+ box_threshold=0.4,
38
+ text_threshold=0.3,
39
+ target_sizes=[image.size[::-1]]
40
+ )
41
+
42
+ # Przygotuj wynik
43
+ return {
44
+ "detections": results
45
+ }
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ addict
5
+ yapf
6
+ timm
7
+ numpy
8
+ opencv-python
9
+ supervision>=0.22.0
10
+ pillow