alvarobartt HF staff commited on
Commit
ff47bc8
·
verified ·
1 Parent(s): 739c23c

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +106 -0
handler.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq, GenerationConfig
3
+ from transformers.image_utils import load_image
4
+
5
+ from typing import Any, Dict
6
+
7
+ import base64
8
+ import re
9
+ from copy import deepcopy
10
+
11
+
12
+ def is_base64(s: str) -> bool:
13
+ try:
14
+ return base64.b64encode(base64.b64decode(s)).decode() == s
15
+ except Exception:
16
+ return False
17
+
18
+
19
+ def is_url(s: str) -> bool:
20
+ url_pattern = re.compile(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+")
21
+ return bool(url_pattern.match(s))
22
+
23
+
24
+ class EndpointHandler:
25
+ def __init__(
26
+ self,
27
+ model_dir: str = "HuggingFaceTB/SmolVLM-Instruct",
28
+ **kwargs: Any, # type: ignore
29
+ ) -> None:
30
+ self.processor = AutoProcessor.from_pretrained(model_dir)
31
+ self.model = AutoModelForVision2Seq.from_pretrained(
32
+ model_dir,
33
+ torch_dtype=torch.bfloat16,
34
+ _attn_implementation="flash_attention_2",
35
+ device_map="auto",
36
+ ).eval()
37
+ self.generation_config = GenerationConfig.from_pretrained(model_dir)
38
+
39
+ def __call__(self, data: Dict[str, Any]) -> Any:
40
+ if "inputs" not in data:
41
+ raise ValueError(
42
+ "The request body must contain a key 'inputs' with a list of inputs."
43
+ )
44
+
45
+ if not isinstance(data["inputs"], list):
46
+ raise ValueError(
47
+ "The request inputs must be a list of dictionaries with the keys 'text' and 'images', being a"
48
+ " string with the prompt and a list with the image URLs or base64 encodings, respectively; and"
49
+ " optionally including the key 'generation_parameters' key too."
50
+ )
51
+
52
+ predictions = []
53
+ for input in data["inputs"]:
54
+ if "text" not in input:
55
+ raise ValueError(
56
+ "The request input body must contain the key 'text' with the prompt to use."
57
+ )
58
+
59
+ if "images" not in input or (
60
+ not isinstance(input["images"], list)
61
+ and all(isinstance(i, str) for i in input["images"])
62
+ ):
63
+ raise ValueError(
64
+ "The request input body must contain the key 'images' with a list of strings,"
65
+ " where each string corresponds to an image on either base64 encoding, or provided"
66
+ " as a valid URL (needs to be publicly accessible and contain a valid image)."
67
+ )
68
+
69
+ images = []
70
+ for image in input["images"]:
71
+ try:
72
+ images.append(load_image(image))
73
+ except Exception as e:
74
+ raise ValueError(
75
+ f"Provided {image=} is not valid, please make sure that's either a base64 encoding"
76
+ f" of a valid image, or a publicly accesible URL to a valid image.\nFailed with {e=}."
77
+ )
78
+
79
+ generation_config = deepcopy(self.generation_config)
80
+ generation_config.update(**input.get("generation_parameters", {}))
81
+
82
+ messages = [
83
+ {
84
+ "role": "user",
85
+ "content": [{"type": "image"} for _ in images]
86
+ + [{"type": "text", "text": input["text"]}],
87
+ },
88
+ ]
89
+ prompt = self.processor.apply_chat_template(
90
+ messages, add_generation_prompt=True
91
+ )
92
+ processed_inputs = self.processor(
93
+ text=prompt, images=images, return_tensors="pt"
94
+ )
95
+
96
+ with torch.no_grad(), torch.autocast(self.model.device):
97
+ generated_ids = self.model.generate(
98
+ **processed_inputs, **generation_config
99
+ )
100
+ generated_texts = self.processor.batch_decode(
101
+ generated_ids,
102
+ skip_special_tokens=True,
103
+ )
104
+ predictions.append(generated_texts[0])
105
+
106
+ return {"predictions": predictions}