File size: 2,375 Bytes
baebe6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56e1504
 
 
 
baebe6f
56e1504
baebe6f
 
 
 
 
 
 
 
 
 
56e1504
baebe6f
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Any, Dict, List

import requests
import torch

from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image


class EndpointHandler:
    def __init__(
        self,
        model_dir: str = "/opt/huggingface/model",
        **kwargs: Any,
    ) -> None:
        self.model = PaliGemmaForConditionalGeneration.from_pretrained(
            "google/paligemma-3b-mix-448",
            revision="bfloat16",
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            device_map="auto",
        ).eval()

        self.processor = AutoProcessor.from_pretrained("google/paligemma-3b-mix-448")

    def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Any]]:
        if "instances" not in data:
            raise ValueError(
                "The request body must contain a key `instances` with a list of instances."
            )

        predictions = []
        for input in data["instances"]:
            if "prompt" in input:
                input["text"] = input.pop("prompt")

            if any(key not in input for key in {"text", "image_url"}):
                raise ValueError(
                    "The request body for each instance should contain both the `text` and the `image_url` key with a valid image URL."
                )

            try:
                image = Image.open(requests.get(input["image_url"], stream=True).raw)  # type: ignore
            except Exception as e:
                raise ValueError(
                    f"The provided image URL ({input['image_url']}) cannot be downloaded (with exception {e}), make sure it's public and accessible."
                )

            inputs = self.processor(
                text=input["text"], images=image, return_tensors="pt"
            ).to(self.model.device)
            input_len = inputs["input_ids"].shape[-1]

            with torch.inference_mode():
                generation_kwargs = data.get(
                    "generation_kwargs", {"max_new_tokens": 100, "do_sample": False}
                )
                generation = self.model.generate(**inputs, **generation_kwargs)
                generation = generation[0][input_len:]
                response = self.processor.decode(generation, skip_special_tokens=True)
            predictions.append(response)
        return {"predictions": predictions}