File size: 1,899 Bytes
8574439
4c169bc
8574439
4c169bc
8574439
 
 
 
 
4c169bc
 
8574439
 
 
4c169bc
8574439
4c169bc
 
 
8574439
 
4c169bc
8574439
 
 
 
4c169bc
8574439
4c169bc
8574439
 
 
4c169bc
8574439
4c169bc
8574439
4c169bc
8574439
4c169bc
 
 
 
 
8574439
 
4c169bc
8574439
 
 
 
 
4c169bc
8574439
4c169bc
8574439
4c169bc
 
8574439
4c169bc
8574439
 
4c169bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Dict, Any
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import requests
from io import BytesIO

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EndpointHandler:
    def __init__(self, path: str = "morthens/qwen2-vl-inference"):
        # Load the processor and model
        self.processor = AutoProcessor.from_pretrained(path)
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            path,
            torch_dtype="auto",
            device_map="auto"
        )
        # Move the model to the appropriate device
        self.model.to(device)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        # Extract the input data
        image_url = data.get("image_url", "")
        text = data.get("text", "")

        # Load the image from the URL
        try:
            response = requests.get(image_url)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content))
        except Exception as e:
            return {"error": f"Failed to fetch or process image: {str(e)}"}

        # Preprocess the input
        inputs = self.processor(
            text=[text],
            images=[image],
            padding=True,
            return_tensors="pt"
        )

        # Move inputs to the correct device
        inputs = {key: value.to(device) for key, value in inputs.items()}

        # Perform inference
        output_ids = self.model.generate(
            **inputs,
            max_new_tokens=128
        )

        # Decode the output
        output_text = self.processor.batch_decode(
            output_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )[0]

        # Return the raw prediction
        return {"prediction": output_text}