File size: 2,183 Bytes
a583978
 
 
 
cb05228
f3126f3
cb05228
a79f819
a583978
 
 
 
 
9d9b5e2
a7a4721
a583978
f0b2295
baa2ff5
 
 
 
 
 
2a79ef4
8576dce
2a79ef4
 
 
 
 
 
 
 
a583978
 
a7a4721
a583978
 
 
 
 
9d9b5e2
baa2ff5
a583978
e53bc58
baa2ff5
a583978
9d9b5e2
db4cf02
8576dce
2a79ef4
baa2ff5
 
 
 
 
 
 
a583978
baa2ff5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import  Dict, List, Any
from PIL import Image
import requests
import torch
import base64
import os
from io import BytesIO
from models.blip_decoder import blip_decoder
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, path=""):
        # load the optimized model
        self.model_path = os.path.join(path,'model_large_caption.pth') 
        self.model = blip_decoder(
            pretrained=self.model_path, 
            image_size=384, 
            vit='large',
            med_config=os.path.join(path, 'configs/med_config.json')
        )
        self.model.eval()
        self.model = self.model.to(device)
        
        image_size = 384
        self.transform = transforms.Compose([
            transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
            ]) 
     


    def __call__(self, data: Any) -> Dict[str, Any]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`dict`:. The object returned should be a dict of one list like {"caption": ["A hugging face at the office"]} containing :
                - "caption": A string corresponding to the generated caption.
        """
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})

       
        image = Image.open(BytesIO(inputs))
        image = self.transform(image).unsqueeze(0).to(device)   
        with torch.no_grad():
            caption = self.model.generate(
                image, 
                sample=parameters.get('sample',True),
                top_p=parameters.get('top_p',0.9), 
                max_length=parameters.get('max_length',20), 
                min_length=parameters.get('min_length',5)
            )
        # postprocess the prediction
        return {"caption": caption}