sheshkar commited on
Commit
f867f82
verified
1 Parent(s): 3bd162c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +54 -45
handler.py CHANGED
@@ -1,45 +1,54 @@
1
- from typing import Dict, Any
2
- from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
3
- from PIL import Image
4
- import torch
5
-
6
- class EndpointHandler():
7
- def __init__(self, path=""):
8
- # 艁adowanie modelu i procesora
9
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
- self.model = AutoModelForZeroShotObjectDetection.from_pretrained(path).to(self.device)
11
- self.processor = AutoProcessor.from_pretrained(path)
12
-
13
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
14
- # Sprawd藕, czy dane wej艣ciowe zawieraj膮 wymagane pola
15
- if "image" not in data or "text" not in data:
16
- return {"error": "Payload must contain 'image' (base64 or URL) and 'text' (queries)."}
17
-
18
- # Za艂aduj obraz
19
- image = Image.open(data["image"]) if isinstance(data["image"], str) else data["image"]
20
-
21
- # Pobierz teksty zapyta艅
22
- text_queries = data["text"]
23
- if isinstance(text_queries, list):
24
- text_queries = ". ".join([t.lower().strip() + "." for t in text_queries])
25
-
26
- # Przygotuj dane wej艣ciowe
27
- inputs = self.processor(images=image, text=text_queries, return_tensors="pt").to(self.device)
28
-
29
- # Przeprowad藕 inferencj臋
30
- with torch.no_grad():
31
- outputs = self.model(**inputs)
32
-
33
- # Post-process detekcji
34
- results = self.processor.post_process_grounded_object_detection(
35
- outputs,
36
- inputs.input_ids,
37
- box_threshold=0.4,
38
- text_threshold=0.3,
39
- target_sizes=[image.size[::-1]]
40
- )
41
-
42
- # Przygotuj wynik
43
- return {
44
- "detections": results
45
- }
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+ import torch
7
+
8
+ class EndpointHandler():
9
+ def __init__(self, path=""):
10
+ # 艁adowanie modelu i procesora
11
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ self.model = AutoModelForZeroShotObjectDetection.from_pretrained(path).to(self.device)
13
+ self.processor = AutoProcessor.from_pretrained(path)
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
16
+ # Sprawdzamy, czy dane wej艣ciowe zawieraj膮 wymagane pola
17
+ if "inputs" not in data:
18
+ return {"error": "Payload must contain 'inputs' key with 'image' and 'text'."}
19
+
20
+ inputs = data["inputs"]
21
+ if "image" not in inputs or "text" not in inputs:
22
+ return {"error": "Payload must contain 'image' (base64 or URL) and 'text' (queries)."}
23
+
24
+ # Pobieramy obraz (URL lub Base64)
25
+ image_data = inputs["image"]
26
+ if image_data.startswith("http"): # URL
27
+ response = requests.get(image_data)
28
+ image = Image.open(BytesIO(response.content))
29
+ else:
30
+ return {"error": "Handler currently supports only URL-based images."}
31
+
32
+ # Pobieramy tekst zapyta艅
33
+ text_queries = inputs["text"]
34
+ if isinstance(text_queries, list):
35
+ text_queries = ". ".join([t.lower().strip() + "." for t in text_queries])
36
+
37
+ # Przygotowujemy dane wej艣ciowe
38
+ processed_inputs = self.processor(images=image, text=text_queries, return_tensors="pt").to(self.device)
39
+
40
+ # Przeprowadzamy inferencj臋
41
+ with torch.no_grad():
42
+ outputs = self.model(**processed_inputs)
43
+
44
+ # Post-process wynik贸w
45
+ results = self.processor.post_process_grounded_object_detection(
46
+ outputs,
47
+ processed_inputs.input_ids,
48
+ box_threshold=0.4,
49
+ text_threshold=0.3,
50
+ target_sizes=[image.size[::-1]]
51
+ )
52
+
53
+ # Zwracamy wyniki
54
+ return {"detections": results}